Prechádzať zdrojové kódy

Merge pull request #103 from icecraft/fix/update_remove_overlap

fix: update remove overlap
myhloli 1 rok pred
rodič
commit
536300fdc0

+ 1 - 0
magic_pdf/libs/drop_reason.py

@@ -22,4 +22,5 @@ class DropReason:
     SPECIAL_PDF = "special_pdf"
     PSEUDO_SINGLE_COLUMN = "pseudo_single_column" # 无法精确判断文字分栏
     CAN_NOT_DETECT_PAGE_LAYOUT="can_not_detect_page_layout" # 无法分析页面的版面
+    NEGATIVE_BBOX_AREA = "negative_bbox_area" # 缩放导致 bbox 面积为负
     

+ 50 - 11
magic_pdf/model/magic_model.py

@@ -76,7 +76,7 @@ class MagicModel:
             for j in range(N):
                 if i == j:
                     continue
-                if _is_in(bboxes[i], bboxes[j]):
+                if _is_in(bboxes[i]["bbox"], bboxes[j]["bbox"]):
                     keep[i] = False
 
         return [bboxes[i] for i in range(N) if keep[i]]
@@ -104,11 +104,26 @@ class MagicModel:
         # 再求出筛选出的 subjects 和 object 的最短距离!
         def may_find_other_nearest_bbox(subject_idx, object_idx):
             ret = float("inf")
-            x0, y0, x1, y1 = expand_bbox(
-                all_bboxes[subject_idx]["bbox"], all_bboxes[object_idx]["bbox"]
+
+            x0 = min(
+                all_bboxes[subject_idx]["bbox"][0], all_bboxes[object_idx]["bbox"][0]
+            )
+            y0 = min(
+                all_bboxes[subject_idx]["bbox"][1], all_bboxes[object_idx]["bbox"][1]
+            )
+            x1 = max(
+                all_bboxes[subject_idx]["bbox"][2], all_bboxes[object_idx]["bbox"][2]
+            )
+            y1 = max(
+                all_bboxes[subject_idx]["bbox"][3], all_bboxes[object_idx]["bbox"][3]
+            )
+
+            object_area = abs(
+                all_bboxes[object_idx]["bbox"][2] - all_bboxes[object_idx]["bbox"][0]
+            ) * abs(
+                all_bboxes[object_idx]["bbox"][3] - all_bboxes[object_idx]["bbox"][1]
             )
 
-            object_area = get_bbox_area(all_bboxes[object_idx]["bbox"])
             for i in range(len(all_bboxes)):
                 if (
                     i == subject_idx
@@ -118,15 +133,19 @@ class MagicModel:
                 if _is_part_overlap([x0, y0, x1, y1], all_bboxes[i]["bbox"]) or _is_in(
                     all_bboxes[i]["bbox"], [x0, y0, x1, y1]
                 ):
-                    i_area = get_bbox_area(all_bboxes[i]["bbox"])
+
+                    i_area = abs(
+                        all_bboxes[i]["bbox"][2] - all_bboxes[i]["bbox"][0]
+                    ) * abs(all_bboxes[i]["bbox"][3] - all_bboxes[i]["bbox"][1])
                     if i_area >= object_area:
-                        ret = min(ret, dis[i][object_idx])
+                        ret = min(float("inf"), dis[i][object_idx])
+
             return ret
 
         subjects = self.__reduct_overlap(
             list(
                 map(
-                    lambda x: x["bbox"],
+                    lambda x: {"bbox": x["bbox"], "score": x["score"]},
                     filter(
                         lambda x: x["category_id"] == subject_category_id,
                         self.__model_list[page_no]["layout_dets"],
@@ -138,7 +157,7 @@ class MagicModel:
         objects = self.__reduct_overlap(
             list(
                 map(
-                    lambda x: x["bbox"],
+                    lambda x: {"bbox": x["bbox"], "score": x["score"]},
                     filter(
                         lambda x: x["category_id"] == object_category_id,
                         self.__model_list[page_no]["layout_dets"],
@@ -148,15 +167,29 @@ class MagicModel:
         )
         subject_object_relation_map = {}
 
-        subjects.sort(key=lambda x: x[0] ** 2 + x[1] ** 2)  # get the distance !
+        subjects.sort(
+            key=lambda x: x["bbox"][0] ** 2 + x["bbox"][1] ** 2
+        )  # get the distance !
 
         all_bboxes = []
 
         for v in subjects:
-            all_bboxes.append({"category_id": subject_category_id, "bbox": v})
+            all_bboxes.append(
+                {
+                    "category_id": subject_category_id,
+                    "bbox": v["bbox"],
+                    "score": v["score"],
+                }
+            )
 
         for v in objects:
-            all_bboxes.append({"category_id": object_category_id, "bbox": v})
+            all_bboxes.append(
+                {
+                    "category_id": object_category_id,
+                    "bbox": v["bbox"],
+                    "score": v["score"],
+                }
+            )
 
         N = len(all_bboxes)
         dis = [[MAX_DIS_OF_POINT] * N for _ in range(N)]
@@ -319,6 +352,7 @@ class MagicModel:
             result = {
                 "subject_body": all_bboxes[i]["bbox"],
                 "all": all_bboxes[i]["bbox"],
+                "score": all_bboxes[i]["score"],
             }
 
             if len(subject_object_relation_map[i]) > 0:
@@ -383,6 +417,7 @@ class MagicModel:
                 "bbox": record["all"],
                 "img_body_bbox": record["subject_body"],
                 "img_caption_bbox": record.get("object_body", None),
+                "score": record["score"],
             }
             for record in records
         ]
@@ -397,6 +432,7 @@ class MagicModel:
         assert N == M
         for i in range(N):
             record = {
+                "score": with_captions[i]["score"],
                 "table_caption_bbox": with_captions[i].get("object_body", None),
                 "table_body_bbox": with_captions[i]["subject_body"],
                 "table_footnote_bbox": with_footnotes[i].get("object_body", None),
@@ -516,6 +552,9 @@ class MagicModel:
                     blocks.append(block)
         return blocks
 
+    def get_model_list(self, page_no):
+        return self.__model_list[page_no]
+
 
 
 if __name__ == "__main__":

+ 84 - 95
magic_pdf/pre_proc/remove_bbox_overlap.py

@@ -1,109 +1,98 @@
-from magic_pdf.libs.boxbase import _is_in_or_part_overlap, _is_in
-
-
-def _remove_overlap_between_bbox_for_span(spans):
-    res = []
-
-    keeps = [True] * len(spans)
-    for i in range(len(spans)):
-        for j in range(len(spans)):
+from magic_pdf.libs.boxbase import _is_in_or_part_overlap, _is_in, _is_part_overlap
+from magic_pdf.libs.drop_reason import DropReason
+
+def _remove_overlap_between_bbox(bbox1, bbox2):
+   if _is_part_overlap(bbox1, bbox2):
+        ix0, iy0, ix1, iy1 = bbox1
+        x0, y0, x1, y1 = bbox2
+
+        diff_x = min(x1, ix1) - max(x0, ix0)
+        diff_y = min(y1, iy1) - max(y0, iy0)
+
+        if diff_y > diff_x:
+            if x1 >= ix1:
+                mid = (x0 + ix1) // 2
+                ix1 = min(mid - 0.25, ix1)
+                x0 = max(mid + 0.25, x0)
+            else:
+                mid = (ix0 + x1) // 2
+                ix0 = max(mid + 0.25, ix0)
+                x1 = min(mid - 0.25, x1)
+        else:
+            if y1 >= iy1:
+                mid = (y0 + iy1) // 2
+                y0 = max(mid + 0.25, y0)
+                iy1 = min(iy1, mid-0.25)
+            else:
+                mid = (iy0 + y1) // 2
+                y1 = min(y1, mid-0.25)
+                iy0 = max(mid + 0.25, iy0)
+
+        if ix1 > ix0 and iy1 > iy0 and y1 > y0 and x1 > x0:
+            bbox1 = [ix0, iy0, ix1, iy1]
+            bbox2 = [x0, y0, x1, y1]
+            return bbox1, bbox2, None
+        else:
+            return bbox1, bbox2, DropReason.NEGATIVE_BBOX_AREA
+   else:
+       return bbox1, bbox2, None
+
+
+def _remove_overlap_between_bboxes(arr):
+    drop_reasons = []
+    N = len(arr)
+    keeps = [True] * N
+    res = [None] * N
+    for i in range(N):
+        for j in range(N):
             if i == j:
                 continue
-            if _is_in(spans[i]["bbox"], spans[j]["bbox"]):
+            if _is_in(arr[i]["bbox"], arr[j]["bbox"]):
                 keeps[i] = False
 
-    for idx, v in enumerate(spans):
+    for idx, v in enumerate(arr):
         if not keeps[idx]:
             continue
-
-        for i in range(len(res)):
-            if _is_in(v["bbox"], res[i]["bbox"]):
+        for i in range(N):
+            if res[i] is None:
                 continue
-            if _is_in_or_part_overlap(res[i]["bbox"], v["bbox"]):
-                ix0, iy0, ix1, iy1 = res[i]["bbox"]
-                x0, y0, x1, y1 = v["bbox"]
-
-                diff_x = min(x1, ix1) - max(x0, ix0)
-                diff_y = min(y1, iy1) - max(y0, iy0)
-
-                if diff_y > diff_x:
-                    if x1 >= ix1:
-                        mid = (x0 + ix1) // 2
-                        ix1 = min(mid - 0.25, ix1)
-                        x0 = max(mid + 0.25, x0)
-                    else:
-                        mid = (ix0 + x1) // 2
-                        ix0 = max(mid + 0.25, ix0)
-                        x1 = min(mid - 0.25, x1)
+        
+            bbox1, bbox2, drop_reason = _remove_overlap_between_bbox(v["bbox"], res[i]["bbox"])
+            if drop_reason is None:
+                v["bbox"] = bbox1
+                res[i]["bbox"] = bbox2
+            else:
+                if v["score"] > res[i]["score"]:
+                    keeps[i] = False
+                    res[i] = False
                 else:
-                    if y1 >= iy1:
-                        mid = (y0 + iy1) // 2
-                        y0 = max(mid + 0.25, y0)
-                        iy1 = min(iy1, mid-0.25)
-                    else:
-                        mid = (iy0 + y1) // 2
-                        y1 = min(y1, mid-0.25)
-                        iy0 = max(mid + 0.25, iy0)
-                res[i]["bbox"] = [ix0, iy0, ix1, iy1]
-                v["bbox"] = [x0, y0, x1, y1]
-
-        res.append(v)
-    return res
-
-
-def _remove_overlap_between_bbox_for_block(all_bboxes):
-    res = []
-
-    keeps = [True] * len(all_bboxes)
-    for i in range(len(all_bboxes)):
-        for j in range(len(all_bboxes)):
-            if i == j:
-                continue
-            if _is_in(all_bboxes[i][:4], all_bboxes[j][:4]):
-                keeps[i] = False
-
-    for idx, v in enumerate(all_bboxes):
-        if not keeps[idx]:
-            continue
-
-        for i in range(len(res)):
-            if _is_in(v[:4], res[i][:4]):
-                continue
-            if _is_in_or_part_overlap(res[i][:4], v[:4]):
-                ix0, iy0, ix1, iy1 = res[i][:4]
-                x0, y0, x1, y1 = v[:4]
-
-                diff_x = min(x1, ix1) - max(x0, ix0)
-                diff_y = min(y1, iy1) - max(y0, iy0)
-
-                if diff_y > diff_x:
-                    if x1 >= ix1:
-                        mid = (x0 + ix1) // 2
-                        ix1 = min(mid - 0.25, ix1)
-                        x0 = max(mid + 0.25, x0)
-                    else:
-                        mid = (ix0 + x1) // 2
-                        ix0 = max(mid + 0.25, ix0)
-                        x1 = min(mid - 0.25, x1)
-                else:
-                    if y1 >= iy1:
-                        mid = (y0 + iy1) // 2
-                        y0 = max(mid + 0.25, y0)
-                        iy1 = min(iy1, mid-0.25)
-                    else:
-                        mid = (iy0 + y1) // 2
-                        y1 = min(y1, mid-0.25)
-                        iy0 = max(mid + 0.25, iy0)
-                res[i][:4] = [ix0, iy0, ix1, iy1]
-                v[:4] = [x0, y0, x1, y1]
-
-        res.append(v)
-    return res
+                    keeps[idx] = False
+                drop_reasons.append(drop_reasons)
+        if keeps[idx]:
+            res[idx] = v
+    return res, drop_reasons
 
 
 def remove_overlap_between_bbox_for_span(spans):
-    return _remove_overlap_between_bbox_for_span(spans)
+    arr = [{"bbox": span["bbox"], "score": span.get("score", 0.1)} for span in spans ]
+    res, drop_reasons = _remove_overlap_between_bboxes(arr)
+    ret = []
+    for i in range(len(res)):
+        if res[i] is None:
+            continue
+        spans[i]["bbox"] = res[i]["bbox"]
+        ret.append(spans[i])
+    return ret, drop_reasons
 
 
 def remove_overlap_between_bbox_for_block(all_bboxes):
-    return _remove_overlap_between_bbox_for_block(all_bboxes)
+    arr = [{"bbox": bbox[:4], "score": bbox[-1]} for bbox in all_bboxes ]
+    res, drop_reasons = _remove_overlap_between_bboxes(arr)
+    ret = []
+    for i in range(len(res)):
+        if res[i] is None:
+            continue
+        all_bboxes[i][:4] = res[i]["bbox"]
+        ret.append(all_bboxes[i])
+    return ret, drop_reasons
+