Browse Source

Merge branch 'master' into fix/update_remove_overlap

myhloli 1 year ago
parent
commit
1baf13f379

+ 11 - 0
magic_pdf/libs/boxbase.py

@@ -161,6 +161,17 @@ def __is_overlaps_y_exceeds_threshold(bbox1, bbox2, overlap_ratio_threshold=0.8)
 
 
 def calculate_iou(bbox1, bbox2):
+    """
+    计算两个边界框的交并比(IOU)。
+
+    Args:
+        bbox1 (list[float]): 第一个边界框的坐标,格式为 [x1, y1, x2, y2],其中 (x1, y1) 为左上角坐标,(x2, y2) 为右下角坐标。
+        bbox2 (list[float]): 第二个边界框的坐标,格式与 `bbox1` 相同。
+
+    Returns:
+        float: 两个边界框的交并比(IOU),取值范围为 [0, 1]。
+
+    """
     # Determine the coordinates of the intersection rectangle
     x_left = max(bbox1[0], bbox2[0])
     y_top = max(bbox1[1], bbox2[1])

+ 27 - 3
magic_pdf/model/magic_model.py

@@ -90,10 +90,21 @@ class MagicModel:
         ret = []
         MAX_DIS_OF_POINT = 10**9 + 7
 
+        def expand_bbox(bbox1, bbox2):
+            x0 = min(bbox1[0], bbox2[0])
+            y0 = min(bbox1[1], bbox2[1])
+            x1 = max(bbox1[2], bbox2[2])
+            y1 = max(bbox1[3], bbox2[3])
+            return [x0, y0, x1, y1]
+
+        def get_bbox_area(bbox):
+            return abs(bbox[2] - bbox[0]) * abs(bbox[3] - bbox[1])
+
         # subject 和 object 的 bbox 会合并成一个大的 bbox (named: merged bbox)。 筛选出所有和 merged bbox 有 overlap 且 overlap 面积大于 object 的面积的 subjects。
         # 再求出筛选出的 subjects 和 object 的最短距离!
         def may_find_other_nearest_bbox(subject_idx, object_idx):
             ret = float("inf")
+
             x0 = min(
                 all_bboxes[subject_idx]["bbox"][0], all_bboxes[object_idx]["bbox"][0]
             )
@@ -112,6 +123,7 @@ class MagicModel:
             ) * abs(
                 all_bboxes[object_idx]["bbox"][3] - all_bboxes[object_idx]["bbox"][1]
             )
+
             for i in range(len(all_bboxes)):
                 if (
                     i == subject_idx
@@ -121,11 +133,13 @@ class MagicModel:
                 if _is_part_overlap([x0, y0, x1, y1], all_bboxes[i]["bbox"]) or _is_in(
                     all_bboxes[i]["bbox"], [x0, y0, x1, y1]
                 ):
+
                     i_area = abs(
                         all_bboxes[i]["bbox"][2] - all_bboxes[i]["bbox"][0]
                     ) * abs(all_bboxes[i]["bbox"][3] - all_bboxes[i]["bbox"][1])
                     if i_area >= object_area:
                         ret = min(float("inf"), dis[i][object_idx])
+
             return ret
 
         subjects = self.__reduct_overlap(
@@ -224,7 +238,7 @@ class MagicModel:
             arr.sort(key=lambda x: x[0])
             if len(arr) > 0:
                 # bug: 离该subject 最近的 object 可能跨越了其它的 subject 。比如 [this subect] [some sbuject] [the nearest objec of subject]
-                if may_find_other_nearest_bbox(i, j) >= arr[0][0]:
+                if may_find_other_nearest_bbox(i, arr[0][1]) >= arr[0][0]:
                     candidates.append(arr[0][1])
                     seen.add(arr[0][1])
 
@@ -470,6 +484,12 @@ class MagicModel:
         return text_spans
 
     def get_all_spans(self, page_no: int) -> list:
+        def remove_duplicate_spans(spans):
+            new_spans = []
+            for span in spans:
+                if not any(span == existing_span for existing_span in new_spans):
+                    new_spans.append(span)
+            return new_spans
         all_spans = []
         model_page_info = self.__model_list[page_no]
         layout_dets = model_page_info["layout_dets"]
@@ -483,7 +503,10 @@ class MagicModel:
         for layout_det in layout_dets:
             category_id = layout_det["category_id"]
             if category_id in allow_category_id_list:
-                span = {"bbox": layout_det["bbox"]}
+                span = {
+                    "bbox": layout_det["bbox"],
+                    "score": layout_det["score"]
+                }
                 if category_id == 3:
                     span["type"] = ContentType.Image
                 elif category_id == 5:
@@ -498,7 +521,7 @@ class MagicModel:
                     span["content"] = layout_det["text"]
                     span["type"] = ContentType.Text
                 all_spans.append(span)
-        return all_spans
+        return remove_duplicate_spans(all_spans)
 
     def get_page_size(self, page_no: int):  # 获取页面宽高
         # 获取当前页的page对象
@@ -533,6 +556,7 @@ class MagicModel:
         return self.__model_list[page_no]
 
 
+
 if __name__ == "__main__":
     drw = DiskReaderWriter(r"D:/project/20231108code-clean")
     if 0:

+ 7 - 1
magic_pdf/pdf_parse_union_core.py

@@ -19,7 +19,8 @@ from magic_pdf.pre_proc.equations_replace import remove_chars_in_text_blocks, re
 from magic_pdf.pre_proc.ocr_detect_all_bboxes import ocr_prepare_bboxes_for_layout_split
 from magic_pdf.pre_proc.ocr_dict_merge import sort_blocks_by_layout, fill_spans_in_blocks, fix_block_spans, \
     fix_discarded_block
-from magic_pdf.pre_proc.ocr_span_list_modify import remove_overlaps_min_spans, get_qa_need_list_v2
+from magic_pdf.pre_proc.ocr_span_list_modify import remove_overlaps_min_spans, get_qa_need_list_v2, \
+    remove_overlaps_low_confidence_spans
 from magic_pdf.pre_proc.resolve_bbox_conflict import check_useful_block_horizontal_overlap
 
 
@@ -64,6 +65,7 @@ def txt_spans_extract(pdf_page, inline_equations, interline_equations):
                             "bbox": list(span["bbox"]),
                             "content": span["latex"],
                             "type": ContentType.InlineEquation,
+                            "score": 1.0,
                         }
                     )
                 elif span.get('type') == ContentType.InterlineEquation:
@@ -72,6 +74,7 @@ def txt_spans_extract(pdf_page, inline_equations, interline_equations):
                             "bbox": list(span["bbox"]),
                             "content": span["latex"],
                             "type": ContentType.InterlineEquation,
+                            "score": 1.0,
                         }
                     )
                 else:
@@ -80,6 +83,7 @@ def txt_spans_extract(pdf_page, inline_equations, interline_equations):
                             "bbox": list(span["bbox"]),
                             "content": span["text"],
                             "type": ContentType.Text,
+                            "score": 1.0,
                         }
                     )
     return spans
@@ -117,6 +121,8 @@ def parse_page_core(pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter,
     else:
         raise Exception("parse_mode must be txt or ocr")
 
+    '''删除重叠spans中置信度较低的那些'''
+    spans, dropped_spans_by_confidence = remove_overlaps_low_confidence_spans(spans)
     '''删除重叠spans中较小的那些'''
     spans, dropped_spans_by_span_overlap = remove_overlaps_min_spans(spans)
     '''对image和table截图'''

+ 22 - 1
magic_pdf/pre_proc/ocr_span_list_modify.py

@@ -1,10 +1,31 @@
 from loguru import logger
 
 from magic_pdf.libs.boxbase import calculate_overlap_area_in_bbox1_area_ratio, get_minbox_if_overlap_by_ratio, \
-    __is_overlaps_y_exceeds_threshold
+    __is_overlaps_y_exceeds_threshold, calculate_iou
 from magic_pdf.libs.drop_tag import DropTag
 from magic_pdf.libs.ocr_content_type import ContentType, BlockType
 
+def remove_overlaps_low_confidence_spans(spans):
+    dropped_spans = []
+    #  删除重叠spans中置信度低的的那些
+    for span1 in spans:
+        for span2 in spans:
+            if span1 != span2:
+                if calculate_iou(span1['bbox'], span2['bbox']) > 0.9:
+                    if span1['score'] < span2['score']:
+                        span_need_remove = span1
+                    else:
+                        span_need_remove = span2
+                    if span_need_remove is not None and span_need_remove not in dropped_spans:
+                        dropped_spans.append(span_need_remove)
+
+    if len(dropped_spans) > 0:
+        for span_need_remove in dropped_spans:
+            spans.remove(span_need_remove)
+            span_need_remove['tag'] = DropTag.SPAN_OVERLAP
+
+    return spans, dropped_spans
+
 
 def remove_overlaps_min_spans(spans):
     dropped_spans = []