overall_indicator.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455
  1. import json
  2. import pandas as pd
  3. import numpy as np
  4. import re
  5. from nltk.translate.bleu_score import sentence_bleu
  6. import time
  7. import argparse
  8. import os
  9. from sklearn.metrics import classification_report,confusion_matrix
  10. from collections import Counter
  11. from sklearn import metrics
  12. from pandas import isnull
  13. def indicator_cal(json_standard,json_test):
  14. json_standard = pd.DataFrame(json_standard)
  15. json_test = pd.DataFrame(json_test)
  16. '''数据集总体指标'''
  17. a=json_test[['id','mid_json']]
  18. b=json_standard[['id','mid_json']]
  19. outer_merge=pd.merge(a,b,on='id',how='outer')
  20. outer_merge.columns=['id','standard_mid_json','test_mid_json']
  21. standard_exist=outer_merge.standard_mid_json.apply(lambda x: not isnull(x))
  22. test_exist=outer_merge.test_mid_json.apply(lambda x: not isnull(x))
  23. overall_report = {}
  24. overall_report['accuracy']=metrics.accuracy_score(standard_exist,test_exist)
  25. overall_report['precision']=metrics.precision_score(standard_exist,test_exist)
  26. overall_report['recall']=metrics.recall_score(standard_exist,test_exist)
  27. overall_report['f1_score']=metrics.f1_score(standard_exist,test_exist)
  28. inner_merge=pd.merge(a,b,on='id',how='inner')
  29. inner_merge.columns=['id','standard_mid_json','test_mid_json']
  30. json_standard = inner_merge['standard_mid_json']#check一下是否对齐
  31. json_test = inner_merge['test_mid_json']
  32. '''批量读取中间生成的json文件'''
  33. test_inline_equations=[]
  34. test_interline_equations=[]
  35. test_dropped_text_bboxes=[]
  36. test_dropped_text_tag=[]
  37. test_dropped_image_bboxes=[]
  38. test_dropped_table_bboxes=[]
  39. test_preproc_num=[]#阅读顺序
  40. test_para_num=[]
  41. test_para_text=[]
  42. for i in json_test:
  43. mid_json=pd.DataFrame(i)
  44. mid_json=mid_json.iloc[:,:-1]
  45. for j1 in mid_json.loc['inline_equations',:]:
  46. page_in=[]
  47. for k1 in j1:
  48. page_in.append(k1['latex_text'])
  49. test_inline_equations.append(page_in)
  50. for j2 in mid_json.loc['interline_equations',:]:
  51. page_in=[]
  52. for k2 in j2:
  53. page_in.append(k2['latex_text'])
  54. test_interline_equations.append(page_in)
  55. for j3 in mid_json.loc['droped_text_block',:]:
  56. page_in_bbox=[]
  57. page_in_tag=[]
  58. for k3 in j3:
  59. page_in_bbox.append(k3['bbox'])
  60. #如果k3中存在tag这个key
  61. if 'tag' in k3.keys():
  62. page_in_tag.append(k3['tag'])
  63. else:
  64. page_in_tag.append('None')
  65. test_dropped_text_tag.append(page_in_tag)
  66. test_dropped_text_bboxes.append(page_in_bbox)
  67. for j4 in mid_json.loc['droped_image_block',:]:
  68. test_dropped_image_bboxes.append(j4)
  69. for j5 in mid_json.loc['droped_table_block',:]:
  70. test_dropped_table_bboxes.append(j5)
  71. for j6 in mid_json.loc['preproc_blocks',:]:
  72. page_in=[]
  73. for k6 in j6:
  74. page_in.append(k6['number'])
  75. test_preproc_num.append(page_in)
  76. test_pdf_text=[]
  77. for j7 in mid_json.loc['para_blocks',:]:
  78. test_para_num.append(len(j7))
  79. for k7 in j7:
  80. test_pdf_text.append(k7['text'])
  81. test_para_text.append(test_pdf_text)
  82. standard_inline_equations=[]
  83. standard_interline_equations=[]
  84. standard_dropped_text_bboxes=[]
  85. standard_dropped_text_tag=[]
  86. standard_dropped_image_bboxes=[]
  87. standard_dropped_table_bboxes=[]
  88. standard_preproc_num=[]#阅读顺序
  89. standard_para_num=[]
  90. standard_para_text=[]
  91. for i in json_standard:
  92. mid_json=pd.DataFrame(i)
  93. mid_json=mid_json.iloc[:,:-1]
  94. for j1 in mid_json.loc['inline_equations',:]:
  95. page_in=[]
  96. for k1 in j1:
  97. page_in.append(k1['latex_text'])
  98. standard_inline_equations.append(page_in)
  99. for j2 in mid_json.loc['interline_equations',:]:
  100. page_in=[]
  101. for k2 in j2:
  102. page_in.append(k2['latex_text'])
  103. standard_interline_equations.append(page_in)
  104. for j3 in mid_json.loc['droped_text_block',:]:
  105. page_in_bbox=[]
  106. page_in_tag=[]
  107. for k3 in j3:
  108. page_in_bbox.append(k3['bbox'])
  109. if 'tag' in k3.keys():
  110. page_in_tag.append(k3['tag'])
  111. else:
  112. page_in_tag.append('None')
  113. standard_dropped_text_bboxes.append(page_in_bbox)
  114. standard_dropped_text_tag.append(page_in_tag)
  115. for j4 in mid_json.loc['droped_image_block',:]:
  116. standard_dropped_image_bboxes.append(j4)
  117. for j5 in mid_json.loc['droped_table_block',:]:
  118. standard_dropped_table_bboxes.append(j5)
  119. for j6 in mid_json.loc['preproc_blocks',:]:
  120. page_in=[]
  121. for k6 in j6:
  122. page_in.append(k6['number'])
  123. standard_preproc_num.append(page_in)
  124. standard_pdf_text=[]
  125. for j7 in mid_json.loc['para_blocks',:]:
  126. standard_para_num.append(len(j7))
  127. for k7 in j7:
  128. standard_pdf_text.append(k7['text'])
  129. standard_para_text.append(standard_pdf_text)
  130. """
  131. 在计算指标之前最好先确认基本统计信息是否一致
  132. """
  133. '''计算pdf之间的总体编辑距离和bleu'''
  134. pdf_dis=[]
  135. pdf_bleu=[]
  136. for a,b in zip(test_para_text,standard_para_text):
  137. a1=[ ''.join(i) for i in a]
  138. b1=[ ''.join(i) for i in b]
  139. pdf_dis.append(Levenshtein_Distance(a1,b1))
  140. pdf_bleu.append(sentence_bleu([a1],b1))
  141. overall_report['pdf间的平均编辑距离']=np.mean(pdf_dis)
  142. overall_report['pdf间的平均bleu']=np.mean(pdf_bleu)
  143. '''行内公式编辑距离和bleu'''
  144. dis1=[]
  145. bleu1=[]
  146. test_inline_equations=[ ''.join(i) for i in test_inline_equations]
  147. standard_inline_equations=[ ''.join(i) for i in standard_inline_equations]
  148. for a,b in zip(test_inline_equations,standard_inline_equations):
  149. if len(a)==0 and len(b)==0:
  150. continue
  151. else:
  152. if a==b:
  153. dis1.append(0)
  154. bleu1.append(1)
  155. else:
  156. dis1.append(Levenshtein_Distance(a,b))
  157. bleu1.append(sentence_bleu([a],b))
  158. inline_equations_edit=np.mean(dis1)
  159. inline_equations_bleu=np.mean(bleu1)
  160. '''行间公式编辑距离和bleu'''
  161. dis2=[]
  162. bleu2=[]
  163. test_interline_equations=[ ''.join(i) for i in test_interline_equations]
  164. standard_interline_equations=[ ''.join(i) for i in standard_interline_equations]
  165. for a,b in zip(test_interline_equations,standard_interline_equations):
  166. if len(a)==0 and len(b)==0:
  167. continue
  168. else:
  169. if a==b:
  170. dis2.append(0)
  171. bleu2.append(1)
  172. else:
  173. dis2.append(Levenshtein_Distance(a,b))
  174. bleu2.append(sentence_bleu([a],b))
  175. interline_equations_edit=np.mean(dis2)
  176. interline_equations_bleu=np.mean(bleu2)
  177. '''可以先检查page和bbox数量是否一致'''
  178. '''dropped_text_block的bbox匹配相关指标'''
  179. test_text_bbox=[]
  180. standard_text_bbox=[]
  181. test_tag=[]
  182. standard_tag=[]
  183. index=0
  184. for a,b in zip(test_dropped_text_bboxes,standard_dropped_text_bboxes):
  185. test_page_tag=[]
  186. standard_page_tag=[]
  187. test_page_bbox=[]
  188. standard_page_bbox=[]
  189. if len(a)==0 and len(b)==0:
  190. pass
  191. else:
  192. for i in range(len(b)):
  193. judge=0
  194. standard_page_tag.append(standard_dropped_text_tag[index][i])
  195. standard_page_bbox.append(1)
  196. for j in range(len(a)):
  197. if bbox_offset(b[i],a[j]):
  198. judge=1
  199. test_page_tag.append(test_dropped_text_tag[index][j])
  200. test_page_bbox.append(1)
  201. break
  202. if judge==0:
  203. test_page_tag.append('None')
  204. test_page_bbox.append(0)
  205. if len(test_dropped_text_tag[index])+test_page_tag.count('None')>len(standard_dropped_text_tag[index]):#有多删的情况出现
  206. test_page_tag1=test_page_tag.copy()
  207. if 'None' in test_page_tag:
  208. test_page_tag1=test_page_tag1.remove('None')
  209. else:
  210. test_page_tag1=test_page_tag
  211. diff=list((Counter(test_dropped_text_tag[index]) - Counter(test_page_tag1)).elements())
  212. test_page_tag.extend(diff)
  213. standard_page_tag.extend(['None']*len(diff))
  214. test_page_bbox.extend([1]*len(diff))
  215. standard_page_bbox.extend([0]*len(diff))
  216. test_tag.extend(test_page_tag)
  217. standard_tag.extend(standard_page_tag)
  218. test_text_bbox.extend(test_page_bbox)
  219. standard_text_bbox.extend(standard_page_bbox)
  220. index+=1
  221. text_block_report = {}
  222. text_block_report['accuracy']=metrics.accuracy_score(standard_text_bbox,test_text_bbox)
  223. text_block_report['precision']=metrics.precision_score(standard_text_bbox,test_text_bbox)
  224. text_block_report['recall']=metrics.recall_score(standard_text_bbox,test_text_bbox)
  225. text_block_report['f1_score']=metrics.f1_score(standard_text_bbox,test_text_bbox)
  226. '''删除的text_block的tag的准确率,召回率和f1-score'''
  227. text_block_tag_report = classification_report(y_true=standard_tag , y_pred=test_tag,output_dict=True)
  228. del text_block_tag_report['None']
  229. del text_block_tag_report["macro avg"]
  230. del text_block_tag_report["weighted avg"]
  231. '''dropped_image_block的bbox匹配相关指标'''
  232. '''有数据格式不一致的问题'''
  233. test_image_bbox=[]
  234. standard_image_bbox=[]
  235. for a,b in zip(test_dropped_image_bboxes,standard_dropped_image_bboxes):
  236. test_page_bbox=[]
  237. standard_page_bbox=[]
  238. if len(a)==0 and len(b)==0:
  239. pass
  240. else:
  241. for i in b:
  242. if len(i)!=4:
  243. continue
  244. else:
  245. judge=0
  246. standard_page_bbox.append(1)
  247. for j in a:
  248. if bbox_offset(i,j):
  249. judge=1
  250. test_page_bbox.append(1)
  251. break
  252. if judge==0:
  253. test_page_bbox.append(0)
  254. diff_num=len(a)+test_page_bbox.count(0)-len(b)
  255. if diff_num>0:#有多删的情况出现
  256. test_page_bbox.extend([1]*diff_num)
  257. standard_page_bbox.extend([0]*diff_num)
  258. test_image_bbox.extend(test_page_bbox)
  259. standard_image_bbox.extend(standard_page_bbox)
  260. image_block_report = {}
  261. image_block_report['accuracy']=metrics.accuracy_score(standard_image_bbox,test_image_bbox)
  262. image_block_report['precision']=metrics.precision_score(standard_image_bbox,test_image_bbox)
  263. image_block_report['recall']=metrics.recall_score(standard_image_bbox,test_image_bbox)
  264. image_block_report['f1_score']=metrics.f1_score(standard_image_bbox,test_image_bbox)
  265. '''dropped_table_block的bbox匹配相关指标'''
  266. test_table_bbox=[]
  267. standard_table_bbox=[]
  268. for a,b in zip(test_dropped_table_bboxes,standard_dropped_table_bboxes):
  269. test_page_bbox=[]
  270. standard_page_bbox=[]
  271. if len(a)==0 and len(b)==0:
  272. pass
  273. else:
  274. for i in b:
  275. if len(i)!=4:
  276. continue
  277. else:
  278. judge=0
  279. standard_page_bbox.append(1)
  280. for j in a:
  281. if bbox_offset(i,j):
  282. judge=1
  283. test_page_bbox.append(1)
  284. break
  285. if judge==0:
  286. test_page_bbox.append(0)
  287. diff_num=len(a)+test_page_bbox.count(0)-len(b)
  288. if diff_num>0:#有多删的情况出现
  289. test_page_bbox.extend([1]*diff_num)
  290. standard_page_bbox.extend([0]*diff_num)
  291. test_table_bbox.extend(test_page_bbox)
  292. standard_table_bbox.extend(standard_page_bbox)
  293. table_block_report = {}
  294. table_block_report['accuracy']=metrics.accuracy_score(standard_table_bbox,test_table_bbox)
  295. table_block_report['precision']=metrics.precision_score(standard_table_bbox,test_table_bbox)
  296. table_block_report['recall']=metrics.recall_score(standard_table_bbox,test_table_bbox)
  297. table_block_report['f1_score']=metrics.f1_score(standard_table_bbox,test_table_bbox)
  298. '''阅读顺序编辑距离的均值'''
  299. preproc_num_dis=[]
  300. for a,b in zip(test_preproc_num,standard_preproc_num):
  301. preproc_num_dis.append(Levenshtein_Distance(a,b))
  302. preproc_num_edit=np.mean(preproc_num_dis)
  303. '''分段准确率'''
  304. test_para_num=np.array(test_para_num)
  305. standard_para_num=np.array(standard_para_num)
  306. acc_para=np.mean(test_para_num==standard_para_num)
  307. output=pd.DataFrame()
  308. output['总体指标']=[overall_report]
  309. output['行内公式平均编辑距离']=[inline_equations_edit]
  310. output['行间公式平均编辑距离']=[interline_equations_edit]
  311. output['行内公式平均bleu']=[inline_equations_bleu]
  312. output['行间公式平均bleu']=[interline_equations_bleu]
  313. output['阅读顺序平均编辑距离']=[preproc_num_edit]
  314. output['分段准确率']=[acc_para]
  315. output['删除的text block的相关指标']=[text_block_report]
  316. output['删除的image block的相关指标']=[image_block_report]
  317. output['删除的table block的相关指标']=[table_block_report]
  318. output['删除的text block的tag相关指标']=[text_block_tag_report]
  319. return output
  320. """
  321. 计算编辑距离
  322. """
  323. def Levenshtein_Distance(str1, str2):
  324. matrix = [[ i + j for j in range(len(str2) + 1)] for i in range(len(str1) + 1)]
  325. for i in range(1, len(str1)+1):
  326. for j in range(1, len(str2)+1):
  327. if(str1[i-1] == str2[j-1]):
  328. d = 0
  329. else:
  330. d = 1
  331. matrix[i][j] = min(matrix[i-1][j]+1, matrix[i][j-1]+1, matrix[i-1][j-1]+d)
  332. return matrix[len(str1)][len(str2)]
  333. '''
  334. 计算bbox偏移量是否符合标准的函数
  335. '''
  336. def bbox_offset(b_t,b_s):
  337. '''b_t是test_doc里的bbox,b_s是standard_doc里的bbox'''
  338. x1_t,y1_t,x2_t,y2_t=b_t
  339. x1_s,y1_s,x2_s,y2_s=b_s
  340. x1=max(x1_t,x1_s)
  341. x2=min(x2_t,x2_s)
  342. y1=max(y1_t,y1_s)
  343. y2=min(y2_t,y2_s)
  344. area_overlap=(x2-x1)*(y2-y1)
  345. area_t=(x2_t-x1_t)*(y2_t-y1_t)+(x2_s-x1_s)*(y2_s-y1_s)-area_overlap
  346. if area_t-area_overlap==0 or area_overlap/(area_t-area_overlap)>0.95:
  347. return True
  348. else:
  349. return False
  350. parser = argparse.ArgumentParser()
  351. parser.add_argument('--test', type=str)
  352. parser.add_argument('--standard', type=str)
  353. args = parser.parse_args()
  354. pdf_json_test = args.test
  355. pdf_json_standard = args.standard
  356. if __name__ == '__main__':
  357. pdf_json_test = [json.loads(line)
  358. for line in open(pdf_json_test, 'r', encoding='utf-8')]
  359. pdf_json_standard = [json.loads(line)
  360. for line in open(pdf_json_standard, 'r', encoding='utf-8')]
  361. overall_indicator=indicator_cal(pdf_json_standard,pdf_json_test)
  362. '''计算的指标输出到overall_indicator_output.json中'''
  363. overall_indicator.to_json('overall_indicator_output.json',orient='records',lines=True,force_ascii=False)