瀏覽代碼

fix: add priority match rule

icecraft 1 年之前
父節點
當前提交
34a13a898b
共有 1 個文件被更改,包括 61 次插入8 次删除
  1. 61 8
      magic_pdf/model/magic_model.py

+ 61 - 8
magic_pdf/model/magic_model.py

@@ -1,3 +1,4 @@
+import enum
 import json
 
 from magic_pdf.data.dataset import Dataset
@@ -18,6 +19,14 @@ CAPATION_OVERLAP_AREA_RATIO = 0.6
 MERGE_BOX_OVERLAP_AREA_RATIO = 1.1
 
 
+class PosRelationEnum(enum.Enum):
+    LEFT = 'left'
+    RIGHT = 'right'
+    UP = 'up'
+    BOTTOM = 'bottom'
+    ALL = 'all'
+
+
 class MagicModel:
     """每个函数没有得到元素的时候返回空list."""
 
@@ -591,9 +600,23 @@ class MagicModel:
         return ret, total_subject_object_dis
 
     def __tie_up_category_by_distance_v2(
-        self, page_no, subject_category_id, object_category_id
+        self,
+        page_no: int,
+        subject_category_id: int,
+        object_category_id: int,
+        priority_pos: PosRelationEnum,
     ):
+        """_summary_
 
+        Args:
+            page_no (int): _description_
+            subject_category_id (int): _description_
+            object_category_id (int): _description_
+            priority_pos (PosRelationEnum): _description_
+
+        Returns:
+            _type_: _description_
+        """
         AXIS_MULPLICITY = 0.5
         subjects = self.__reduct_overlap(
             list(
@@ -680,6 +703,27 @@ class MagicModel:
                             j,
                             bbox_distance(obj['bbox'], sub['bbox']),
                         ]
+
+            if (
+                dis_by_directions['top'][i][1] != float('inf')
+                and dis_by_directions['bottom'][i][1] != float('inf')
+                and priority_pos in (PosRelationEnum.BOTTOM, PosRelationEnum.UP)
+            ):
+                RATIO = 3
+                if (
+                    abs(
+                        dis_by_directions['top'][i][1]
+                        - dis_by_directions['bottom'][i][1]
+                    )
+                    < RATIO * axis_unit
+                ):
+
+                    if priority_pos == PosRelationEnum.BOTTOM:
+                        sub_obj_map_h[dis_by_directions['bottom'][i][0]].append(i)
+                    else:
+                        sub_obj_map_h[dis_by_directions['top'][i][0]].append(i)
+                    continue
+
             if dis_by_directions['left'][i][1] != float('inf') or dis_by_directions[
                 'right'
             ][i][1] != float('inf'):
@@ -735,9 +779,12 @@ class MagicModel:
 
                         top_bottom_x_axis = top_bottom[2] - top_bottom[0]
                         bottom_top_x_axis = bottom_top[2] - bottom_top[0]
-                        if abs(top_bottom_x_axis - l_x_axis) + dis_by_directions['bottom'][i][1] > abs(
-                            bottom_top_x_axis - l_x_axis
-                        ) + dis_by_directions['top'][i][1]:
+                        if (
+                            abs(top_bottom_x_axis - l_x_axis)
+                            + dis_by_directions['bottom'][i][1]
+                            > abs(bottom_top_x_axis - l_x_axis)
+                            + dis_by_directions['top'][i][1]
+                        ):
                             top_or_bottom = dis_by_directions['top'][i]
                         else:
                             top_or_bottom = dis_by_directions['bottom'][i]
@@ -798,9 +845,11 @@ class MagicModel:
         return ret
 
     def get_imgs_v2(self, page_no: int):
-        with_captions = self.__tie_up_category_by_distance_v2(page_no, 3, 4)
+        with_captions = self.__tie_up_category_by_distance_v2(
+            page_no, 3, 4, PosRelationEnum.BOTTOM
+        )
         with_footnotes = self.__tie_up_category_by_distance_v2(
-            page_no, 3, CategoryId.ImageFootnote
+            page_no, 3, CategoryId.ImageFootnote, PosRelationEnum.ALL
         )
         ret = []
         for v in with_captions:
@@ -815,8 +864,12 @@ class MagicModel:
         return ret
 
     def get_tables_v2(self, page_no: int) -> list:
-        with_captions = self.__tie_up_category_by_distance_v2(page_no, 5, 6)
-        with_footnotes = self.__tie_up_category_by_distance_v2(page_no, 5, 7)
+        with_captions = self.__tie_up_category_by_distance_v2(
+            page_no, 5, 6, PosRelationEnum.UP
+        )
+        with_footnotes = self.__tie_up_category_by_distance_v2(
+            page_no, 5, 7, PosRelationEnum.ALL
+        )
         ret = []
         for v in with_captions:
             record = {