validation.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. import numpy as np
  2. from mmeval import COCODetection
  3. import distance
  4. def reformat_gt_and_pred(labels, det_res, label_classes):
  5. preds = []
  6. gts = []
  7. for idx, (ann, pred) in enumerate(zip(labels, det_res)):
  8. # with open(label_path, "r") as f:
  9. # ann = json.load(f)
  10. gt_bboxes = []
  11. gt_labels = []
  12. for item in ann['step_1']['result']:
  13. if item['attribute'] in label_classes:
  14. gt_bboxes.append([item['x'], item['y'], item['x']+item['width'], item['y']+item['height']])
  15. gt_labels.append(label_classes.index(item['attribute']))
  16. gts.append({
  17. 'img_id': idx,
  18. 'width': ann['width'],
  19. 'height': ann['height'],
  20. 'bboxes': np.array(gt_bboxes),
  21. 'labels': np.array(gt_labels),
  22. 'ignore_flags': [False]*len(gt_labels),
  23. })
  24. bboxes = []
  25. labels = []
  26. scores = []
  27. for item in pred:
  28. bboxes.append(item['box'])
  29. labels.append(label_classes.index(item['type']))
  30. scores.append(item['score'])
  31. preds.append({
  32. 'img_id': idx,
  33. 'bboxes': np.array(bboxes),
  34. 'scores': np.array(scores),
  35. 'labels': np.array(labels),
  36. })
  37. return gts, preds
  38. def detect_val(labels, det_res, label_classes):
  39. # label_classes = ['inline_formula', "formula"]
  40. meta={'CLASSES':tuple(label_classes)}
  41. coco_det_metric = COCODetection(dataset_meta=meta, metric=['bbox'])
  42. gts, preds = reformat_gt_and_pred(labels, det_res, label_classes)
  43. res = coco_det_metric(predictions=preds, groundtruths=gts)
  44. return res
  45. def label_match(annotations, match_change_dict):
  46. for item in annotations['step_1']['result']:
  47. if item['attribute'] in match_change_dict.keys():
  48. item['attribute'] = match_change_dict[item['attribute']]
  49. return annotations
  50. def format_gt_bbox(annotations, ignore_category):
  51. gt_bboxes = []
  52. for item in annotations['step_1']['result']:
  53. if item['textAttribute'] and item['attribute'] not in ignore_category:
  54. x0 = item['x']
  55. y0 = item['y']
  56. x1 = item['x'] + item['width']
  57. y1 = item['y'] + item['height']
  58. order = item['textAttribute']
  59. category = item['attribute']
  60. gt_bboxes.append([x0, y0, x1, y1, order, None, None, category])
  61. return gt_bboxes
  62. def cal_edit_distance(sorted_bboxes):
  63. # order_list = [int(bbox[4]) for bbox in sorted_bboxes]
  64. # print(sorted_bboxes[0][0][12])
  65. order_list = [int(bbox[12]) for bbox in sorted_bboxes]
  66. sorted_order = sorted(order_list, key=int)
  67. distance_cal = distance.levenshtein(order_list, sorted_order)
  68. if len(order_list) > 0:
  69. return distance_cal / len(order_list)
  70. else:
  71. return 0