Преглед на файлове

fix: udpate remove overlap

许瑞 преди 1 година
родител
ревизия
8193128f73
променени са 1 файла, в които са добавени 61 реда и са изтрити 17 реда
  1. 61 17
      magic_pdf/model/magic_model.py

+ 61 - 17
magic_pdf/model/magic_model.py

@@ -21,6 +21,7 @@ from magic_pdf.libs.ModelBlockTypeEnum import ModelBlockTypeEnum
 
 CAPATION_OVERLAP_AREA_RATIO = 0.6
 
+
 class MagicModel:
     """
     每个函数没有得到元素的时候返回空list
@@ -75,7 +76,7 @@ class MagicModel:
             for j in range(N):
                 if i == j:
                     continue
-                if _is_in(bboxes[i], bboxes[j]):
+                if _is_in(bboxes[i]["bbox"], bboxes[j]["bbox"]):
                     keep[i] = False
 
         return [bboxes[i] for i in range(N) if keep[i]]
@@ -93,25 +94,44 @@ class MagicModel:
         # 再求出筛选出的 subjects 和 object 的最短距离!
         def may_find_other_nearest_bbox(subject_idx, object_idx):
             ret = float("inf")
-            x0 = min(all_bboxes[subject_idx]["bbox"][0], all_bboxes[object_idx]["bbox"][0])
-            y0 = min(all_bboxes[subject_idx]["bbox"][1], all_bboxes[object_idx]["bbox"][1])
-            x1 = max(all_bboxes[subject_idx]["bbox"][2], all_bboxes[object_idx]["bbox"][2])
-            y1 = max(all_bboxes[subject_idx]["bbox"][3], all_bboxes[object_idx]["bbox"][3])
+            x0 = min(
+                all_bboxes[subject_idx]["bbox"][0], all_bboxes[object_idx]["bbox"][0]
+            )
+            y0 = min(
+                all_bboxes[subject_idx]["bbox"][1], all_bboxes[object_idx]["bbox"][1]
+            )
+            x1 = max(
+                all_bboxes[subject_idx]["bbox"][2], all_bboxes[object_idx]["bbox"][2]
+            )
+            y1 = max(
+                all_bboxes[subject_idx]["bbox"][3], all_bboxes[object_idx]["bbox"][3]
+            )
 
-            object_area = abs(all_bboxes[object_idx]["bbox"][2] - all_bboxes[object_idx]["bbox"][0]) * abs(all_bboxes[object_idx]["bbox"][3] - all_bboxes[object_idx]["bbox"][1])
+            object_area = abs(
+                all_bboxes[object_idx]["bbox"][2] - all_bboxes[object_idx]["bbox"][0]
+            ) * abs(
+                all_bboxes[object_idx]["bbox"][3] - all_bboxes[object_idx]["bbox"][1]
+            )
             for i in range(len(all_bboxes)):
-                if i == subject_idx or all_bboxes[i]["category_id"] != subject_category_id:
+                if (
+                    i == subject_idx
+                    or all_bboxes[i]["category_id"] != subject_category_id
+                ):
                     continue
-                if _is_part_overlap([x0, y0, x1, y1], all_bboxes[i]["bbox"]) or _is_in(all_bboxes[i]["bbox"], [x0, y0, x1, y1]):
-                    i_area = abs(all_bboxes[i]["bbox"][2] - all_bboxes[i]["bbox"][0]) * abs(all_bboxes[i]["bbox"][3] - all_bboxes[i]["bbox"][1])
+                if _is_part_overlap([x0, y0, x1, y1], all_bboxes[i]["bbox"]) or _is_in(
+                    all_bboxes[i]["bbox"], [x0, y0, x1, y1]
+                ):
+                    i_area = abs(
+                        all_bboxes[i]["bbox"][2] - all_bboxes[i]["bbox"][0]
+                    ) * abs(all_bboxes[i]["bbox"][3] - all_bboxes[i]["bbox"][1])
                     if i_area >= object_area:
-                        ret = min(float("inf"), dis[i][object_idx]) 
+                        ret = min(float("inf"), dis[i][object_idx])
             return ret
 
         subjects = self.__reduct_overlap(
             list(
                 map(
-                    lambda x: x["bbox"],
+                    lambda x: {"bbox": x["bbox"], "score": x["score"]},
                     filter(
                         lambda x: x["category_id"] == subject_category_id,
                         self.__model_list[page_no]["layout_dets"],
@@ -123,7 +143,7 @@ class MagicModel:
         objects = self.__reduct_overlap(
             list(
                 map(
-                    lambda x: x["bbox"],
+                    lambda x: {"bbox": x["bbox"], "score": x["score"]},
                     filter(
                         lambda x: x["category_id"] == object_category_id,
                         self.__model_list[page_no]["layout_dets"],
@@ -133,15 +153,29 @@ class MagicModel:
         )
         subject_object_relation_map = {}
 
-        subjects.sort(key=lambda x: x[0] ** 2 + x[1] ** 2)  # get the distance !
+        subjects.sort(
+            key=lambda x: x["bbox"][0] ** 2 + x["bbox"][1] ** 2
+        )  # get the distance !
 
         all_bboxes = []
 
         for v in subjects:
-            all_bboxes.append({"category_id": subject_category_id, "bbox": v, "score": v["score"]})
+            all_bboxes.append(
+                {
+                    "category_id": subject_category_id,
+                    "bbox": v["bbox"],
+                    "score": v["score"],
+                }
+            )
 
         for v in objects:
-            all_bboxes.append({"category_id": object_category_id, "bbox": v, "score": v["score"]})
+            all_bboxes.append(
+                {
+                    "category_id": object_category_id,
+                    "bbox": v["bbox"],
+                    "score": v["score"],
+                }
+            )
 
         N = len(all_bboxes)
         dis = [[MAX_DIS_OF_POINT] * N for _ in range(N)]
@@ -266,7 +300,12 @@ class MagicModel:
             for bbox in caption_poses:
                 embed_arr = []
                 for idx in seen:
-                    if calculate_overlap_area_in_bbox1_area_ratio(all_bboxes[idx]["bbox"], bbox) > CAPATION_OVERLAP_AREA_RATIO:
+                    if (
+                        calculate_overlap_area_in_bbox1_area_ratio(
+                            all_bboxes[idx]["bbox"], bbox
+                        )
+                        > CAPATION_OVERLAP_AREA_RATIO
+                    ):
                         embed_arr.append(idx)
 
                 if len(embed_arr) > 0:
@@ -286,7 +325,12 @@ class MagicModel:
                 caption_bbox = caption_poses[max_area_idx]
 
                 for j in seen:
-                    if calculate_overlap_area_in_bbox1_area_ratio(all_bboxes[j]["bbox"], caption_bbox) > CAPATION_OVERLAP_AREA_RATIO:
+                    if (
+                        calculate_overlap_area_in_bbox1_area_ratio(
+                            all_bboxes[j]["bbox"], caption_bbox
+                        )
+                        > CAPATION_OVERLAP_AREA_RATIO
+                    ):
                         used.add(j)
                         subject_object_relation_map[i].append(j)