overall_indicator.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462
  1. import json
  2. import pandas as pd
  3. import numpy as np
  4. import re
  5. from nltk.translate.bleu_score import sentence_bleu
  6. import time
  7. import argparse
  8. import os
  9. from sklearn.metrics import classification_report,confusion_matrix
  10. from collections import Counter
  11. from sklearn import metrics
  12. from pandas import isnull
  13. def indicator_cal(json_standard,json_test):
  14. json_standard = pd.DataFrame(json_standard)
  15. json_test = pd.DataFrame(json_test)
  16. '''数据集总体指标'''
  17. a=json_test[['id','mid_json']]
  18. b=json_standard[['id','mid_json','pass_label']]
  19. outer_merge=pd.merge(a,b,on='id',how='outer')
  20. outer_merge.columns=['id','standard_mid_json','test_mid_json','pass_label']
  21. standard_exist=outer_merge.standard_mid_json.apply(lambda x: not isnull(x))
  22. test_exist=outer_merge.test_mid_json.apply(lambda x: not isnull(x))
  23. overall_report = {}
  24. overall_report['accuracy']=metrics.accuracy_score(standard_exist,test_exist)
  25. overall_report['precision']=metrics.precision_score(standard_exist,test_exist)
  26. overall_report['recall']=metrics.recall_score(standard_exist,test_exist)
  27. overall_report['f1_score']=metrics.f1_score(standard_exist,test_exist)
  28. inner_merge=pd.merge(a,b,on='id',how='inner')
  29. inner_merge.columns=['id','standard_mid_json','test_mid_json','pass_label']
  30. json_standard = inner_merge['standard_mid_json']#check一下是否对齐
  31. json_test = inner_merge['test_mid_json']
  32. '''批量读取中间生成的json文件'''
  33. test_inline_equations=[]
  34. test_interline_equations=[]
  35. test_dropped_text_bboxes=[]
  36. test_dropped_text_tag=[]
  37. test_dropped_image_bboxes=[]
  38. test_dropped_table_bboxes=[]
  39. test_preproc_num=[]#阅读顺序
  40. test_para_num=[]
  41. test_para_text=[]
  42. for i in json_test:
  43. mid_json=pd.DataFrame(i)
  44. mid_json=mid_json.iloc[:,:-1]
  45. for j1 in mid_json.loc['inline_equations',:]:
  46. page_in=[]
  47. for k1 in j1:
  48. page_in.append(k1['latex_text'])
  49. test_inline_equations.append(page_in)
  50. for j2 in mid_json.loc['interline_equations',:]:
  51. page_in=[]
  52. for k2 in j2:
  53. page_in.append(k2['latex_text'])
  54. test_interline_equations.append(page_in)
  55. for j3 in mid_json.loc['droped_text_block',:]:
  56. page_in_bbox=[]
  57. page_in_tag=[]
  58. for k3 in j3:
  59. page_in_bbox.append(k3['bbox'])
  60. #如果k3中存在tag这个key
  61. if 'tag' in k3.keys():
  62. page_in_tag.append(k3['tag'])
  63. else:
  64. page_in_tag.append('None')
  65. test_dropped_text_tag.append(page_in_tag)
  66. test_dropped_text_bboxes.append(page_in_bbox)
  67. for j4 in mid_json.loc['droped_image_block',:]:
  68. test_dropped_image_bboxes.append(j4)
  69. for j5 in mid_json.loc['droped_table_block',:]:
  70. test_dropped_table_bboxes.append(j5)
  71. for j6 in mid_json.loc['preproc_blocks',:]:
  72. page_in=[]
  73. for k6 in j6:
  74. page_in.append(k6['number'])
  75. test_preproc_num.append(page_in)
  76. test_pdf_text=[]
  77. for j7 in mid_json.loc['para_blocks',:]:
  78. test_para_num.append(len(j7))
  79. for k7 in j7:
  80. test_pdf_text.append(k7['text'])
  81. test_para_text.append(test_pdf_text)
  82. standard_inline_equations=[]
  83. standard_interline_equations=[]
  84. standard_dropped_text_bboxes=[]
  85. standard_dropped_text_tag=[]
  86. standard_dropped_image_bboxes=[]
  87. standard_dropped_table_bboxes=[]
  88. standard_preproc_num=[]#阅读顺序
  89. standard_para_num=[]
  90. standard_para_text=[]
  91. for i in json_standard:
  92. mid_json=pd.DataFrame(i)
  93. mid_json=mid_json.iloc[:,:-1]
  94. for j1 in mid_json.loc['inline_equations',:]:
  95. page_in=[]
  96. for k1 in j1:
  97. page_in.append(k1['latex_text'])
  98. standard_inline_equations.append(page_in)
  99. for j2 in mid_json.loc['interline_equations',:]:
  100. page_in=[]
  101. for k2 in j2:
  102. page_in.append(k2['latex_text'])
  103. standard_interline_equations.append(page_in)
  104. for j3 in mid_json.loc['droped_text_block',:]:
  105. page_in_bbox=[]
  106. page_in_tag=[]
  107. for k3 in j3:
  108. page_in_bbox.append(k3['bbox'])
  109. if 'tag' in k3.keys():
  110. page_in_tag.append(k3['tag'])
  111. else:
  112. page_in_tag.append('None')
  113. standard_dropped_text_bboxes.append(page_in_bbox)
  114. standard_dropped_text_tag.append(page_in_tag)
  115. for j4 in mid_json.loc['droped_image_block',:]:
  116. standard_dropped_image_bboxes.append(j4)
  117. for j5 in mid_json.loc['droped_table_block',:]:
  118. standard_dropped_table_bboxes.append(j5)
  119. for j6 in mid_json.loc['preproc_blocks',:]:
  120. page_in=[]
  121. for k6 in j6:
  122. page_in.append(k6['number'])
  123. standard_preproc_num.append(page_in)
  124. standard_pdf_text=[]
  125. for j7 in mid_json.loc['para_blocks',:]:
  126. standard_para_num.append(len(j7))
  127. for k7 in j7:
  128. standard_pdf_text.append(k7['text'])
  129. standard_para_text.append(standard_pdf_text)
  130. """
  131. 在计算指标之前最好先确认基本统计信息是否一致
  132. """
  133. '''
  134. 计算pdf之间的总体编辑距离和bleu
  135. 这里只计算正例的pdf
  136. '''
  137. test_para_text=np.asarray(test_para_text, dtype = object)[inner_merge['pass_label']=='yes']
  138. standard_para_text=np.asarray(standard_para_text, dtype = object)[inner_merge['pass_label']=='yes']
  139. pdf_dis=[]
  140. pdf_bleu=[]
  141. for a,b in zip(test_para_text,standard_para_text):
  142. a1=[ ''.join(i) for i in a]
  143. b1=[ ''.join(i) for i in b]
  144. pdf_dis.append(Levenshtein_Distance(a1,b1))
  145. pdf_bleu.append(sentence_bleu([a1],b1))
  146. overall_report['pdf间的平均编辑距离']=np.mean(pdf_dis)
  147. overall_report['pdf间的平均bleu']=np.mean(pdf_bleu)
  148. '''行内公式编辑距离和bleu'''
  149. dis1=[]
  150. bleu1=[]
  151. test_inline_equations=[ ''.join(i) for i in test_inline_equations]
  152. standard_inline_equations=[ ''.join(i) for i in standard_inline_equations]
  153. for a,b in zip(test_inline_equations,standard_inline_equations):
  154. if len(a)==0 and len(b)==0:
  155. continue
  156. else:
  157. if a==b:
  158. dis1.append(0)
  159. bleu1.append(1)
  160. else:
  161. dis1.append(Levenshtein_Distance(a,b))
  162. bleu1.append(sentence_bleu([a],b))
  163. inline_equations_edit=np.mean(dis1)
  164. inline_equations_bleu=np.mean(bleu1)
  165. '''行间公式编辑距离和bleu'''
  166. dis2=[]
  167. bleu2=[]
  168. test_interline_equations=[ ''.join(i) for i in test_interline_equations]
  169. standard_interline_equations=[ ''.join(i) for i in standard_interline_equations]
  170. for a,b in zip(test_interline_equations,standard_interline_equations):
  171. if len(a)==0 and len(b)==0:
  172. continue
  173. else:
  174. if a==b:
  175. dis2.append(0)
  176. bleu2.append(1)
  177. else:
  178. dis2.append(Levenshtein_Distance(a,b))
  179. bleu2.append(sentence_bleu([a],b))
  180. interline_equations_edit=np.mean(dis2)
  181. interline_equations_bleu=np.mean(bleu2)
  182. '''可以先检查page和bbox数量是否一致'''
  183. '''dropped_text_block的bbox匹配相关指标'''
  184. test_text_bbox=[]
  185. standard_text_bbox=[]
  186. test_tag=[]
  187. standard_tag=[]
  188. index=0
  189. for a,b in zip(test_dropped_text_bboxes,standard_dropped_text_bboxes):
  190. test_page_tag=[]
  191. standard_page_tag=[]
  192. test_page_bbox=[]
  193. standard_page_bbox=[]
  194. if len(a)==0 and len(b)==0:
  195. pass
  196. else:
  197. for i in range(len(b)):
  198. judge=0
  199. standard_page_tag.append(standard_dropped_text_tag[index][i])
  200. standard_page_bbox.append(1)
  201. for j in range(len(a)):
  202. if bbox_offset(b[i],a[j]):
  203. judge=1
  204. test_page_tag.append(test_dropped_text_tag[index][j])
  205. test_page_bbox.append(1)
  206. break
  207. if judge==0:
  208. test_page_tag.append('None')
  209. test_page_bbox.append(0)
  210. if len(test_dropped_text_tag[index])+test_page_tag.count('None')>len(standard_dropped_text_tag[index]):#有多删的情况出现
  211. test_page_tag1=test_page_tag.copy()
  212. if 'None' in test_page_tag:
  213. test_page_tag1=test_page_tag1.remove('None')
  214. else:
  215. test_page_tag1=test_page_tag
  216. diff=list((Counter(test_dropped_text_tag[index]) - Counter(test_page_tag1)).elements())
  217. test_page_tag.extend(diff)
  218. standard_page_tag.extend(['None']*len(diff))
  219. test_page_bbox.extend([1]*len(diff))
  220. standard_page_bbox.extend([0]*len(diff))
  221. test_tag.extend(test_page_tag)
  222. standard_tag.extend(standard_page_tag)
  223. test_text_bbox.extend(test_page_bbox)
  224. standard_text_bbox.extend(standard_page_bbox)
  225. index+=1
  226. text_block_report = {}
  227. text_block_report['accuracy']=metrics.accuracy_score(standard_text_bbox,test_text_bbox)
  228. text_block_report['precision']=metrics.precision_score(standard_text_bbox,test_text_bbox)
  229. text_block_report['recall']=metrics.recall_score(standard_text_bbox,test_text_bbox)
  230. text_block_report['f1_score']=metrics.f1_score(standard_text_bbox,test_text_bbox)
  231. '''删除的text_block的tag的准确率,召回率和f1-score'''
  232. text_block_tag_report = classification_report(y_true=standard_tag , y_pred=test_tag,output_dict=True)
  233. del text_block_tag_report['None']
  234. del text_block_tag_report["macro avg"]
  235. del text_block_tag_report["weighted avg"]
  236. '''dropped_image_block的bbox匹配相关指标'''
  237. '''有数据格式不一致的问题'''
  238. test_image_bbox=[]
  239. standard_image_bbox=[]
  240. for a,b in zip(test_dropped_image_bboxes,standard_dropped_image_bboxes):
  241. test_page_bbox=[]
  242. standard_page_bbox=[]
  243. if len(a)==0 and len(b)==0:
  244. pass
  245. else:
  246. for i in b:
  247. if len(i)!=4:
  248. continue
  249. else:
  250. judge=0
  251. standard_page_bbox.append(1)
  252. for j in a:
  253. if bbox_offset(i,j):
  254. judge=1
  255. test_page_bbox.append(1)
  256. break
  257. if judge==0:
  258. test_page_bbox.append(0)
  259. diff_num=len(a)+test_page_bbox.count(0)-len(b)
  260. if diff_num>0:#有多删的情况出现
  261. test_page_bbox.extend([1]*diff_num)
  262. standard_page_bbox.extend([0]*diff_num)
  263. test_image_bbox.extend(test_page_bbox)
  264. standard_image_bbox.extend(standard_page_bbox)
  265. image_block_report = {}
  266. image_block_report['accuracy']=metrics.accuracy_score(standard_image_bbox,test_image_bbox)
  267. image_block_report['precision']=metrics.precision_score(standard_image_bbox,test_image_bbox)
  268. image_block_report['recall']=metrics.recall_score(standard_image_bbox,test_image_bbox)
  269. image_block_report['f1_score']=metrics.f1_score(standard_image_bbox,test_image_bbox)
  270. '''dropped_table_block的bbox匹配相关指标'''
  271. test_table_bbox=[]
  272. standard_table_bbox=[]
  273. for a,b in zip(test_dropped_table_bboxes,standard_dropped_table_bboxes):
  274. test_page_bbox=[]
  275. standard_page_bbox=[]
  276. if len(a)==0 and len(b)==0:
  277. pass
  278. else:
  279. for i in b:
  280. if len(i)!=4:
  281. continue
  282. else:
  283. judge=0
  284. standard_page_bbox.append(1)
  285. for j in a:
  286. if bbox_offset(i,j):
  287. judge=1
  288. test_page_bbox.append(1)
  289. break
  290. if judge==0:
  291. test_page_bbox.append(0)
  292. diff_num=len(a)+test_page_bbox.count(0)-len(b)
  293. if diff_num>0:#有多删的情况出现
  294. test_page_bbox.extend([1]*diff_num)
  295. standard_page_bbox.extend([0]*diff_num)
  296. test_table_bbox.extend(test_page_bbox)
  297. standard_table_bbox.extend(standard_page_bbox)
  298. table_block_report = {}
  299. table_block_report['accuracy']=metrics.accuracy_score(standard_table_bbox,test_table_bbox)
  300. table_block_report['precision']=metrics.precision_score(standard_table_bbox,test_table_bbox)
  301. table_block_report['recall']=metrics.recall_score(standard_table_bbox,test_table_bbox)
  302. table_block_report['f1_score']=metrics.f1_score(standard_table_bbox,test_table_bbox)
  303. '''阅读顺序编辑距离的均值'''
  304. preproc_num_dis=[]
  305. for a,b in zip(test_preproc_num,standard_preproc_num):
  306. preproc_num_dis.append(Levenshtein_Distance(a,b))
  307. preproc_num_edit=np.mean(preproc_num_dis)
  308. '''分段准确率'''
  309. test_para_num=np.array(test_para_num)
  310. standard_para_num=np.array(standard_para_num)
  311. acc_para=np.mean(test_para_num==standard_para_num)
  312. output=pd.DataFrame()
  313. output['总体指标']=[overall_report]
  314. output['行内公式平均编辑距离']=[inline_equations_edit]
  315. output['行间公式平均编辑距离']=[interline_equations_edit]
  316. output['行内公式平均bleu']=[inline_equations_bleu]
  317. output['行间公式平均bleu']=[interline_equations_bleu]
  318. output['阅读顺序平均编辑距离']=[preproc_num_edit]
  319. output['分段准确率']=[acc_para]
  320. output['删除的text block的相关指标']=[text_block_report]
  321. output['删除的image block的相关指标']=[image_block_report]
  322. output['删除的table block的相关指标']=[table_block_report]
  323. output['删除的text block的tag相关指标']=[text_block_tag_report]
  324. return output
  325. """
  326. 计算编辑距离
  327. """
  328. def Levenshtein_Distance(str1, str2):
  329. matrix = [[ i + j for j in range(len(str2) + 1)] for i in range(len(str1) + 1)]
  330. for i in range(1, len(str1)+1):
  331. for j in range(1, len(str2)+1):
  332. if(str1[i-1] == str2[j-1]):
  333. d = 0
  334. else:
  335. d = 1
  336. matrix[i][j] = min(matrix[i-1][j]+1, matrix[i][j-1]+1, matrix[i-1][j-1]+d)
  337. return matrix[len(str1)][len(str2)]
  338. '''
  339. 计算bbox偏移量是否符合标准的函数
  340. '''
  341. def bbox_offset(b_t,b_s):
  342. '''b_t是test_doc里的bbox,b_s是standard_doc里的bbox'''
  343. x1_t,y1_t,x2_t,y2_t=b_t
  344. x1_s,y1_s,x2_s,y2_s=b_s
  345. x1=max(x1_t,x1_s)
  346. x2=min(x2_t,x2_s)
  347. y1=max(y1_t,y1_s)
  348. y2=min(y2_t,y2_s)
  349. area_overlap=(x2-x1)*(y2-y1)
  350. area_t=(x2_t-x1_t)*(y2_t-y1_t)+(x2_s-x1_s)*(y2_s-y1_s)-area_overlap
  351. if area_t-area_overlap==0 or area_overlap/(area_t-area_overlap)>0.95:
  352. return True
  353. else:
  354. return False
  355. parser = argparse.ArgumentParser()
  356. parser.add_argument('--test', type=str)
  357. parser.add_argument('--standard', type=str)
  358. args = parser.parse_args()
  359. pdf_json_test = args.test
  360. pdf_json_standard = args.standard
  361. if __name__ == '__main__':
  362. pdf_json_test = [json.loads(line)
  363. for line in open(pdf_json_test, 'r', encoding='utf-8')]
  364. pdf_json_standard = [json.loads(line)
  365. for line in open(pdf_json_standard, 'r', encoding='utf-8')]
  366. overall_indicator=indicator_cal(pdf_json_standard,pdf_json_test)
  367. '''计算的指标输出到overall_indicator_output.json中'''
  368. overall_indicator.to_json('overall_indicator_output.json',orient='records',lines=True,force_ascii=False)