Переглянути джерело

Merge branch 'dev' of https://github.com/opendatalab/MinerU into dev

quyuan 1 рік тому
батько
коміт
9b30ea20e5
3 змінених файлів з 116 додано та 46 видалено
  1. 19 0
      magic_pdf/libs/boxbase.py
  2. 96 45
      magic_pdf/model/magic_model.py
  3. 1 1
      magic_pdf/tools/common.py

+ 19 - 0
magic_pdf/libs/boxbase.py

@@ -426,3 +426,22 @@ def bbox_distance(bbox1, bbox2):
     elif top:
         return y2 - y1b
     return 0.0
+
+
+def box_area(bbox):
+    return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
+
+
+def get_overlap_area(bbox1, bbox2):
+    """计算box1和box2的重叠面积占bbox1的比例."""
+    # Determine the coordinates of the intersection rectangle
+    x_left = max(bbox1[0], bbox2[0])
+    y_top = max(bbox1[1], bbox2[1])
+    x_right = min(bbox1[2], bbox2[2])
+    y_bottom = min(bbox1[3], bbox2[3])
+
+    if x_right < x_left or y_bottom < y_top:
+        return 0.0
+
+    # The area of overlap area
+    return (x_right - x_left) * (y_bottom - y_top)

+ 96 - 45
magic_pdf/model/magic_model.py

@@ -1,8 +1,9 @@
 import json
 
 from magic_pdf.libs.boxbase import (_is_in, _is_part_overlap, bbox_distance,
-                                    bbox_relative_pos, calculate_iou,
-                                    calculate_overlap_area_in_bbox1_area_ratio)
+                                    bbox_relative_pos, box_area, calculate_iou,
+                                    calculate_overlap_area_in_bbox1_area_ratio,
+                                    get_overlap_area)
 from magic_pdf.libs.commons import fitz, join_path
 from magic_pdf.libs.coordinate_transform import get_scale_ratio
 from magic_pdf.libs.local_math import float_gt
@@ -12,6 +13,7 @@ from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
 from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
 
 CAPATION_OVERLAP_AREA_RATIO = 0.6
+MERGE_BOX_OVERLAP_AREA_RATIO = 1.1
 
 
 class MagicModel:
@@ -124,49 +126,51 @@ class MagicModel:
                     tables.append(obj)
                 if len(footnotes) * len(figures) == 0:
                     continue
-                dis_figure_footnote = {}
-                dis_table_footnote = {}
-
-                for i in range(len(footnotes)):
-                    for j in range(len(figures)):
-                        pos_flag_count = sum(
-                            list(
-                                map(
-                                    lambda x: 1 if x else 0,
-                                    bbox_relative_pos(
-                                        footnotes[i]['bbox'], figures[j]['bbox']
-                                    ),
-                                )
+            dis_figure_footnote = {}
+            dis_table_footnote = {}
+
+            for i in range(len(footnotes)):
+                for j in range(len(figures)):
+                    pos_flag_count = sum(
+                        list(
+                            map(
+                                lambda x: 1 if x else 0,
+                                bbox_relative_pos(
+                                    footnotes[i]['bbox'], figures[j]['bbox']
+                                ),
                             )
                         )
-                        if pos_flag_count > 1:
-                            continue
-                        dis_figure_footnote[i] = min(
-                            bbox_distance(figures[j]['bbox'], footnotes[i]['bbox']),
-                            dis_figure_footnote.get(i, float('inf')),
-                        )
-                for i in range(len(footnotes)):
-                    for j in range(len(tables)):
-                        pos_flag_count = sum(
-                            list(
-                                map(
-                                    lambda x: 1 if x else 0,
-                                    bbox_relative_pos(
-                                        footnotes[i]['bbox'], tables[j]['bbox']
-                                    ),
-                                )
+                    )
+                    if pos_flag_count > 1:
+                        continue
+                    dis_figure_footnote[i] = min(
+                        bbox_distance(figures[j]['bbox'], footnotes[i]['bbox']),
+                        dis_figure_footnote.get(i, float('inf')),
+                    )
+            for i in range(len(footnotes)):
+                for j in range(len(tables)):
+                    pos_flag_count = sum(
+                        list(
+                            map(
+                                lambda x: 1 if x else 0,
+                                bbox_relative_pos(
+                                    footnotes[i]['bbox'], tables[j]['bbox']
+                                ),
                             )
                         )
-                        if pos_flag_count > 1:
-                            continue
+                    )
+                    if pos_flag_count > 1:
+                        continue
 
-                        dis_table_footnote[i] = min(
-                            bbox_distance(tables[j]['bbox'], footnotes[i]['bbox']),
-                            dis_table_footnote.get(i, float('inf')),
-                        )
-                for i in range(len(footnotes)):
-                    if dis_table_footnote.get(i, float('inf')) > dis_figure_footnote[i]:
-                        footnotes[i]['category_id'] = CategoryId.ImageFootnote
+                    dis_table_footnote[i] = min(
+                        bbox_distance(tables[j]['bbox'], footnotes[i]['bbox']),
+                        dis_table_footnote.get(i, float('inf')),
+                    )
+            for i in range(len(footnotes)):
+                if i not in dis_figure_footnote:
+                    continue
+                if dis_table_footnote.get(i, float('inf')) > dis_figure_footnote[i]:
+                    footnotes[i]['category_id'] = CategoryId.ImageFootnote
 
     def __reduct_overlap(self, bboxes):
         N = len(bboxes)
@@ -191,6 +195,44 @@ class MagicModel:
         筛选出所有和 merged bbox 有 overlap 且 overlap 面积大于 object 的面积的 subjects。
         再求出筛选出的 subjects 和 object 的最短距离
         """
+        def search_overlap_between_boxes(
+            subject_idx, object_idx
+        ):
+            idxes = [subject_idx, object_idx]
+            x0s = [all_bboxes[idx]['bbox'][0] for idx in idxes]
+            y0s = [all_bboxes[idx]['bbox'][1] for idx in idxes]
+            x1s = [all_bboxes[idx]['bbox'][2] for idx in idxes]
+            y1s = [all_bboxes[idx]['bbox'][3] for idx in idxes]
+
+            merged_bbox = [
+                min(x0s),
+                min(y0s),
+                max(x1s),
+                max(y1s),
+            ]
+            ratio = 0
+
+            other_objects = list(
+                map(
+                    lambda x: {'bbox': x['bbox'], 'score': x['score']},
+                    filter(
+                        lambda x: x['category_id']
+                        not in (object_category_id, subject_category_id),
+                        self.__model_list[page_no]['layout_dets'],
+                    ),
+                )
+            )
+            for other_object in other_objects:
+                ratio = max(
+                    ratio,
+                    get_overlap_area(
+                        merged_bbox, other_object['bbox']
+                    ) * 1.0 / box_area(all_bboxes[object_idx]['bbox'])
+                )
+                if ratio >= MERGE_BOX_OVERLAP_AREA_RATIO:
+                    break
+
+            return ratio
 
         def may_find_other_nearest_bbox(subject_idx, object_idx):
             ret = float('inf')
@@ -299,6 +341,15 @@ class MagicModel:
                 ):
                     continue
 
+                subject_idx, object_idx = i, j
+                if all_bboxes[j]['category_id'] == subject_category_id:
+                    subject_idx, object_idx = j, i
+
+                if search_overlap_between_boxes(subject_idx, object_idx) >= MERGE_BOX_OVERLAP_AREA_RATIO:
+                    dis[i][j] = float('inf')
+                    dis[j][i] = dis[i][j]
+                    continue
+
                 dis[i][j] = bbox_distance(all_bboxes[i]['bbox'], all_bboxes[j]['bbox'])
                 dis[j][i] = dis[i][j]
 
@@ -627,13 +678,13 @@ class MagicModel:
                     span['type'] = ContentType.Image
                 elif category_id == 5:
                     # 获取table模型结果
-                    latex = layout_det.get("latex", None)
-                    html = layout_det.get("html", None)
+                    latex = layout_det.get('latex', None)
+                    html = layout_det.get('html', None)
                     if latex:
-                        span["latex"] = latex
+                        span['latex'] = latex
                     elif html:
-                        span["html"] = html
-                    span["type"] = ContentType.Table
+                        span['html'] = html
+                    span['type'] = ContentType.Table
                 elif category_id == 13:
                     span['content'] = layout_det['latex']
                     span['type'] = ContentType.InlineEquation

+ 1 - 1
magic_pdf/tools/common.py

@@ -50,7 +50,7 @@ def do_parse(
 >>>>>>> 0140d7d271ac3b1561ca2272030e9e038b469999
 ):
     if debug_able:
-        logger.warning("debug mode is on")
+        logger.warning('debug mode is on')
         f_dump_content_list = True
         f_draw_model_bbox = True