Selaa lähdekoodia

Merge pull request #160 from icecraft/fix/figure_caption_relation

fix: object cluster algorithm
Xiaomeng Zhao 1 vuosi sitten
vanhempi
commit
95e7e3a76c
1 muutettua tiedostoa jossa 44 lisäystä ja 40 poistoa
  1. 44 40
      magic_pdf/model/magic_model.py

+ 44 - 40
magic_pdf/model/magic_model.py

@@ -15,7 +15,8 @@ from magic_pdf.libs.boxbase import (
     bbox_relative_pos,
     bbox_distance,
     _is_part_overlap,
-    calculate_overlap_area_in_bbox1_area_ratio, calculate_iou,
+    calculate_overlap_area_in_bbox1_area_ratio,
+    calculate_iou,
 )
 from magic_pdf.libs.ModelBlockTypeEnum import ModelBlockTypeEnum
 
@@ -78,9 +79,23 @@ class MagicModel:
                 for layout_det2 in layout_dets:
                     if layout_det1 == layout_det2:
                         continue
-                    if layout_det1["category_id"] in [0,1,2,3,4,5,6,7,8,9] and layout_det2["category_id"] in [0,1,2,3,4,5,6,7,8,9]:
-                        if calculate_iou(layout_det1['bbox'], layout_det2['bbox']) > 0.9:
-                            if layout_det1['score'] < layout_det2['score']:
+                    if layout_det1["category_id"] in [
+                        0,
+                        1,
+                        2,
+                        3,
+                        4,
+                        5,
+                        6,
+                        7,
+                        8,
+                        9,
+                    ] and layout_det2["category_id"] in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]:
+                        if (
+                            calculate_iou(layout_det1["bbox"], layout_det2["bbox"])
+                            > 0.9
+                        ):
+                            if layout_det1["score"] < layout_det2["score"]:
                                 layout_det_need_remove = layout_det1
                             else:
                                 layout_det_need_remove = layout_det2
@@ -97,11 +112,11 @@ class MagicModel:
     def __init__(self, model_list: list, docs: fitz.Document):
         self.__model_list = model_list
         self.__docs = docs
-        '''为所有模型数据添加bbox信息(缩放,poly->bbox)'''
+        """为所有模型数据添加bbox信息(缩放,poly->bbox)"""
         self.__fix_axis()
-        '''删除置信度特别低的模型数据(<0.05),提高质量'''
+        """删除置信度特别低的模型数据(<0.05),提高质量"""
         self.__fix_by_remove_low_confidence()
-        '''删除高iou(>0.9)数据中置信度较低的那个'''
+        """删除高iou(>0.9)数据中置信度较低的那个"""
         self.__fix_by_remove_high_iou_and_low_confidence()
 
     def __reduct_overlap(self, bboxes):
@@ -125,16 +140,6 @@ class MagicModel:
         ret = []
         MAX_DIS_OF_POINT = 10**9 + 7
 
-        def expand_bbox(bbox1, bbox2):
-            x0 = min(bbox1[0], bbox2[0])
-            y0 = min(bbox1[1], bbox2[1])
-            x1 = max(bbox1[2], bbox2[2])
-            y1 = max(bbox1[3], bbox2[3])
-            return [x0, y0, x1, y1]
-
-        def get_bbox_area(bbox):
-            return abs(bbox[2] - bbox[0]) * abs(bbox[3] - bbox[1])
-
         # subject 和 object 的 bbox 会合并成一个大的 bbox (named: merged bbox)。 筛选出所有和 merged bbox 有 overlap 且 overlap 面积大于 object 的面积的 subjects。
         # 再求出筛选出的 subjects 和 object 的最短距离!
         def may_find_other_nearest_bbox(subject_idx, object_idx):
@@ -177,6 +182,13 @@ class MagicModel:
 
             return ret
 
+        def expand_bbbox(idxes):
+            x0s = [all_bboxes[idx]["bbox"][0] for idx in idxes] 
+            y0s = [all_bboxes[idx]["bbox"][1] for idx in idxes] 
+            x1s = [all_bboxes[idx]["bbox"][2] for idx in idxes] 
+            y1s = [all_bboxes[idx]["bbox"][3] for idx in idxes] 
+            return min(x0s), min(y0s), max(x1s), max(y1s)
+
         subjects = self.__reduct_overlap(
             list(
                 map(
@@ -268,7 +280,9 @@ class MagicModel:
                     or dis[i][j] == MAX_DIS_OF_POINT
                 ):
                     continue
-                left, right, _, _ = bbox_relative_pos(all_bboxes[i]["bbox"], all_bboxes[j]["bbox"]) # 由  pos_flag_count 相关逻辑保证本段逻辑准确性
+                left, right, _, _ = bbox_relative_pos(
+                    all_bboxes[i]["bbox"], all_bboxes[j]["bbox"]
+                )  # 由  pos_flag_count 相关逻辑保证本段逻辑准确性
                 if left or right:
                     one_way_dis = all_bboxes[i]["bbox"][2] - all_bboxes[i]["bbox"][0]
                 else:
@@ -322,6 +336,10 @@ class MagicModel:
                             break
 
                     if is_nearest:
+                        nx0, ny0, nx1, ny1 = expand_bbbox(list(seen) + [k])
+                        n_dis = bbox_distance(all_bboxes[i]["bbox"], [nx0, ny0, nx1, ny1])
+                        if float_gt(dis[i][j], n_dis):
+                            continue
                         tmp.append(k)
                         seen.add(k)
 
@@ -331,20 +349,7 @@ class MagicModel:
 
             # 已经获取到某个 figure 下所有的最靠近的 captions,以及最靠近这些 captions 的 captions 。
             # 先扩一下 bbox,
-            x0s = [all_bboxes[idx]["bbox"][0] for idx in seen] + [
-                all_bboxes[i]["bbox"][0]
-            ]
-            y0s = [all_bboxes[idx]["bbox"][1] for idx in seen] + [
-                all_bboxes[i]["bbox"][1]
-            ]
-            x1s = [all_bboxes[idx]["bbox"][2] for idx in seen] + [
-                all_bboxes[i]["bbox"][2]
-            ]
-            y1s = [all_bboxes[idx]["bbox"][3] for idx in seen] + [
-                all_bboxes[i]["bbox"][3]
-            ]
-
-            ox0, oy0, ox1, oy1 = min(x0s), min(y0s), max(x1s), max(y1s)
+            ox0, oy0, ox1, oy1 = expand_bbbox(list(seen) + [i])
             ix0, iy0, ix1, iy1 = all_bboxes[i]["bbox"]
 
             # 分成了 4 个截取空间,需要计算落在每个截取空间下 objects 合并后占据的矩形面积
@@ -455,8 +460,10 @@ class MagicModel:
                 with_caption_subject.add(j)
         return ret, total_subject_object_dis
 
-    def get_imgs(self, page_no: int):  # @许瑞
-        records, _ = self.__tie_up_category_by_distance(page_no, 3, 4)
+    def get_imgs(self, page_no: int):
+        figure_captions, _ = self.__tie_up_category_by_distance(
+            page_no, 3, 4
+        )
         return [
             {
                 "bbox": record["all"],
@@ -464,7 +471,7 @@ class MagicModel:
                 "img_caption_bbox": record.get("object_body", None),
                 "score": record["score"],
             }
-            for record in records
+            for record in figure_captions
         ]
 
     def get_tables(
@@ -535,6 +542,7 @@ class MagicModel:
                 if not any(span == existing_span for existing_span in new_spans):
                     new_spans.append(span)
             return new_spans
+
         all_spans = []
         model_page_info = self.__model_list[page_no]
         layout_dets = model_page_info["layout_dets"]
@@ -548,10 +556,7 @@ class MagicModel:
         for layout_det in layout_dets:
             category_id = layout_det["category_id"]
             if category_id in allow_category_id_list:
-                span = {
-                    "bbox": layout_det["bbox"],
-                    "score": layout_det["score"]
-                }
+                span = {"bbox": layout_det["bbox"], "score": layout_det["score"]}
                 if category_id == 3:
                     span["type"] = ContentType.Image
                 elif category_id == 5:
@@ -604,7 +609,6 @@ class MagicModel:
         return self.__model_list[page_no]
 
 
-
 if __name__ == "__main__":
     drw = DiskReaderWriter(r"D:/project/20231108code-clean")
     if 0: