overall_indicator.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456
  1. import json
  2. import pandas as pd
  3. import numpy as np
  4. from nltk.translate.bleu_score import sentence_bleu
  5. import argparse
  6. from sklearn.metrics import classification_report
  7. from collections import Counter
  8. from sklearn import metrics
  9. from pandas import isnull
  10. def indicator_cal(json_standard,json_test):
  11. json_standard = pd.DataFrame(json_standard)
  12. json_test = pd.DataFrame(json_test)
  13. '''数据集总体指标'''
  14. a=json_test[['id','mid_json']]
  15. b=json_standard[['id','mid_json','pass_label']]
  16. a=a.drop_duplicates(subset='id',keep='first')
  17. a.index=range(len(a))
  18. b=b.drop_duplicates(subset='id',keep='first')
  19. b.index=range(len(b))
  20. outer_merge=pd.merge(a,b,on='id',how='outer')
  21. outer_merge.columns=['id','standard_mid_json','test_mid_json','pass_label']
  22. standard_exist=outer_merge.standard_mid_json.apply(lambda x: not isnull(x))
  23. test_exist=outer_merge.test_mid_json.apply(lambda x: not isnull(x))
  24. overall_report = {}
  25. overall_report['accuracy']=metrics.accuracy_score(standard_exist,test_exist)
  26. overall_report['precision']=metrics.precision_score(standard_exist,test_exist)
  27. overall_report['recall']=metrics.recall_score(standard_exist,test_exist)
  28. overall_report['f1_score']=metrics.f1_score(standard_exist,test_exist)
  29. inner_merge=pd.merge(a,b,on='id',how='inner')
  30. inner_merge.columns=['id','standard_mid_json','test_mid_json','pass_label']
  31. json_standard = inner_merge['standard_mid_json']#check一下是否对齐
  32. json_test = inner_merge['test_mid_json']
  33. '''批量读取中间生成的json文件'''
  34. test_inline_equations=[]
  35. test_interline_equations=[]
  36. test_inline_euqations_bboxs=[]
  37. test_interline_equations_bboxs=[]
  38. test_dropped_text_bboxes=[]
  39. test_dropped_text_tag=[]
  40. test_dropped_image_bboxes=[]
  41. test_dropped_table_bboxes=[]
  42. test_preproc_num=[]#阅读顺序
  43. test_para_num=[]
  44. test_para_text=[]
  45. for i in json_test:
  46. mid_json=pd.DataFrame(i)
  47. mid_json=mid_json.iloc[:,:-1]
  48. for j1 in mid_json.loc['inline_equations',:]:
  49. page_in_text=[]
  50. page_in_bbox=[]
  51. for k1 in j1:
  52. page_in_text.append(k1['latex_text'])
  53. page_in_bbox.append(k1['bbox'])
  54. test_inline_equations.append(page_in_text)
  55. test_inline_euqations_bboxs.append(page_in_bbox)
  56. for j2 in mid_json.loc['interline_equations',:]:
  57. page_in_text=[]
  58. page_in_bbox=[]
  59. for k2 in j2:
  60. page_in_text.append(k2['latex_text'])
  61. page_in_bbox.append(k2['bbox'])
  62. test_interline_equations.append(page_in_text)
  63. test_interline_equations_bboxs.append(page_in_bbox)
  64. for j3 in mid_json.loc['droped_text_block',:]:
  65. page_in_bbox=[]
  66. page_in_tag=[]
  67. for k3 in j3:
  68. page_in_bbox.append(k3['bbox'])
  69. #如果k3中存在tag这个key
  70. if 'tag' in k3.keys():
  71. page_in_tag.append(k3['tag'])
  72. else:
  73. page_in_tag.append('None')
  74. test_dropped_text_tag.append(page_in_tag)
  75. test_dropped_text_bboxes.append(page_in_bbox)
  76. for j4 in mid_json.loc['droped_image_block',:]:
  77. test_dropped_image_bboxes.append(j4)
  78. for j5 in mid_json.loc['droped_table_block',:]:
  79. test_dropped_table_bboxes.append(j5)
  80. for j6 in mid_json.loc['preproc_blocks',:]:
  81. page_in=[]
  82. for k6 in j6:
  83. page_in.append(k6['number'])
  84. test_preproc_num.append(page_in)
  85. test_pdf_text=[]
  86. for j7 in mid_json.loc['para_blocks',:]:
  87. test_para_num.append(len(j7))
  88. for k7 in j7:
  89. test_pdf_text.append(k7['text'])
  90. test_para_text.append(test_pdf_text)
  91. standard_inline_equations=[]
  92. standard_interline_equations=[]
  93. standard_inline_euqations_bboxs=[]
  94. standard_interline_equations_bboxs=[]
  95. standard_dropped_text_bboxes=[]
  96. standard_dropped_text_tag=[]
  97. standard_dropped_image_bboxes=[]
  98. standard_dropped_table_bboxes=[]
  99. standard_preproc_num=[]#阅读顺序
  100. standard_para_num=[]
  101. standard_para_text=[]
  102. for i in json_standard:
  103. mid_json=pd.DataFrame(i)
  104. mid_json=mid_json.iloc[:,:-1]
  105. for j1 in mid_json.loc['inline_equations',:]:
  106. page_in_text=[]
  107. page_in_bbox=[]
  108. for k1 in j1:
  109. page_in_text.append(k1['latex_text'])
  110. page_in_bbox.append(k1['bbox'])
  111. standard_inline_equations.append(page_in_text)
  112. standard_inline_euqations_bboxs.append(page_in_bbox)
  113. for j2 in mid_json.loc['interline_equations',:]:
  114. page_in_text=[]
  115. page_in_bbox=[]
  116. for k2 in j2:
  117. page_in_text.append(k2['latex_text'])
  118. page_in_bbox.append(k2['bbox'])
  119. standard_interline_equations.append(page_in_text)
  120. standard_interline_equations_bboxs.append(page_in_bbox)
  121. for j3 in mid_json.loc['droped_text_block',:]:
  122. page_in_bbox=[]
  123. page_in_tag=[]
  124. for k3 in j3:
  125. page_in_bbox.append(k3['bbox'])
  126. if 'tag' in k3.keys():
  127. page_in_tag.append(k3['tag'])
  128. else:
  129. page_in_tag.append('None')
  130. standard_dropped_text_bboxes.append(page_in_bbox)
  131. standard_dropped_text_tag.append(page_in_tag)
  132. for j4 in mid_json.loc['droped_image_block',:]:
  133. standard_dropped_image_bboxes.append(j4)
  134. for j5 in mid_json.loc['droped_table_block',:]:
  135. standard_dropped_table_bboxes.append(j5)
  136. for j6 in mid_json.loc['preproc_blocks',:]:
  137. page_in=[]
  138. for k6 in j6:
  139. page_in.append(k6['number'])
  140. standard_preproc_num.append(page_in)
  141. standard_pdf_text=[]
  142. for j7 in mid_json.loc['para_blocks',:]:
  143. standard_para_num.append(len(j7))
  144. for k7 in j7:
  145. standard_pdf_text.append(k7['text'])
  146. standard_para_text.append(standard_pdf_text)
  147. """
  148. 在计算指标之前最好先确认基本统计信息是否一致
  149. """
  150. '''
  151. 计算pdf之间的总体编辑距离和bleu
  152. 这里只计算正例的pdf
  153. '''
  154. test_para_text=np.asarray(test_para_text, dtype = object)[inner_merge['pass_label']=='yes']
  155. standard_para_text=np.asarray(standard_para_text, dtype = object)[inner_merge['pass_label']=='yes']
  156. pdf_dis=[]
  157. pdf_bleu=[]
  158. for a,b in zip(test_para_text,standard_para_text):
  159. a1=[ ''.join(i) for i in a]
  160. b1=[ ''.join(i) for i in b]
  161. pdf_dis.append(Levenshtein_Distance(a1,b1))
  162. pdf_bleu.append(sentence_bleu([a1],b1))
  163. overall_report['pdf间的平均编辑距离']=np.mean(pdf_dis)
  164. overall_report['pdf间的平均bleu']=np.mean(pdf_bleu)
  165. '''行内公式和行间公式的编辑距离和bleu'''
  166. inline_equations_edit_bleu=equations_indicator(test_inline_euqations_bboxs,standard_inline_euqations_bboxs,test_inline_equations,standard_inline_equations)
  167. interline_equations_edit_bleu=equations_indicator(test_interline_equations_bboxs,standard_interline_equations_bboxs,test_interline_equations,standard_interline_equations)
  168. '''行内公式bbox匹配相关指标'''
  169. inline_equations_bbox_report=bbox_match_indicator(test_inline_euqations_bboxs,standard_inline_euqations_bboxs)
  170. '''行间公式bbox匹配相关指标'''
  171. interline_equations_bbox_report=bbox_match_indicator(test_interline_equations_bboxs,standard_interline_equations_bboxs)
  172. '''可以先检查page和bbox数量是否一致'''
  173. '''dropped_text_block的bbox匹配相关指标'''
  174. test_text_bbox=[]
  175. standard_text_bbox=[]
  176. test_tag=[]
  177. standard_tag=[]
  178. index=0
  179. for a,b in zip(test_dropped_text_bboxes,standard_dropped_text_bboxes):
  180. test_page_tag=[]
  181. standard_page_tag=[]
  182. test_page_bbox=[]
  183. standard_page_bbox=[]
  184. if len(a)==0 and len(b)==0:
  185. pass
  186. else:
  187. for i in range(len(b)):
  188. judge=0
  189. standard_page_tag.append(standard_dropped_text_tag[index][i])
  190. standard_page_bbox.append(1)
  191. for j in range(len(a)):
  192. if bbox_offset(b[i],a[j]):
  193. judge=1
  194. test_page_tag.append(test_dropped_text_tag[index][j])
  195. test_page_bbox.append(1)
  196. break
  197. if judge==0:
  198. test_page_tag.append('None')
  199. test_page_bbox.append(0)
  200. if len(test_dropped_text_tag[index])+test_page_tag.count('None')>len(standard_dropped_text_tag[index]):#有多删的情况出现
  201. test_page_tag1=test_page_tag.copy()
  202. if 'None' in test_page_tag:
  203. test_page_tag1=test_page_tag1.remove('None')
  204. else:
  205. test_page_tag1=test_page_tag
  206. diff=list((Counter(test_dropped_text_tag[index]) - Counter(test_page_tag1)).elements())
  207. test_page_tag.extend(diff)
  208. standard_page_tag.extend(['None']*len(diff))
  209. test_page_bbox.extend([1]*len(diff))
  210. standard_page_bbox.extend([0]*len(diff))
  211. test_tag.extend(test_page_tag)
  212. standard_tag.extend(standard_page_tag)
  213. test_text_bbox.extend(test_page_bbox)
  214. standard_text_bbox.extend(standard_page_bbox)
  215. index+=1
  216. text_block_report = {}
  217. text_block_report['accuracy']=metrics.accuracy_score(standard_text_bbox,test_text_bbox)
  218. text_block_report['precision']=metrics.precision_score(standard_text_bbox,test_text_bbox)
  219. text_block_report['recall']=metrics.recall_score(standard_text_bbox,test_text_bbox)
  220. text_block_report['f1_score']=metrics.f1_score(standard_text_bbox,test_text_bbox)
  221. '''删除的text_block的tag的准确率,召回率和f1-score'''
  222. text_block_tag_report = classification_report(y_true=standard_tag , y_pred=test_tag,output_dict=True)
  223. del text_block_tag_report['None']
  224. del text_block_tag_report["macro avg"]
  225. del text_block_tag_report["weighted avg"]
  226. '''dropped_image_block的bbox匹配相关指标'''
  227. '''有数据格式不一致的问题'''
  228. image_block_report=bbox_match_indicator(test_dropped_image_bboxes,standard_dropped_image_bboxes)
  229. '''dropped_table_block的bbox匹配相关指标'''
  230. table_block_report=bbox_match_indicator(test_dropped_table_bboxes,standard_dropped_table_bboxes)
  231. '''阅读顺序编辑距离的均值'''
  232. preproc_num_dis=[]
  233. for a,b in zip(test_preproc_num,standard_preproc_num):
  234. preproc_num_dis.append(Levenshtein_Distance(a,b))
  235. preproc_num_edit=np.mean(preproc_num_dis)
  236. '''分段准确率'''
  237. test_para_num=np.array(test_para_num)
  238. standard_para_num=np.array(standard_para_num)
  239. acc_para=np.mean(test_para_num==standard_para_num)
  240. output=pd.DataFrame()
  241. output['总体指标']=[overall_report]
  242. output['行内公式平均编辑距离']=[inline_equations_edit_bleu[0]]
  243. output['行内公式平均bleu']=[inline_equations_edit_bleu[1]]
  244. output['行间公式平均编辑距离']=[interline_equations_edit_bleu[0]]
  245. output['行间公式平均bleu']=[interline_equations_edit_bleu[1]]
  246. output['行内公式识别相关指标']=[inline_equations_bbox_report]
  247. output['行间公式识别相关指标']=[interline_equations_bbox_report]
  248. output['阅读顺序平均编辑距离']=[preproc_num_edit]
  249. output['分段准确率']=[acc_para]
  250. output['删除的text block的相关指标']=[text_block_report]
  251. output['删除的image block的相关指标']=[image_block_report]
  252. output['删除的table block的相关指标']=[table_block_report]
  253. output['删除的text block的tag相关指标']=[text_block_tag_report]
  254. return output
  255. """
  256. 计算编辑距离
  257. """
  258. def Levenshtein_Distance(str1, str2):
  259. matrix = [[ i + j for j in range(len(str2) + 1)] for i in range(len(str1) + 1)]
  260. for i in range(1, len(str1)+1):
  261. for j in range(1, len(str2)+1):
  262. if(str1[i-1] == str2[j-1]):
  263. d = 0
  264. else:
  265. d = 1
  266. matrix[i][j] = min(matrix[i-1][j]+1, matrix[i][j-1]+1, matrix[i-1][j-1]+d)
  267. return matrix[len(str1)][len(str2)]
  268. '''
  269. 计算bbox偏移量是否符合标准的函数
  270. '''
  271. def bbox_offset(b_t,b_s):
  272. '''b_t是test_doc里的bbox,b_s是standard_doc里的bbox'''
  273. x1_t,y1_t,x2_t,y2_t=b_t
  274. x1_s,y1_s,x2_s,y2_s=b_s
  275. x1=max(x1_t,x1_s)
  276. x2=min(x2_t,x2_s)
  277. y1=max(y1_t,y1_s)
  278. y2=min(y2_t,y2_s)
  279. area_overlap=(x2-x1)*(y2-y1)
  280. area_t=(x2_t-x1_t)*(y2_t-y1_t)+(x2_s-x1_s)*(y2_s-y1_s)-area_overlap
  281. if area_t-area_overlap==0 or area_overlap/(area_t-area_overlap)>0.95:
  282. return True
  283. else:
  284. return False
  285. '''bbox匹配和对齐函数,输出相关指标'''
  286. '''输入的是以page为单位的bbox列表'''
  287. def bbox_match_indicator(test_bbox_list,standard_bbox_list):
  288. test_bbox=[]
  289. standard_bbox=[]
  290. for a,b in zip(test_bbox_list,standard_bbox_list):
  291. test_page_bbox=[]
  292. standard_page_bbox=[]
  293. if len(a)==0 and len(b)==0:
  294. pass
  295. else:
  296. for i in b:
  297. if len(i)!=4:
  298. continue
  299. else:
  300. judge=0
  301. standard_page_bbox.append(1)
  302. for j in a:
  303. if bbox_offset(i,j):
  304. judge=1
  305. test_page_bbox.append(1)
  306. break
  307. if judge==0:
  308. test_page_bbox.append(0)
  309. diff_num=len(a)+test_page_bbox.count(0)-len(b)
  310. if diff_num>0:#有多删的情况出现
  311. test_page_bbox.extend([1]*diff_num)
  312. standard_page_bbox.extend([0]*diff_num)
  313. test_bbox.extend(test_page_bbox)
  314. standard_bbox.extend(standard_page_bbox)
  315. block_report = {}
  316. block_report['accuracy']=metrics.accuracy_score(standard_bbox,test_bbox)
  317. block_report['precision']=metrics.precision_score(standard_bbox,test_bbox)
  318. block_report['recall']=metrics.recall_score(standard_bbox,test_bbox)
  319. block_report['f1_score']=metrics.f1_score(standard_bbox,test_bbox)
  320. return block_report
  321. '''公式编辑距离和bleu'''
  322. def equations_indicator(test_euqations_bboxs,standard_euqations_bboxs,test_equations,standard_equations):
  323. test_match_equations=[]
  324. standard_match_equations=[]
  325. index=0
  326. for a,b in zip(test_euqations_bboxs,standard_euqations_bboxs):
  327. if len(a)==0 and len(b)==0:
  328. pass
  329. else:
  330. for i in range(len(b)):
  331. for j in range(len(a)):
  332. if bbox_offset(b[i],a[j]):
  333. standard_match_equations.append(standard_equations[index][i])
  334. test_match_equations.append(test_equations[index][j])
  335. break
  336. index+=1
  337. dis=[]
  338. bleu=[]
  339. for a,b in zip(test_match_equations,standard_match_equations):
  340. if len(a)==0 and len(b)==0:
  341. continue
  342. else:
  343. if a==b:
  344. dis.append(0)
  345. bleu.append(1)
  346. else:
  347. dis.append(Levenshtein_Distance(a,b))
  348. bleu.append(sentence_bleu([a],b))
  349. equations_edit=np.mean(dis)
  350. equations_bleu=np.mean(bleu)
  351. return (equations_edit,equations_bleu)
  352. parser = argparse.ArgumentParser()
  353. parser.add_argument('--test', type=str)
  354. parser.add_argument('--standard', type=str)
  355. args = parser.parse_args()
  356. pdf_json_test = args.test
  357. pdf_json_standard = args.standard
  358. if __name__ == '__main__':
  359. pdf_json_test = [json.loads(line)
  360. for line in open(pdf_json_test, 'r', encoding='utf-8')]
  361. pdf_json_standard = [json.loads(line)
  362. for line in open(pdf_json_standard, 'r', encoding='utf-8')]
  363. overall_indicator=indicator_cal(pdf_json_standard,pdf_json_test)
  364. '''计算的指标输出到overall_indicator_output.json中'''
  365. overall_indicator.to_json('overall_indicator_output.json',orient='records',lines=True,force_ascii=False)