vali_bbox_sort.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import numpy as np
  2. import tqdm
  3. import json
  4. from validation import cal_edit_distance, format_gt_bbox
  5. from magic_pdf.layout.layout_sort import sort_with_layout
  6. with open('/mnt/petrelfs/share_data/ouyanglinke/OCR/OCR_validation_dataset_final_rotated_formulafix_highdpi_scihub.json', 'r') as f:
  7. samples = json.load(f)
  8. # labels = []
  9. # det_res = []
  10. edit_distance_dict = []
  11. edit_distance_list = []
  12. for i, sample in tqdm.tqdm(enumerate(samples)):
  13. pdf_name = sample['pdf_name']
  14. s3_pdf_path = sample['s3_path']
  15. page_num = sample['page']
  16. page_width = sample['annotations']['width']
  17. page_height = sample['annotations']['height']
  18. # pre = main(s3_pdf_path, pdf_bin_file_profile, join_path(pdf_model_dir, pdf_name), pdf_model_profile, save_path, page_num)
  19. # pre_dict_list = []
  20. # for item in pre:
  21. # pre_sample = {
  22. # 'box': [item[0],item[1],item[2],item[3]],
  23. # 'type': item[7],
  24. # 'score': 1
  25. # }
  26. # pre_dict_list.append(pre_sample)
  27. # det_res.append(pre_dict_list)
  28. # match_change_dict = { # 待确认
  29. # "figure": "image",
  30. # "svg_figure": "image",
  31. # "inline_fomula": "equations_inline",
  32. # "fomula": "equation_interline",
  33. # "figure_caption": "text",
  34. # "table_caption": "text",
  35. # "fomula_caption": "text"
  36. # }
  37. gt_annos = sample['annotations']
  38. # matched_label = label_match(gt_annos, match_change_dict)
  39. # labels.append(matched_label)
  40. # 判断排序函数的精度
  41. # 目前不考虑caption与图表相同序号的问题
  42. ignore_category = ['abandon', 'figure_caption', 'table_caption', 'formula_caption', 'inline_fomula']
  43. gt_bboxes = format_gt_bbox(gt_annos, ignore_category)
  44. sorted_bboxes, _ = sort_with_layout(gt_bboxes, page_width, page_height)
  45. if sorted_bboxes:
  46. edit_distance = cal_edit_distance(sorted_bboxes)
  47. edit_distance_list.append(edit_distance)
  48. edit_distance_dict.append({
  49. "sample_id": i,
  50. "s3_path": s3_pdf_path,
  51. "page_num": page_num,
  52. "page_s2_path": sample['page_path'],
  53. "edit_distance": edit_distance
  54. })
  55. # label_classes = ["image", "text", "table", "equation_interline"]
  56. # detect_matrix = detect_val(labels, det_res, label_classes)
  57. # print('detect_matrix', detect_matrix)
  58. edit_distance_mean = np.mean(edit_distance_list)
  59. print('edit_distance_mean', edit_distance_mean)
  60. edit_distance_dict_sorted = sorted(edit_distance_dict, key=lambda x: x['edit_distance'], reverse=True)
  61. # print(edit_distance_dict_sorted)
  62. result = {
  63. "edit_distance_mean": edit_distance_mean,
  64. "edit_distance_dict_sorted": edit_distance_dict_sorted
  65. }
  66. with open('vali_bbox_sort_result.json', 'w') as f:
  67. json.dump(result, f)