瀏覽代碼

Merge pull request #1701 from icecraft/fix/caption_match

fix: caption match algorithm
Xiaomeng Zhao 9 月之前
父節點
當前提交
d46b87be49
共有 1 個文件被更改,包括 113 次插入4 次删除
  1. 113 4
      magic_pdf/model/magic_model.py

+ 113 - 4
magic_pdf/model/magic_model.py

@@ -450,11 +450,120 @@ class MagicModel:
             )
         return ret
 
+
+    def __tie_up_category_by_distance_v3(
+        self,
+        page_no: int,
+        subject_category_id: int,
+        object_category_id: int,
+        priority_pos: PosRelationEnum,
+    ):
+        subjects = self.__reduct_overlap(
+            list(
+                map(
+                    lambda x: {'bbox': x['bbox'], 'score': x['score']},
+                    filter(
+                        lambda x: x['category_id'] == subject_category_id,
+                        self.__model_list[page_no]['layout_dets'],
+                    ),
+                )
+            )
+        )
+        objects = self.__reduct_overlap(
+            list(
+                map(
+                    lambda x: {'bbox': x['bbox'], 'score': x['score']},
+                    filter(
+                        lambda x: x['category_id'] == object_category_id,
+                        self.__model_list[page_no]['layout_dets'],
+                    ),
+                )
+            )
+        )
+
+        ret = []
+        N, M = len(subjects), len(objects)
+        subjects.sort(key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2)
+        objects.sort(key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2)
+
+        OBJ_IDX_OFFSET = 10000
+        SUB_BIT_KIND, OBJ_BIT_KIND = 0, 1
+        
+        all_boxes_with_idx = [(i, SUB_BIT_KIND, sub['bbox'][0], sub['bbox'][1]) for i, sub in enumerate(subjects)] + [(i + OBJ_IDX_OFFSET , OBJ_BIT_KIND, obj['bbox'][0], obj['bbox'][1]) for i, obj in enumerate(objects)]
+        seen_idx = set()
+        seen_sub_idx = set()
+        
+        while N > len(seen_sub_idx):
+            candidates = [] 
+            for idx, kind, x0, y0 in all_boxes_with_idx:
+                if idx in seen_idx:
+                    continue 
+                candidates.append((idx, kind, x0, y0))
+            
+            if len(candidates) == 0:
+                break
+            left_x = min([v[2] for v in candidates])
+            top_y =  min([v[3] for v in candidates])
+            
+            candidates.sort(key=lambda x: (x[2]-left_x) ** 2 + (x[3] - top_y) ** 2)
+
+            
+            fst_idx, fst_kind, left_x, top_y = candidates[0]
+            candidates.sort(key=lambda x: (x[2] - left_x) ** 2 + (x[3] - top_y)**2)
+            nxt = None
+            
+            for i in range(1, len(candidates)):
+                if candidates[i][1] ^ fst_kind == 1:
+                    nxt = candidates[i]
+                    break 
+            if nxt is None:
+                break
+            
+            seen_idx.add(fst_idx)
+            seen_idx.add(nxt[0])
+            if fst_kind == SUB_BIT_KIND:
+                seen_sub_idx.add(fst_idx)
+                sub_idx, obj_idx = fst_idx, nxt[0] - OBJ_IDX_OFFSET
+                
+            else:
+                seen_sub_idx.add(nxt[0])
+                sub_idx, obj_idx = nxt[0], fst_idx - OBJ_IDX_OFFSET
+            ret.append(
+                {
+                    'sub_bbox': {
+                        'bbox': subjects[sub_idx]['bbox'],
+                        'score': subjects[sub_idx]['score'],
+                    },
+                    'obj_bboxes': [
+                        {'score': objects[obj_idx]['score'], 'bbox': objects[obj_idx]['bbox']}
+                    ],
+                    'sub_idx': sub_idx,
+                }
+            )
+
+        for i in range(len(subjects)):
+            if i in seen_sub_idx:
+                continue 
+            ret.append(
+                {
+                    'sub_bbox': {
+                        'bbox': subjects[i]['bbox'],
+                        'score': subjects[i]['score'],
+                    },
+                    'obj_bboxes': [],
+                    'sub_idx': i,
+                }
+            )
+        
+        
+        return ret
+
+
     def get_imgs_v2(self, page_no: int):
-        with_captions = self.__tie_up_category_by_distance_v2(
+        with_captions = self.__tie_up_category_by_distance_v3(
             page_no, 3, 4, PosRelationEnum.BOTTOM
         )
-        with_footnotes = self.__tie_up_category_by_distance_v2(
+        with_footnotes = self.__tie_up_category_by_distance_v3(
             page_no, 3, CategoryId.ImageFootnote, PosRelationEnum.ALL
         )
         ret = []
@@ -470,10 +579,10 @@ class MagicModel:
         return ret
 
     def get_tables_v2(self, page_no: int) -> list:
-        with_captions = self.__tie_up_category_by_distance_v2(
+        with_captions = self.__tie_up_category_by_distance_v3(
             page_no, 5, 6, PosRelationEnum.UP
         )
-        with_footnotes = self.__tie_up_category_by_distance_v2(
+        with_footnotes = self.__tie_up_category_by_distance_v3(
             page_no, 5, 7, PosRelationEnum.ALL
         )
         ret = []