overall_indicator.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456
  1. import json
  2. import pandas as pd
  3. import numpy as np
  4. import re
  5. from nltk.translate.bleu_score import sentence_bleu
  6. import time
  7. import argparse
  8. import os
  9. from sklearn.metrics import classification_report,confusion_matrix
  10. from collections import Counter
  11. from sklearn import metrics
  12. from pandas import isnull
  13. def indicator_cal(json_standard,json_test):
  14. json_standard = pd.DataFrame(json_standard)
  15. json_test = pd.DataFrame(json_test)
  16. '''数据集总体指标'''
  17. a=json_test[['id','mid_json']]
  18. b=json_standard[['id','mid_json','pass_label']]
  19. outer_merge=pd.merge(a,b,on='id',how='outer')
  20. outer_merge.columns=['id','standard_mid_json','test_mid_json','pass_label']
  21. standard_exist=outer_merge.standard_mid_json.apply(lambda x: not isnull(x))
  22. test_exist=outer_merge.test_mid_json.apply(lambda x: not isnull(x))
  23. overall_report = {}
  24. overall_report['accuracy']=metrics.accuracy_score(standard_exist,test_exist)
  25. overall_report['precision']=metrics.precision_score(standard_exist,test_exist)
  26. overall_report['recall']=metrics.recall_score(standard_exist,test_exist)
  27. overall_report['f1_score']=metrics.f1_score(standard_exist,test_exist)
  28. inner_merge=pd.merge(a,b,on='id',how='inner')
  29. inner_merge.columns=['id','standard_mid_json','test_mid_json','pass_label']
  30. json_standard = inner_merge['standard_mid_json']#check一下是否对齐
  31. json_test = inner_merge['test_mid_json']
  32. '''批量读取中间生成的json文件'''
  33. test_inline_equations=[]
  34. test_interline_equations=[]
  35. test_inline_euqations_bboxs=[]
  36. test_interline_equations_bboxs=[]
  37. test_dropped_text_bboxes=[]
  38. test_dropped_text_tag=[]
  39. test_dropped_image_bboxes=[]
  40. test_dropped_table_bboxes=[]
  41. test_preproc_num=[]#阅读顺序
  42. test_para_num=[]
  43. test_para_text=[]
  44. for i in json_test:
  45. mid_json=pd.DataFrame(i)
  46. mid_json=mid_json.iloc[:,:-1]
  47. for j1 in mid_json.loc['inline_equations',:]:
  48. page_in_text=[]
  49. page_in_bbox=[]
  50. for k1 in j1:
  51. page_in_text.append(k1['latex_text'])
  52. page_in_bbox.append(k1['bbox'])
  53. test_inline_equations.append(page_in_text)
  54. test_inline_euqations_bboxs.append(page_in_bbox)
  55. for j2 in mid_json.loc['interline_equations',:]:
  56. page_in_text=[]
  57. page_in_bbox=[]
  58. for k2 in j2:
  59. page_in_text.append(k2['latex_text'])
  60. test_interline_equations.append(page_in_text)
  61. test_interline_equations_bboxs.append(page_in_bbox)
  62. for j3 in mid_json.loc['droped_text_block',:]:
  63. page_in_bbox=[]
  64. page_in_tag=[]
  65. for k3 in j3:
  66. page_in_bbox.append(k3['bbox'])
  67. #如果k3中存在tag这个key
  68. if 'tag' in k3.keys():
  69. page_in_tag.append(k3['tag'])
  70. else:
  71. page_in_tag.append('None')
  72. test_dropped_text_tag.append(page_in_tag)
  73. test_dropped_text_bboxes.append(page_in_bbox)
  74. for j4 in mid_json.loc['droped_image_block',:]:
  75. test_dropped_image_bboxes.append(j4)
  76. for j5 in mid_json.loc['droped_table_block',:]:
  77. test_dropped_table_bboxes.append(j5)
  78. for j6 in mid_json.loc['preproc_blocks',:]:
  79. page_in=[]
  80. for k6 in j6:
  81. page_in.append(k6['number'])
  82. test_preproc_num.append(page_in)
  83. test_pdf_text=[]
  84. for j7 in mid_json.loc['para_blocks',:]:
  85. test_para_num.append(len(j7))
  86. for k7 in j7:
  87. test_pdf_text.append(k7['text'])
  88. test_para_text.append(test_pdf_text)
  89. standard_inline_equations=[]
  90. standard_interline_equations=[]
  91. standard_inline_euqations_bboxs=[]
  92. standard_interline_equations_bboxs=[]
  93. standard_dropped_text_bboxes=[]
  94. standard_dropped_text_tag=[]
  95. standard_dropped_image_bboxes=[]
  96. standard_dropped_table_bboxes=[]
  97. standard_preproc_num=[]#阅读顺序
  98. standard_para_num=[]
  99. standard_para_text=[]
  100. for i in json_standard:
  101. mid_json=pd.DataFrame(i)
  102. mid_json=mid_json.iloc[:,:-1]
  103. for j1 in mid_json.loc['inline_equations',:]:
  104. page_in_text=[]
  105. page_in_bbox=[]
  106. for k1 in j1:
  107. page_in_text.append(k1['latex_text'])
  108. page_in_bbox.append(k1['bbox'])
  109. standard_inline_equations.append(page_in_text)
  110. standard_inline_euqations_bboxs.append(page_in_bbox)
  111. for j2 in mid_json.loc['interline_equations',:]:
  112. page_in_text=[]
  113. page_in_bbox=[]
  114. for k2 in j2:
  115. page_in_text.append(k2['latex_text'])
  116. page_in_bbox.append(k2['bbox'])
  117. standard_interline_equations.append(page_in_text)
  118. standard_interline_equations_bboxs.append(page_in_bbox)
  119. for j3 in mid_json.loc['droped_text_block',:]:
  120. page_in_bbox=[]
  121. page_in_tag=[]
  122. for k3 in j3:
  123. page_in_bbox.append(k3['bbox'])
  124. if 'tag' in k3.keys():
  125. page_in_tag.append(k3['tag'])
  126. else:
  127. page_in_tag.append('None')
  128. standard_dropped_text_bboxes.append(page_in_bbox)
  129. standard_dropped_text_tag.append(page_in_tag)
  130. for j4 in mid_json.loc['droped_image_block',:]:
  131. standard_dropped_image_bboxes.append(j4)
  132. for j5 in mid_json.loc['droped_table_block',:]:
  133. standard_dropped_table_bboxes.append(j5)
  134. for j6 in mid_json.loc['preproc_blocks',:]:
  135. page_in=[]
  136. for k6 in j6:
  137. page_in.append(k6['number'])
  138. standard_preproc_num.append(page_in)
  139. standard_pdf_text=[]
  140. for j7 in mid_json.loc['para_blocks',:]:
  141. standard_para_num.append(len(j7))
  142. for k7 in j7:
  143. standard_pdf_text.append(k7['text'])
  144. standard_para_text.append(standard_pdf_text)
  145. """
  146. 在计算指标之前最好先确认基本统计信息是否一致
  147. """
  148. '''
  149. 计算pdf之间的总体编辑距离和bleu
  150. 这里只计算正例的pdf
  151. '''
  152. test_para_text=np.asarray(test_para_text, dtype = object)[inner_merge['pass_label']=='yes']
  153. standard_para_text=np.asarray(standard_para_text, dtype = object)[inner_merge['pass_label']=='yes']
  154. pdf_dis=[]
  155. pdf_bleu=[]
  156. for a,b in zip(test_para_text,standard_para_text):
  157. a1=[ ''.join(i) for i in a]
  158. b1=[ ''.join(i) for i in b]
  159. pdf_dis.append(Levenshtein_Distance(a1,b1))
  160. pdf_bleu.append(sentence_bleu([a1],b1))
  161. overall_report['pdf间的平均编辑距离']=np.mean(pdf_dis)
  162. overall_report['pdf间的平均bleu']=np.mean(pdf_bleu)
  163. '''行内公式编辑距离和bleu'''
  164. dis1=[]
  165. bleu1=[]
  166. test_inline_equations=[ ''.join(i) for i in test_inline_equations]
  167. standard_inline_equations=[ ''.join(i) for i in standard_inline_equations]
  168. for a,b in zip(test_inline_equations,standard_inline_equations):
  169. if len(a)==0 and len(b)==0:
  170. continue
  171. else:
  172. if a==b:
  173. dis1.append(0)
  174. bleu1.append(1)
  175. else:
  176. dis1.append(Levenshtein_Distance(a,b))
  177. bleu1.append(sentence_bleu([a],b))
  178. inline_equations_edit=np.mean(dis1)
  179. inline_equations_bleu=np.mean(bleu1)
  180. '''行内公式bbox匹配相关指标'''
  181. inline_equations_bbox_report=bbox_match_indicator(test_inline_euqations_bboxs,standard_inline_euqations_bboxs)
  182. '''行间公式编辑距离和bleu'''
  183. dis2=[]
  184. bleu2=[]
  185. test_interline_equations=[ ''.join(i) for i in test_interline_equations]
  186. standard_interline_equations=[ ''.join(i) for i in standard_interline_equations]
  187. for a,b in zip(test_interline_equations,standard_interline_equations):
  188. if len(a)==0 and len(b)==0:
  189. continue
  190. else:
  191. if a==b:
  192. dis2.append(0)
  193. bleu2.append(1)
  194. else:
  195. dis2.append(Levenshtein_Distance(a,b))
  196. bleu2.append(sentence_bleu([a],b))
  197. interline_equations_edit=np.mean(dis2)
  198. interline_equations_bleu=np.mean(bleu2)
  199. '''行间公式bbox匹配相关指标'''
  200. interline_equations_bbox_report=bbox_match_indicator(test_interline_equations_bboxs,standard_interline_equations_bboxs)
  201. '''可以先检查page和bbox数量是否一致'''
  202. '''dropped_text_block的bbox匹配相关指标'''
  203. test_text_bbox=[]
  204. standard_text_bbox=[]
  205. test_tag=[]
  206. standard_tag=[]
  207. index=0
  208. for a,b in zip(test_dropped_text_bboxes,standard_dropped_text_bboxes):
  209. test_page_tag=[]
  210. standard_page_tag=[]
  211. test_page_bbox=[]
  212. standard_page_bbox=[]
  213. if len(a)==0 and len(b)==0:
  214. pass
  215. else:
  216. for i in range(len(b)):
  217. judge=0
  218. standard_page_tag.append(standard_dropped_text_tag[index][i])
  219. standard_page_bbox.append(1)
  220. for j in range(len(a)):
  221. if bbox_offset(b[i],a[j]):
  222. judge=1
  223. test_page_tag.append(test_dropped_text_tag[index][j])
  224. test_page_bbox.append(1)
  225. break
  226. if judge==0:
  227. test_page_tag.append('None')
  228. test_page_bbox.append(0)
  229. if len(test_dropped_text_tag[index])+test_page_tag.count('None')>len(standard_dropped_text_tag[index]):#有多删的情况出现
  230. test_page_tag1=test_page_tag.copy()
  231. if 'None' in test_page_tag:
  232. test_page_tag1=test_page_tag1.remove('None')
  233. else:
  234. test_page_tag1=test_page_tag
  235. diff=list((Counter(test_dropped_text_tag[index]) - Counter(test_page_tag1)).elements())
  236. test_page_tag.extend(diff)
  237. standard_page_tag.extend(['None']*len(diff))
  238. test_page_bbox.extend([1]*len(diff))
  239. standard_page_bbox.extend([0]*len(diff))
  240. test_tag.extend(test_page_tag)
  241. standard_tag.extend(standard_page_tag)
  242. test_text_bbox.extend(test_page_bbox)
  243. standard_text_bbox.extend(standard_page_bbox)
  244. index+=1
  245. text_block_report = {}
  246. text_block_report['accuracy']=metrics.accuracy_score(standard_text_bbox,test_text_bbox)
  247. text_block_report['precision']=metrics.precision_score(standard_text_bbox,test_text_bbox)
  248. text_block_report['recall']=metrics.recall_score(standard_text_bbox,test_text_bbox)
  249. text_block_report['f1_score']=metrics.f1_score(standard_text_bbox,test_text_bbox)
  250. '''删除的text_block的tag的准确率,召回率和f1-score'''
  251. text_block_tag_report = classification_report(y_true=standard_tag , y_pred=test_tag,output_dict=True)
  252. del text_block_tag_report['None']
  253. del text_block_tag_report["macro avg"]
  254. del text_block_tag_report["weighted avg"]
  255. '''dropped_image_block的bbox匹配相关指标'''
  256. '''有数据格式不一致的问题'''
  257. image_block_report=bbox_match_indicator(test_dropped_image_bboxes,standard_dropped_image_bboxes)
  258. '''dropped_table_block的bbox匹配相关指标'''
  259. table_block_report=bbox_match_indicator(test_dropped_table_bboxes,standard_dropped_table_bboxes)
  260. '''阅读顺序编辑距离的均值'''
  261. preproc_num_dis=[]
  262. for a,b in zip(test_preproc_num,standard_preproc_num):
  263. preproc_num_dis.append(Levenshtein_Distance(a,b))
  264. preproc_num_edit=np.mean(preproc_num_dis)
  265. '''分段准确率'''
  266. test_para_num=np.array(test_para_num)
  267. standard_para_num=np.array(standard_para_num)
  268. acc_para=np.mean(test_para_num==standard_para_num)
  269. output=pd.DataFrame()
  270. output['总体指标']=[overall_report]
  271. output['行内公式平均编辑距离']=[inline_equations_edit]
  272. output['行间公式平均编辑距离']=[interline_equations_edit]
  273. output['行内公式平均bleu']=[inline_equations_bleu]
  274. output['行间公式平均bleu']=[interline_equations_bleu]
  275. output['行内公式识别相关指标']=[inline_equations_bbox_report]
  276. output['行间公式识别相关指标']=[interline_equations_bbox_report]
  277. output['阅读顺序平均编辑距离']=[preproc_num_edit]
  278. output['分段准确率']=[acc_para]
  279. output['删除的text block的相关指标']=[text_block_report]
  280. output['删除的image block的相关指标']=[image_block_report]
  281. output['删除的table block的相关指标']=[table_block_report]
  282. output['删除的text block的tag相关指标']=[text_block_tag_report]
  283. return output
  284. """
  285. 计算编辑距离
  286. """
  287. def Levenshtein_Distance(str1, str2):
  288. matrix = [[ i + j for j in range(len(str2) + 1)] for i in range(len(str1) + 1)]
  289. for i in range(1, len(str1)+1):
  290. for j in range(1, len(str2)+1):
  291. if(str1[i-1] == str2[j-1]):
  292. d = 0
  293. else:
  294. d = 1
  295. matrix[i][j] = min(matrix[i-1][j]+1, matrix[i][j-1]+1, matrix[i-1][j-1]+d)
  296. return matrix[len(str1)][len(str2)]
  297. '''
  298. 计算bbox偏移量是否符合标准的函数
  299. '''
  300. def bbox_offset(b_t,b_s):
  301. '''b_t是test_doc里的bbox,b_s是standard_doc里的bbox'''
  302. x1_t,y1_t,x2_t,y2_t=b_t
  303. x1_s,y1_s,x2_s,y2_s=b_s
  304. x1=max(x1_t,x1_s)
  305. x2=min(x2_t,x2_s)
  306. y1=max(y1_t,y1_s)
  307. y2=min(y2_t,y2_s)
  308. area_overlap=(x2-x1)*(y2-y1)
  309. area_t=(x2_t-x1_t)*(y2_t-y1_t)+(x2_s-x1_s)*(y2_s-y1_s)-area_overlap
  310. if area_t-area_overlap==0 or area_overlap/(area_t-area_overlap)>0.95:
  311. return True
  312. else:
  313. return False
  314. '''bbox匹配和对齐函数,输出相关指标'''
  315. '''输入的是以page为单位的bbox列表'''
  316. def bbox_match_indicator(test_bbox_list,standard_bbox_list):
  317. test_bbox=[]
  318. standard_bbox=[]
  319. for a,b in zip(test_bbox_list,standard_bbox_list):
  320. test_page_bbox=[]
  321. standard_page_bbox=[]
  322. if len(a)==0 and len(b)==0:
  323. pass
  324. else:
  325. for i in b:
  326. if len(i)!=4:
  327. continue
  328. else:
  329. judge=0
  330. standard_page_bbox.append(1)
  331. for j in a:
  332. if bbox_offset(i,j):
  333. judge=1
  334. test_page_bbox.append(1)
  335. break
  336. if judge==0:
  337. test_page_bbox.append(0)
  338. diff_num=len(a)+test_page_bbox.count(0)-len(b)
  339. if diff_num>0:#有多删的情况出现
  340. test_page_bbox.extend([1]*diff_num)
  341. standard_page_bbox.extend([0]*diff_num)
  342. test_bbox.extend(test_page_bbox)
  343. standard_bbox.extend(standard_page_bbox)
  344. block_report = {}
  345. block_report['accuracy']=metrics.accuracy_score(standard_bbox,test_bbox)
  346. block_report['precision']=metrics.precision_score(standard_bbox,test_bbox)
  347. block_report['recall']=metrics.recall_score(standard_bbox,test_bbox)
  348. block_report['f1_score']=metrics.f1_score(standard_bbox,test_bbox)
  349. return block_report
  350. parser = argparse.ArgumentParser()
  351. parser.add_argument('--test', type=str)
  352. parser.add_argument('--standard', type=str)
  353. args = parser.parse_args()
  354. pdf_json_test = args.test
  355. pdf_json_standard = args.standard
  356. if __name__ == '__main__':
  357. pdf_json_test = [json.loads(line)
  358. for line in open(pdf_json_test, 'r', encoding='utf-8')]
  359. pdf_json_standard = [json.loads(line)
  360. for line in open(pdf_json_standard, 'r', encoding='utf-8')]
  361. overall_indicator=indicator_cal(pdf_json_standard,pdf_json_test)
  362. '''计算的指标输出到overall_indicator_output.json中'''
  363. overall_indicator.to_json('overall_indicator_output.json',orient='records',lines=True,force_ascii=False)