Selaa lähdekoodia

fix: using overlap area ratio to calculate box relation when build figure/caption relations

许瑞 1 vuosi sitten
vanhempi
commit
96d17cb010
1 muutettua tiedostoa jossa 42 lisäystä ja 28 poistoa
  1. 42 28
      magic_pdf/model/magic_model.py

+ 42 - 28
magic_pdf/model/magic_model.py

@@ -10,9 +10,16 @@ from magic_pdf.libs.ocr_content_type import ContentType
 from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
 from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
 from magic_pdf.libs.math import float_gt
-from magic_pdf.libs.boxbase import _is_in, bbox_relative_pos, bbox_distance
+from magic_pdf.libs.boxbase import (
+    _is_in,
+    bbox_relative_pos,
+    bbox_distance,
+    _is_part_overlap,
+    calculate_overlap_area_in_bbox1_area_ratio,
+)
 from magic_pdf.libs.ModelBlockTypeEnum import ModelBlockTypeEnum
 
+CAPATION_OVERLAP_AREA_RATIO = 0.6
 
 class MagicModel:
     """
@@ -74,13 +81,13 @@ class MagicModel:
         return [bboxes[i] for i in range(N) if keep[i]]
 
     def __tie_up_category_by_distance(
-            self, page_no, subject_category_id, object_category_id
+        self, page_no, subject_category_id, object_category_id
     ):
         """
         假定每个 subject 最多有一个 object (可以有多个相邻的 object 合并为单个 object),每个 object 只能属于一个 subject
         """
         ret = []
-        MAX_DIS_OF_POINT = 10 ** 9 + 7
+        MAX_DIS_OF_POINT = 10**9 + 7
 
         subjects = self.__reduct_overlap(
             list(
@@ -123,8 +130,8 @@ class MagicModel:
         for i in range(N):
             for j in range(i):
                 if (
-                        all_bboxes[i]["category_id"] == subject_category_id
-                        and all_bboxes[j]["category_id"] == subject_category_id
+                    all_bboxes[i]["category_id"] == subject_category_id
+                    and all_bboxes[j]["category_id"] == subject_category_id
                 ):
                     continue
 
@@ -154,9 +161,9 @@ class MagicModel:
                 if pos_flag_count > 1:
                     continue
                 if (
-                        all_bboxes[j]["category_id"] != object_category_id
-                        or j in used
-                        or dis[i][j] == MAX_DIS_OF_POINT
+                    all_bboxes[j]["category_id"] != object_category_id
+                    or j in used
+                    or dis[i][j] == MAX_DIS_OF_POINT
                 ):
                     continue
                 arr.append((dis[i][j], j))
@@ -185,10 +192,10 @@ class MagicModel:
                         continue
 
                     if (
-                            all_bboxes[k]["category_id"] != object_category_id
-                            or k in used
-                            or k in seen
-                            or dis[j][k] == MAX_DIS_OF_POINT
+                        all_bboxes[k]["category_id"] != object_category_id
+                        or k in used
+                        or k in seen
+                        or dis[j][k] == MAX_DIS_OF_POINT
                     ):
                         continue
                     is_nearest = True
@@ -238,7 +245,7 @@ class MagicModel:
             for bbox in caption_poses:
                 embed_arr = []
                 for idx in seen:
-                    if _is_in(all_bboxes[idx]["bbox"], bbox):
+                    if calculate_overlap_area_in_bbox1_area_ratio(all_bboxes[idx]["bbox"], bbox) > CAPATION_OVERLAP_AREA_RATIO:
                         embed_arr.append(idx)
 
                 if len(embed_arr) > 0:
@@ -258,7 +265,7 @@ class MagicModel:
                 caption_bbox = caption_poses[max_area_idx]
 
                 for j in seen:
-                    if _is_in(all_bboxes[j]["bbox"], caption_bbox):
+                    if calculate_overlap_area_in_bbox1_area_ratio(all_bboxes[j]["bbox"], caption_bbox) > CAPATION_OVERLAP_AREA_RATIO:
                         used.add(j)
                         subject_object_relation_map[i].append(j)
 
@@ -312,8 +319,8 @@ class MagicModel:
             candidates = []
             for j in range(N):
                 if (
-                        all_bboxes[j]["category_id"] != subject_category_id
-                        or j in with_caption_subject
+                    all_bboxes[j]["category_id"] != subject_category_id
+                    or j in with_caption_subject
                 ):
                     continue
                 candidates.append((dis[i][j], j))
@@ -335,7 +342,7 @@ class MagicModel:
         ]
 
     def get_tables(
-            self, page_no: int
+        self, page_no: int
     ) -> list:  # 3个坐标, caption, table主体,table-note
         with_captions, _ = self.__tie_up_category_by_distance(page_no, 5, 6)
         with_footnotes, _ = self.__tie_up_category_by_distance(page_no, 5, 7)
@@ -358,9 +365,15 @@ class MagicModel:
         return ret
 
     def get_equations(self, page_no: int) -> list:  # 有坐标,也有字
-        inline_equations = self.__get_blocks_by_type(ModelBlockTypeEnum.EMBEDDING.value, page_no, ["latex"])
-        interline_equations = self.__get_blocks_by_type(ModelBlockTypeEnum.ISOLATED.value, page_no, ["latex"])
-        interline_equations_blocks = self.__get_blocks_by_type(ModelBlockTypeEnum.ISOLATE_FORMULA.value, page_no)
+        inline_equations = self.__get_blocks_by_type(
+            ModelBlockTypeEnum.EMBEDDING.value, page_no, ["latex"]
+        )
+        interline_equations = self.__get_blocks_by_type(
+            ModelBlockTypeEnum.ISOLATED.value, page_no, ["latex"]
+        )
+        interline_equations_blocks = self.__get_blocks_by_type(
+            ModelBlockTypeEnum.ISOLATE_FORMULA.value, page_no
+        )
         return inline_equations, interline_equations, interline_equations_blocks
 
     def get_discarded(self, page_no: int) -> list:  # 自研模型,只有坐标
@@ -382,7 +395,7 @@ class MagicModel:
         for layout_det in layout_dets:
             if layout_det["category_id"] == "15":
                 span = {
-                    "bbox": layout_det['bbox'],
+                    "bbox": layout_det["bbox"],
                     "content": layout_det["text"],
                 }
                 text_spans.append(span)
@@ -402,9 +415,7 @@ class MagicModel:
         for layout_det in layout_dets:
             category_id = layout_det["category_id"]
             if category_id in allow_category_id_list:
-                span = {
-                    "bbox": layout_det['bbox']
-                }
+                span = {"bbox": layout_det["bbox"]}
                 if category_id == 3:
                     span["type"] = ContentType.Image
                 elif category_id == 5:
@@ -429,7 +440,9 @@ class MagicModel:
         page_h = page.rect.height
         return page_w, page_h
 
-    def __get_blocks_by_type(self, type: int, page_no: int, extra_col: list[str] = []) -> list:
+    def __get_blocks_by_type(
+        self, type: int, page_no: int, extra_col: list[str] = []
+    ) -> list:
         blocks = []
         for page_dict in self.__model_list:
             layout_dets = page_dict.get("layout_dets", [])
@@ -442,14 +455,15 @@ class MagicModel:
                 bbox = item.get("bbox", None)
 
                 if category_id == type:
-                    block = {
-                        "bbox": bbox
-                    }
+                    block = {"bbox": bbox}
                     for col in extra_col:
                         block[col] = item.get(col, None)
                     blocks.append(block)
         return blocks
 
+    def get_model_list(self, page_no):
+        return self.__model_list[page_no]
+
 
 if __name__ == "__main__":
     drw = DiskReaderWriter(r"D:/project/20231108code-clean")