Selaa lähdekoodia

fix: table caption relation

许瑞 1 vuosi sitten
vanhempi
commit
86d7cff111
1 muutettua tiedostoa jossa 40 lisäystä ja 12 poistoa
  1. 40 12
      magic_pdf/model/magic_model.py

+ 40 - 12
magic_pdf/model/magic_model.py

@@ -21,6 +21,7 @@ from magic_pdf.libs.ModelBlockTypeEnum import ModelBlockTypeEnum
 
 CAPATION_OVERLAP_AREA_RATIO = 0.6
 
+
 class MagicModel:
     """
     每个函数没有得到元素的时候返回空list
@@ -89,23 +90,37 @@ class MagicModel:
         ret = []
         MAX_DIS_OF_POINT = 10**9 + 7
 
+        def expand_bbox(bbox1, bbox2):
+            x0 = min(bbox1[0], bbox2[0])
+            y0 = min(bbox1[1], bbox2[1])
+            x1 = max(bbox1[2], bbox2[2])
+            y1 = max(bbox1[3], bbox2[3])
+            return [x0, y0, x1, y1]
+
+        def get_bbox_area(bbox):
+            return abs(bbox[2] - bbox[0]) * abs(bbox[3] - bbox[1])
+
         # subject 和 object 的 bbox 会合并成一个大的 bbox (named: merged bbox)。 筛选出所有和 merged bbox 有 overlap 且 overlap 面积大于 object 的面积的 subjects。
         # 再求出筛选出的 subjects 和 object 的最短距离!
         def may_find_other_nearest_bbox(subject_idx, object_idx):
             ret = float("inf")
-            x0 = min(all_bboxes[subject_idx]["bbox"][0], all_bboxes[object_idx]["bbox"][0])
-            y0 = min(all_bboxes[subject_idx]["bbox"][1], all_bboxes[object_idx]["bbox"][1])
-            x1 = max(all_bboxes[subject_idx]["bbox"][2], all_bboxes[object_idx]["bbox"][2])
-            y1 = max(all_bboxes[subject_idx]["bbox"][3], all_bboxes[object_idx]["bbox"][3])
+            x0, y0, x1, y1 = expand_bbox(
+                all_bboxes[subject_idx]["bbox"], all_bboxes[object_idx]["bbox"]
+            )
 
-            object_area = abs(all_bboxes[object_idx]["bbox"][2] - all_bboxes[object_idx]["bbox"][0]) * abs(all_bboxes[object_idx]["bbox"][3] - all_bboxes[object_idx]["bbox"][1])
+            object_area = get_bbox_area(all_bboxes[object_idx]["bbox"])
             for i in range(len(all_bboxes)):
-                if i == subject_idx or all_bboxes[i]["category_id"] != subject_category_id:
+                if (
+                    i == subject_idx
+                    or all_bboxes[i]["category_id"] != subject_category_id
+                ):
                     continue
-                if _is_part_overlap([x0, y0, x1, y1], all_bboxes[i]["bbox"]) or _is_in(all_bboxes[i]["bbox"], [x0, y0, x1, y1]):
-                    i_area = abs(all_bboxes[i]["bbox"][2] - all_bboxes[i]["bbox"][0]) * abs(all_bboxes[i]["bbox"][3] - all_bboxes[i]["bbox"][1])
+                if _is_part_overlap([x0, y0, x1, y1], all_bboxes[i]["bbox"]) or _is_in(
+                    all_bboxes[i]["bbox"], [x0, y0, x1, y1]
+                ):
+                    i_area = get_bbox_area(all_bboxes[i]["bbox"])
                     if i_area >= object_area:
-                        ret = min(float("inf"), dis[i][object_idx]) 
+                        ret = min(ret, dis[i][object_idx])
             return ret
 
         subjects = self.__reduct_overlap(
@@ -190,7 +205,7 @@ class MagicModel:
             arr.sort(key=lambda x: x[0])
             if len(arr) > 0:
                 # bug: 离该subject 最近的 object 可能跨越了其它的 subject 。比如 [this subect] [some sbuject] [the nearest objec of subject]
-                if may_find_other_nearest_bbox(i, j) >= arr[0][0]:
+                if may_find_other_nearest_bbox(i, arr[0][1]) >= arr[0][0]:
                     candidates.append(arr[0][1])
                     seen.add(arr[0][1])
 
@@ -266,7 +281,12 @@ class MagicModel:
             for bbox in caption_poses:
                 embed_arr = []
                 for idx in seen:
-                    if calculate_overlap_area_in_bbox1_area_ratio(all_bboxes[idx]["bbox"], bbox) > CAPATION_OVERLAP_AREA_RATIO:
+                    if (
+                        calculate_overlap_area_in_bbox1_area_ratio(
+                            all_bboxes[idx]["bbox"], bbox
+                        )
+                        > CAPATION_OVERLAP_AREA_RATIO
+                    ):
                         embed_arr.append(idx)
 
                 if len(embed_arr) > 0:
@@ -286,7 +306,12 @@ class MagicModel:
                 caption_bbox = caption_poses[max_area_idx]
 
                 for j in seen:
-                    if calculate_overlap_area_in_bbox1_area_ratio(all_bboxes[j]["bbox"], caption_bbox) > CAPATION_OVERLAP_AREA_RATIO:
+                    if (
+                        calculate_overlap_area_in_bbox1_area_ratio(
+                            all_bboxes[j]["bbox"], caption_bbox
+                        )
+                        > CAPATION_OVERLAP_AREA_RATIO
+                    ):
                         used.add(j)
                         subject_object_relation_map[i].append(j)
 
@@ -482,6 +507,9 @@ class MagicModel:
                     blocks.append(block)
         return blocks
 
+    def get_model_list(self, page_no):
+        return self.__model_list[page_no]
+
 
 if __name__ == "__main__":
     drw = DiskReaderWriter(r"D:/project/20231108code-clean")