Prechádzať zdrojové kódy

fix: caption or footnote match algorithm

icecraft 1 rok pred
rodič
commit
ef45ad0874
1 zmenil súbory, kde vykonal 25 pridanie a 5 odobranie
  1. 25 5
      magic_pdf/model/magic_model.py

+ 25 - 5
magic_pdf/model/magic_model.py

@@ -110,6 +110,26 @@ class MagicModel:
         self.__fix_by_remove_high_iou_and_low_confidence()
         self.__fix_footnote()
 
+    def _bbox_distance(self, bbox1, bbox2):
+        left, right, bottom, top = bbox_relative_pos(bbox1, bbox2)
+        flags = [left, right, bottom, top]
+        count = sum([1 if v else 0 for v in flags])
+        if count > 1:
+            return float('inf')
+        if left or right:
+            l1 = bbox1[3] - bbox1[1]
+            l2 = bbox2[3] - bbox2[1]
+            minL, maxL = min(l1, l2), max(l1, l2)
+            if (maxL - minL) / minL > 0.5:
+                return float('inf')
+        if bottom or top:
+            l1 = bbox1[2] - bbox1[0]
+            l2 = bbox2[2] - bbox2[0]
+            minL, maxL = min(l1, l2), max(l1, l2)
+            if (maxL - minL) / minL > 0.5:
+                return float('inf')
+        return bbox_distance(bbox1, bbox2)
+
     def __fix_footnote(self):
         # 3: figure, 5: table, 7: footnote
         for model_page_info in self.__model_list:
@@ -144,7 +164,7 @@ class MagicModel:
                     if pos_flag_count > 1:
                         continue
                     dis_figure_footnote[i] = min(
-                        bbox_distance(figures[j]['bbox'], footnotes[i]['bbox']),
+                        self._bbox_distance(figures[j]['bbox'], footnotes[i]['bbox']),
                         dis_figure_footnote.get(i, float('inf')),
                     )
             for i in range(len(footnotes)):
@@ -163,7 +183,7 @@ class MagicModel:
                         continue
 
                     dis_table_footnote[i] = min(
-                        bbox_distance(tables[j]['bbox'], footnotes[i]['bbox']),
+                        self._bbox_distance(tables[j]['bbox'], footnotes[i]['bbox']),
                         dis_table_footnote.get(i, float('inf')),
                     )
             for i in range(len(footnotes)):
@@ -350,7 +370,7 @@ class MagicModel:
                     dis[j][i] = dis[i][j]
                     continue
 
-                dis[i][j] = bbox_distance(all_bboxes[i]['bbox'], all_bboxes[j]['bbox'])
+                dis[i][j] = self._bbox_distance(all_bboxes[i]['bbox'], all_bboxes[j]['bbox'])
                 dis[j][i] = dis[i][j]
 
         used = set()
@@ -441,7 +461,7 @@ class MagicModel:
 
                     if is_nearest:
                         nx0, ny0, nx1, ny1 = expand_bbbox(list(seen) + [k])
-                        n_dis = bbox_distance(
+                        n_dis = self._bbox_distance(
                             all_bboxes[i]['bbox'], [nx0, ny0, nx1, ny1]
                         )
                         if float_gt(dis[i][j], n_dis):
@@ -537,7 +557,7 @@ class MagicModel:
         # 计算已经配对的 distance 距离
         for i in subject_object_relation_map.keys():
             for j in subject_object_relation_map[i]:
-                total_subject_object_dis += bbox_distance(
+                total_subject_object_dis += self._bbox_distance(
                     all_bboxes[i]['bbox'], all_bboxes[j]['bbox']
                 )