Browse Source

refactor: improve block processing logic and enhance span handling

myhloli 5 tháng trước cách đây
mục cha
commit
236a6033f1

+ 14 - 12
mineru/backend/pipeline/batch_analyze.py

@@ -230,18 +230,18 @@ class BatchAnalyze:
                         ocr_result_list = get_ocr_result_list(ocr_res, useful_list, ocr_res_list_dict['ocr_enable'],
                                                               new_image, _lang)
 
-                        if res["category_id"] == 3:
-                            # ocr_result_list中所有bbox的面积之和
-                            ocr_res_area = sum(
-                                get_coords_and_area(ocr_res_item)[4] for ocr_res_item in ocr_result_list if 'poly' in ocr_res_item)
-                            # 求ocr_res_area和res的面积的比值
-                            res_area = get_coords_and_area(res)[4]
-                            if res_area > 0:
-                                ratio = ocr_res_area / res_area
-                                if ratio > 0.3:
-                                    res["category_id"] = 1
-                                else:
-                                    continue
+                        # if res["category_id"] == 3 and ocr_res_list_dict['ocr_enable']:
+                        #     # ocr_result_list中所有bbox的面积之和
+                        #     ocr_res_area = sum(
+                        #         get_coords_and_area(ocr_res_item)[4] for ocr_res_item in ocr_result_list if 'poly' in ocr_res_item)
+                        #     # 求ocr_res_area和res的面积的比值
+                        #     res_area = get_coords_and_area(res)[4]
+                        #     if res_area > 0:
+                        #         ratio = ocr_res_area / res_area
+                        #         if ratio > 0.25:
+                        #             res["category_id"] = 1
+                        #         else:
+                        #             continue
 
                         ocr_res_list_dict['layout_res'].extend(ocr_result_list)
 
@@ -321,6 +321,8 @@ class BatchAnalyze:
                         ocr_text, ocr_score = ocr_res_list[index]
                         layout_res_item['text'] = ocr_text
                         layout_res_item['score'] = float(f"{ocr_score:.3f}")
+                        if ocr_score < 0.6:
+                            layout_res_item['category_id'] = 16
 
                     total_processed += len(img_crop_list)
 

+ 34 - 8
mineru/backend/pipeline/model_json_to_middle_json.py

@@ -8,6 +8,7 @@ from mineru.backend.pipeline.model_init import AtomModelSingleton
 from mineru.backend.pipeline.para_split import para_split
 from mineru.utils.block_pre_proc import prepare_block_bboxes, process_groups
 from mineru.utils.block_sort import sort_blocks_by_bbox
+from mineru.utils.boxbase import calculate_overlap_area_in_bbox1_area_ratio
 from mineru.utils.cut_image import cut_image_and_table
 from mineru.utils.llm_aided import llm_aided_title
 from mineru.utils.model_utils import clean_memory
@@ -27,22 +28,48 @@ def page_model_info_to_page_info(page_model_info, image_dict, page, image_writer
     magic_model = MagicModel(page_model_info, scale)
 
     """从magic_model对象中获取后面会用到的区块信息"""
+    discarded_blocks = magic_model.get_discarded()
+    text_blocks = magic_model.get_text_blocks()
+    title_blocks = magic_model.get_title_blocks()
+    inline_equations, interline_equations, interline_equation_blocks = magic_model.get_equations()
+
     img_groups = magic_model.get_imgs()
     table_groups = magic_model.get_tables()
 
     """对image和table的区块分组"""
-    img_body_blocks, img_caption_blocks, img_footnote_blocks = process_groups(
+    img_body_blocks, img_caption_blocks, img_footnote_blocks, maybe_text_image_blocks = process_groups(
         img_groups, 'image_body', 'image_caption_list', 'image_footnote_list'
     )
 
-    table_body_blocks, table_caption_blocks, table_footnote_blocks = process_groups(
+    table_body_blocks, table_caption_blocks, table_footnote_blocks, _ = process_groups(
         table_groups, 'table_body', 'table_caption_list', 'table_footnote_list'
     )
 
-    discarded_blocks = magic_model.get_discarded()
-    text_blocks = magic_model.get_text_blocks()
-    title_blocks = magic_model.get_title_blocks()
-    inline_equations, interline_equations, interline_equation_blocks = magic_model.get_equations()
+    """获取所有的spans信息"""
+    spans = magic_model.get_all_spans()
+
+    if len(maybe_text_image_blocks) > 0:
+        for block in maybe_text_image_blocks:
+            span_in_block_list = []
+            for span in spans:
+                if span['type'] == 'text' and calculate_overlap_area_in_bbox1_area_ratio(span['bbox'], block['bbox']) > 0.7:
+                    span_in_block_list.append(span)
+            if len(span_in_block_list) > 0:
+                # span_in_block_list中所有bbox的面积之和
+                spans_area = sum((span['bbox'][2] - span['bbox'][0]) * (span['bbox'][3] - span['bbox'][1]) for span in span_in_block_list)
+                # 求ocr_res_area和res的面积的比值
+                block_area = (block['bbox'][2] - block['bbox'][0]) * (block['bbox'][3] - block['bbox'][1])
+                if block_area > 0:
+                    ratio = spans_area / block_area
+                    if ratio > 0.25 and ocr:
+                        # 移除block的group_id
+                        block.pop('group_id', None)
+                        text_blocks.append(block)
+                    else:
+                        img_body_blocks.append(block)
+            else:
+                img_body_blocks.append(block)
+
 
     """将所有区块的bbox整理到一起"""
     interline_equation_blocks = []
@@ -68,8 +95,7 @@ def page_model_info_to_page_info(page_model_info, image_dict, page, image_writer
             page_w,
             page_h,
         )
-    """获取所有的spans信息"""
-    spans = magic_model.get_all_spans()
+
     """在删除重复span之前,应该通过image_body和table_body的block过滤一下image和table的span"""
     """顺便删除大水印并保留abandon的span"""
     spans = remove_outside_spans(spans, all_bboxes, all_discarded_blocks)

+ 16 - 9
mineru/utils/block_pre_proc.py

@@ -12,16 +12,23 @@ def process_groups(groups, body_key, caption_key, footnote_key):
     body_blocks = []
     caption_blocks = []
     footnote_blocks = []
+    maybe_text_image_blocks = []
     for i, group in enumerate(groups):
-        group[body_key]['group_id'] = i
-        body_blocks.append(group[body_key])
-        for caption_block in group[caption_key]:
-            caption_block['group_id'] = i
-            caption_blocks.append(caption_block)
-        for footnote_block in group[footnote_key]:
-            footnote_block['group_id'] = i
-            footnote_blocks.append(footnote_block)
-    return body_blocks, caption_blocks, footnote_blocks
+        if body_key == 'image_body' and len(group[caption_key]) == 0 and len(group[footnote_key]) == 0:
+            # 如果没有caption和footnote,则不需要将group_id添加到image_body中
+            group[body_key]['group_id'] = i
+            maybe_text_image_blocks.append(group[body_key])
+            continue
+        else:
+            group[body_key]['group_id'] = i
+            body_blocks.append(group[body_key])
+            for caption_block in group[caption_key]:
+                caption_block['group_id'] = i
+                caption_blocks.append(caption_block)
+            for footnote_block in group[footnote_key]:
+                footnote_block['group_id'] = i
+                footnote_blocks.append(footnote_block)
+    return body_blocks, caption_blocks, footnote_blocks, maybe_text_image_blocks
 
 
 def prepare_block_bboxes(

+ 0 - 11
mineru/utils/boxbase.py

@@ -148,17 +148,6 @@ def calculate_iou(bbox1, bbox2):
     return iou
 
 
-def _is_in(box1, box2) -> bool:
-    """box1是否完全在box2里面."""
-    x0_1, y0_1, x1_1, y1_1 = box1
-    x0_2, y0_2, x1_2, y1_2 = box2
-
-    return (x0_1 >= x0_2 and  # box1的左边界不在box2的左边外
-            y0_1 >= y0_2 and  # box1的上边界不在box2的上边外
-            x1_1 <= x1_2 and  # box1的右边界不在box2的右边外
-            y1_1 <= y1_2)  # box1的下边界不在box2的下边外
-
-
 def calculate_overlap_area_in_bbox1_area_ratio(bbox1, bbox2):
     """计算box1和box2的重叠面积占bbox1的比例."""
     # Determine the coordinates of the intersection rectangle

+ 2 - 2
mineru/utils/pipeline_magic_model.py

@@ -1,4 +1,4 @@
-from mineru.utils.boxbase import bbox_relative_pos, calculate_iou, bbox_distance, _is_in
+from mineru.utils.boxbase import bbox_relative_pos, calculate_iou, bbox_distance, is_in
 from mineru.utils.enum_class import CategoryId, ContentType
 
 
@@ -156,7 +156,7 @@ class MagicModel:
             for j in range(N):
                 if i == j:
                     continue
-                if _is_in(bboxes[i]['bbox'], bboxes[j]['bbox']):
+                if is_in(bboxes[i]['bbox'], bboxes[j]['bbox']):
                     keep[i] = False
         return [bboxes[i] for i in range(N) if keep[i]]