Эх сурвалжийг харах

fix: update remove overlap

许瑞 1 жил өмнө
parent
commit
6b28ec82b8

+ 1 - 0
magic_pdf/libs/drop_reason.py

@@ -22,4 +22,5 @@ class DropReason:
     SPECIAL_PDF = "special_pdf"
     PSEUDO_SINGLE_COLUMN = "pseudo_single_column" # 无法精确判断文字分栏
     CAN_NOT_DETECT_PAGE_LAYOUT="can_not_detect_page_layout" # 无法分析页面的版面
+    NEGATIVE_BBOX_AREA = "negative_bbox_area" # 缩放导致 bbox 面积为负
     

+ 8 - 2
magic_pdf/model/magic_model.py

@@ -138,10 +138,10 @@ class MagicModel:
         all_bboxes = []
 
         for v in subjects:
-            all_bboxes.append({"category_id": subject_category_id, "bbox": v})
+            all_bboxes.append({"category_id": subject_category_id, "bbox": v, "score": v["score"]})
 
         for v in objects:
-            all_bboxes.append({"category_id": object_category_id, "bbox": v})
+            all_bboxes.append({"category_id": object_category_id, "bbox": v, "score": v["score"]})
 
         N = len(all_bboxes)
         dis = [[MAX_DIS_OF_POINT] * N for _ in range(N)]
@@ -294,6 +294,7 @@ class MagicModel:
             result = {
                 "subject_body": all_bboxes[i]["bbox"],
                 "all": all_bboxes[i]["bbox"],
+                "score": all_bboxes[i]["score"],
             }
 
             if len(subject_object_relation_map[i]) > 0:
@@ -358,6 +359,7 @@ class MagicModel:
                 "bbox": record["all"],
                 "img_body_bbox": record["subject_body"],
                 "img_caption_bbox": record.get("object_body", None),
+                "score": record["score"],
             }
             for record in records
         ]
@@ -372,6 +374,7 @@ class MagicModel:
         assert N == M
         for i in range(N):
             record = {
+                "score": with_captions[i]["score"],
                 "table_caption_bbox": with_captions[i].get("object_body", None),
                 "table_body_bbox": with_captions[i]["subject_body"],
                 "table_footnote_bbox": with_footnotes[i].get("object_body", None),
@@ -482,6 +485,9 @@ class MagicModel:
                     blocks.append(block)
         return blocks
 
+    def get_model_list(self, page_no):
+        return self.__model_list[page_no]
+
 
 if __name__ == "__main__":
     drw = DiskReaderWriter(r"D:/project/20231108code-clean")

+ 84 - 95
magic_pdf/pre_proc/remove_bbox_overlap.py

@@ -1,109 +1,98 @@
-from magic_pdf.libs.boxbase import _is_in_or_part_overlap, _is_in
-
-
-def _remove_overlap_between_bbox_for_span(spans):
-    res = []
-
-    keeps = [True] * len(spans)
-    for i in range(len(spans)):
-        for j in range(len(spans)):
+from magic_pdf.libs.boxbase import _is_in_or_part_overlap, _is_in, _is_part_overlap
+from magic_pdf.libs.drop_reason import DropReason
+
+def _remove_overlap_between_bbox(bbox1, bbox2):
+   if _is_part_overlap(bbox1, bbox2):
+        ix0, iy0, ix1, iy1 = bbox1
+        x0, y0, x1, y1 = bbox2
+
+        diff_x = min(x1, ix1) - max(x0, ix0)
+        diff_y = min(y1, iy1) - max(y0, iy0)
+
+        if diff_y > diff_x:
+            if x1 >= ix1:
+                mid = (x0 + ix1) // 2
+                ix1 = min(mid - 0.25, ix1)
+                x0 = max(mid + 0.25, x0)
+            else:
+                mid = (ix0 + x1) // 2
+                ix0 = max(mid + 0.25, ix0)
+                x1 = min(mid - 0.25, x1)
+        else:
+            if y1 >= iy1:
+                mid = (y0 + iy1) // 2
+                y0 = max(mid + 0.25, y0)
+                iy1 = min(iy1, mid-0.25)
+            else:
+                mid = (iy0 + y1) // 2
+                y1 = min(y1, mid-0.25)
+                iy0 = max(mid + 0.25, iy0)
+
+        if ix1 > ix0 and iy1 > iy0 and y1 > y0 and x1 > x0:
+            bbox1 = [ix0, iy0, ix1, iy1]
+            bbox2 = [x0, y0, x1, y1]
+            return bbox1, bbox2, None
+        else:
+            return bbox1, bbox2, DropReason.NEGATIVE_BBOX_AREA
+   else:
+       return bbox1, bbox2, None
+
+
+def _remove_overlap_between_bboxes(arr):
+    drop_reasons = []
+    N = len(arr)
+    keeps = [True] * N
+    res = [None] * N
+    for i in range(N):
+        for j in range(N):
             if i == j:
                 continue
-            if _is_in(spans[i]["bbox"], spans[j]["bbox"]):
+            if _is_in(arr[i]["bbox"], arr[j]["bbox"]):
                 keeps[i] = False
 
-    for idx, v in enumerate(spans):
+    for idx, v in enumerate(arr):
         if not keeps[idx]:
             continue
-
-        for i in range(len(res)):
-            if _is_in(v["bbox"], res[i]["bbox"]):
+        for i in range(N):
+            if res[i] is None:
                 continue
-            if _is_in_or_part_overlap(res[i]["bbox"], v["bbox"]):
-                ix0, iy0, ix1, iy1 = res[i]["bbox"]
-                x0, y0, x1, y1 = v["bbox"]
-
-                diff_x = min(x1, ix1) - max(x0, ix0)
-                diff_y = min(y1, iy1) - max(y0, iy0)
-
-                if diff_y > diff_x:
-                    if x1 >= ix1:
-                        mid = (x0 + ix1) // 2
-                        ix1 = min(mid - 0.25, ix1)
-                        x0 = max(mid + 0.25, x0)
-                    else:
-                        mid = (ix0 + x1) // 2
-                        ix0 = max(mid + 0.25, ix0)
-                        x1 = min(mid - 0.25, x1)
+        
+            bbox1, bbox2, drop_reason = _remove_overlap_between_bbox(v["bbox"], res[i]["bbox"])
+            if drop_reason is None:
+                v["bbox"] = bbox1
+                res[i]["bbox"] = bbox2
+            else:
+                if v["score"] > res[i]["score"]:
+                    keeps[i] = False
+                    res[i] = False
                 else:
-                    if y1 >= iy1:
-                        mid = (y0 + iy1) // 2
-                        y0 = max(mid + 0.25, y0)
-                        iy1 = min(iy1, mid-0.25)
-                    else:
-                        mid = (iy0 + y1) // 2
-                        y1 = min(y1, mid-0.25)
-                        iy0 = max(mid + 0.25, iy0)
-                res[i]["bbox"] = [ix0, iy0, ix1, iy1]
-                v["bbox"] = [x0, y0, x1, y1]
-
-        res.append(v)
-    return res
-
-
-def _remove_overlap_between_bbox_for_block(all_bboxes):
-    res = []
-
-    keeps = [True] * len(all_bboxes)
-    for i in range(len(all_bboxes)):
-        for j in range(len(all_bboxes)):
-            if i == j:
-                continue
-            if _is_in(all_bboxes[i][:4], all_bboxes[j][:4]):
-                keeps[i] = False
-
-    for idx, v in enumerate(all_bboxes):
-        if not keeps[idx]:
-            continue
-
-        for i in range(len(res)):
-            if _is_in(v[:4], res[i][:4]):
-                continue
-            if _is_in_or_part_overlap(res[i][:4], v[:4]):
-                ix0, iy0, ix1, iy1 = res[i][:4]
-                x0, y0, x1, y1 = v[:4]
-
-                diff_x = min(x1, ix1) - max(x0, ix0)
-                diff_y = min(y1, iy1) - max(y0, iy0)
-
-                if diff_y > diff_x:
-                    if x1 >= ix1:
-                        mid = (x0 + ix1) // 2
-                        ix1 = min(mid - 0.25, ix1)
-                        x0 = max(mid + 0.25, x0)
-                    else:
-                        mid = (ix0 + x1) // 2
-                        ix0 = max(mid + 0.25, ix0)
-                        x1 = min(mid - 0.25, x1)
-                else:
-                    if y1 >= iy1:
-                        mid = (y0 + iy1) // 2
-                        y0 = max(mid + 0.25, y0)
-                        iy1 = min(iy1, mid-0.25)
-                    else:
-                        mid = (iy0 + y1) // 2
-                        y1 = min(y1, mid-0.25)
-                        iy0 = max(mid + 0.25, iy0)
-                res[i][:4] = [ix0, iy0, ix1, iy1]
-                v[:4] = [x0, y0, x1, y1]
-
-        res.append(v)
-    return res
+                    keeps[idx] = False
+                drop_reasons.append(drop_reasons)
+        if keeps[idx]:
+            res[idx] = v
+    return res, drop_reasons
 
 
 def remove_overlap_between_bbox_for_span(spans):
-    return _remove_overlap_between_bbox_for_span(spans)
+    arr = [{"bbox": span["bbox"], "score": span.get("score", 0.1)} for span in spans ]
+    res, drop_reasons = _remove_overlap_between_bboxes(arr)
+    ret = []
+    for i in range(len(res)):
+        if res[i] is None:
+            continue
+        spans[i]["bbox"] = res[i]["bbox"]
+        ret.append(spans[i])
+    return ret, drop_reasons
 
 
 def remove_overlap_between_bbox_for_block(all_bboxes):
-    return _remove_overlap_between_bbox_for_block(all_bboxes)
+    arr = [{"bbox": bbox[:4], "score": bbox[-1]} for bbox in all_bboxes ]
+    res, drop_reasons = _remove_overlap_between_bboxes(arr)
+    ret = []
+    for i in range(len(res)):
+        if res[i] is None:
+            continue
+        all_bboxes[i][:4] = res[i]["bbox"]
+        ret.append(all_bboxes[i])
+    return ret, drop_reasons
+