Procházet zdrojové kódy

refactor: enhance block processing and sorting utilities for improved span management

myhloli před 5 měsíci
rodič
revize
0f21495a06

+ 70 - 17
mineru/backend/pipeline/model_json_to_middle_json.py

@@ -1,11 +1,16 @@
 # Copyright (c) Opendatalab. All rights reserved.
-from mineru.utils.block_pre_proc import prepare_block_bboxes
+from mineru.utils.block_pre_proc import prepare_block_bboxes, process_groups
+from mineru.utils.block_sort import sort_blocks_by_bbox
+from mineru.utils.cut_image import cut_image_and_table
 from mineru.utils.pipeline_magic_model import MagicModel
+from mineru.utils.span_block_fix import fill_spans_in_blocks, fix_discarded_block, fix_block_spans
+from mineru.utils.span_pre_proc import remove_outside_spans, remove_overlaps_low_confidence_spans, \
+    remove_overlaps_min_spans, txt_spans_extract
 from mineru.version import __version__
 from mineru.utils.hash_utils import str_md5
 
 
-def page_model_info_to_page_info(page_model_info, image_dict, page, image_writer, page_index, lang=None, ocr=False):
+def page_model_info_to_page_info(page_model_info, image_dict, page, image_writer, page_index, ocr=False):
     scale = image_dict["scale"]
     page_pil_img = image_dict["img_pil"]
     page_img_md5 = str_md5(image_dict["img_base64"])
@@ -54,6 +59,57 @@ def page_model_info_to_page_info(page_model_info, image_dict, page, image_writer
             page_w,
             page_h,
         )
+    """获取所有的spans信息"""
+    spans = magic_model.get_all_spans()
+    """在删除重复span之前,应该通过image_body和table_body的block过滤一下image和table的span"""
+    """顺便删除大水印并保留abandon的span"""
+    spans = remove_outside_spans(spans, all_bboxes, all_discarded_blocks)
+
+    """删除重叠spans中置信度较低的那些"""
+    spans, dropped_spans_by_confidence = remove_overlaps_low_confidence_spans(spans)
+    """删除重叠spans中较小的那些"""
+    spans, dropped_spans_by_span_overlap = remove_overlaps_min_spans(spans)
+
+    """根据parse_mode,构造spans,主要是文本类的字符填充"""
+    if ocr:
+        pass
+    else:
+        """使用新版本的混合ocr方案."""
+        spans = txt_spans_extract(page, spans, page_pil_img, scale)
+
+    """先处理不需要排版的discarded_blocks"""
+    discarded_block_with_spans, spans = fill_spans_in_blocks(
+        all_discarded_blocks, spans, 0.4
+    )
+    fix_discarded_blocks = fix_discarded_block(discarded_block_with_spans)
+
+    """如果当前页面没有有效的bbox则跳过"""
+    if len(all_bboxes) == 0:
+        return None
+
+    """对image和table截图"""
+    for span in spans:
+        if span['type'] in ['image', 'table']:
+            span = cut_image_and_table(
+                span, page_pil_img, page_img_md5, page_index, image_writer, scale=scale
+            )
+
+    """span填充进block"""
+    block_with_spans, spans = fill_spans_in_blocks(all_bboxes, spans, 0.5)
+
+    """对block进行fix操作"""
+    fix_blocks = fix_block_spans(block_with_spans)
+
+    """同一行被断开的titile合并"""
+    # merge_title_blocks(fix_blocks)
+
+    """对block进行排序"""
+    sorted_blocks = sort_blocks_by_bbox(fix_blocks, page_w, page_h, footnote_blocks)
+
+    """构造page_info"""
+    page_info = make_page_info_dict(sorted_blocks, page_index, page_w, page_h, fix_discarded_blocks)
+
+    return page_info
 
 
 def result_to_middle_json(model_list, images_list, pdf_doc, image_writer, lang=None, ocr=False):
@@ -62,23 +118,20 @@ def result_to_middle_json(model_list, images_list, pdf_doc, image_writer, lang=N
         page = pdf_doc[page_index]
         image_dict = images_list[page_index]
         page_info = page_model_info_to_page_info(
-            page_model_info, image_dict, page, image_writer, page_index, lang=lang, ocr=ocr
+            page_model_info, image_dict, page, image_writer, page_index, ocr=ocr
         )
+        if page_info is None:
+            page_w, page_h = map(int, page.get_size())
+            page_info = make_page_info_dict([], page_index, page_w, page_h, [])
         middle_json["pdf_info"].append(page_info)
     return middle_json
 
 
-def process_groups(groups, body_key, caption_key, footnote_key):
-    body_blocks = []
-    caption_blocks = []
-    footnote_blocks = []
-    for i, group in enumerate(groups):
-        group[body_key]['group_id'] = i
-        body_blocks.append(group[body_key])
-        for caption_block in group[caption_key]:
-            caption_block['group_id'] = i
-            caption_blocks.append(caption_block)
-        for footnote_block in group[footnote_key]:
-            footnote_block['group_id'] = i
-            footnote_blocks.append(footnote_block)
-    return body_blocks, caption_blocks, footnote_blocks
+def make_page_info_dict(blocks, page_id, page_w, page_h, discarded_blocks):
+    return_dict = {
+        'preproc_blocks': blocks,
+        'page_idx': page_id,
+        'page_size': [page_w, page_h],
+        'discarded_blocks': discarded_blocks,
+    }
+    return return_dict

+ 16 - 0
mineru/utils/block_pre_proc.py

@@ -8,6 +8,22 @@ from mineru.utils.boxbase import (
 from mineru.utils.enum_class import BlockType
 
 
+def process_groups(groups, body_key, caption_key, footnote_key):
+    body_blocks = []
+    caption_blocks = []
+    footnote_blocks = []
+    for i, group in enumerate(groups):
+        group[body_key]['group_id'] = i
+        body_blocks.append(group[body_key])
+        for caption_block in group[caption_key]:
+            caption_block['group_id'] = i
+            caption_blocks.append(caption_block)
+        for footnote_block in group[footnote_key]:
+            footnote_block['group_id'] = i
+            footnote_blocks.append(footnote_block)
+    return body_blocks, caption_blocks, footnote_blocks
+
+
 def prepare_block_bboxes(
     img_body_blocks,
     img_caption_blocks,

+ 338 - 0
mineru/utils/block_sort.py

@@ -0,0 +1,338 @@
+# Copyright (c) Opendatalab. All rights reserved.
+import copy
+import os
+import statistics
+import warnings
+from typing import List
+import torch
+from loguru import logger
+
+from mineru.backend.pipeline.config_reader import get_device, get_local_layoutreader_model_dir
+from mineru.utils.enum_class import BlockType
+
+
+def sort_blocks_by_bbox(blocks, page_w, page_h, footnote_blocks):
+
+    """获取所有line并计算正文line的高度"""
+    line_height = get_line_height(blocks)
+
+    """获取所有line并对line排序"""
+    sorted_bboxes = sort_lines_by_model(blocks, page_w, page_h, line_height, footnote_blocks)
+
+    """根据line的中位数算block的序列关系"""
+    blocks = cal_block_index(blocks, sorted_bboxes)
+
+    """将image和table的block还原回group形式参与后续流程"""
+    blocks = revert_group_blocks(blocks)
+
+    """重排block"""
+    sorted_blocks = sorted(blocks, key=lambda b: b['index'])
+
+    """block内重排(img和table的block内多个caption或footnote的排序)"""
+    for block in sorted_blocks:
+        if block['type'] in [BlockType.IMAGE, BlockType.TABLE]:
+            block['blocks'] = sorted(block['blocks'], key=lambda b: b['index'])
+
+    return sorted_blocks
+
+
+def get_line_height(blocks):
+    page_line_height_list = []
+    for block in blocks:
+        if block['type'] in [
+            BlockType.TEXT, BlockType.TITLE,
+            BlockType.IMAGE_CAPTION, BlockType.IMAGE_FOOTNOTE,
+            BlockType.TABLE_CAPTION, BlockType.TABLE_FOOTNOTE
+        ]:
+            for line in block['lines']:
+                bbox = line['bbox']
+                page_line_height_list.append(int(bbox[3] - bbox[1]))
+    if len(page_line_height_list) > 0:
+        return statistics.median(page_line_height_list)
+    else:
+        return 10
+
+
+def sort_lines_by_model(fix_blocks, page_w, page_h, line_height, footnote_blocks):
+    page_line_list = []
+
+    def add_lines_to_block(b):
+        line_bboxes = insert_lines_into_block(b['bbox'], line_height, page_w, page_h)
+        b['lines'] = []
+        for line_bbox in line_bboxes:
+            b['lines'].append({'bbox': line_bbox, 'spans': []})
+        page_line_list.extend(line_bboxes)
+
+    for block in fix_blocks:
+        if block['type'] in [
+            BlockType.TEXT, BlockType.TITLE,
+            BlockType.IMAGE_CAPTION, BlockType.IMAGE_FOOTNOTE,
+            BlockType.TABLE_CAPTION, BlockType.TABLE_FOOTNOTE
+        ]:
+            if len(block['lines']) == 0:
+                add_lines_to_block(block)
+            elif block['type'] in [BlockType.TITLE] and len(block['lines']) == 1 and (block['bbox'][3] - block['bbox'][1]) > line_height * 2:
+                block['real_lines'] = copy.deepcopy(block['lines'])
+                add_lines_to_block(block)
+            else:
+                for line in block['lines']:
+                    bbox = line['bbox']
+                    page_line_list.append(bbox)
+        elif block['type'] in [BlockType.IMAGE_BODY, BlockType.TABLE_BODY, BlockType.INTERLINE_EQUATION]:
+            block['real_lines'] = copy.deepcopy(block['lines'])
+            add_lines_to_block(block)
+
+    for block in footnote_blocks:
+        footnote_block = {'bbox': block[:4]}
+        add_lines_to_block(footnote_block)
+
+    if len(page_line_list) > 200:  # layoutreader最高支持512line
+        return None
+
+    # 使用layoutreader排序
+    x_scale = 1000.0 / page_w
+    y_scale = 1000.0 / page_h
+    boxes = []
+    # logger.info(f"Scale: {x_scale}, {y_scale}, Boxes len: {len(page_line_list)}")
+    for left, top, right, bottom in page_line_list:
+        if left < 0:
+            logger.warning(
+                f'left < 0, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}'
+            )  # noqa: E501
+            left = 0
+        if right > page_w:
+            logger.warning(
+                f'right > page_w, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}'
+            )  # noqa: E501
+            right = page_w
+        if top < 0:
+            logger.warning(
+                f'top < 0, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}'
+            )  # noqa: E501
+            top = 0
+        if bottom > page_h:
+            logger.warning(
+                f'bottom > page_h, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}'
+            )  # noqa: E501
+            bottom = page_h
+
+        left = round(left * x_scale)
+        top = round(top * y_scale)
+        right = round(right * x_scale)
+        bottom = round(bottom * y_scale)
+        assert (
+            1000 >= right >= left >= 0 and 1000 >= bottom >= top >= 0
+        ), f'Invalid box. right: {right}, left: {left}, bottom: {bottom}, top: {top}'  # noqa: E126, E121
+        boxes.append([left, top, right, bottom])
+    model_manager = ModelSingleton()
+    model = model_manager.get_model('layoutreader')
+    with torch.no_grad():
+        orders = do_predict(boxes, model)
+    sorted_bboxes = [page_line_list[i] for i in orders]
+
+    return sorted_bboxes
+
+
+def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
+    # block_bbox是一个元组(x0, y0, x1, y1),其中(x0, y0)是左下角坐标,(x1, y1)是右上角坐标
+    x0, y0, x1, y1 = block_bbox
+
+    block_height = y1 - y0
+    block_weight = x1 - x0
+
+    # 如果block高度小于n行正文,则直接返回block的bbox
+    if line_height * 2 < block_height:
+        if (
+            block_height > page_h * 0.25 and page_w * 0.5 > block_weight > page_w * 0.25
+        ):  # 可能是双列结构,可以切细点
+            lines = int(block_height / line_height)
+        else:
+            # 如果block的宽度超过0.4页面宽度,则将block分成3行(是一种复杂布局,图不能切的太细)
+            if block_weight > page_w * 0.4:
+                lines = 3
+            elif block_weight > page_w * 0.25:  # (可能是三列结构,也切细点)
+                lines = int(block_height / line_height)
+            else:  # 判断长宽比
+                if block_height / block_weight > 1.2:  # 细长的不分
+                    return [[x0, y0, x1, y1]]
+                else:  # 不细长的还是分成两行
+                    lines = 2
+
+        line_height = (y1 - y0) / lines
+
+        # 确定从哪个y位置开始绘制线条
+        current_y = y0
+
+        # 用于存储线条的位置信息[(x0, y), ...]
+        lines_positions = []
+
+        for i in range(lines):
+            lines_positions.append([x0, current_y, x1, current_y + line_height])
+            current_y += line_height
+        return lines_positions
+
+    else:
+        return [[x0, y0, x1, y1]]
+
+
+def model_init(model_name: str):
+    from transformers import LayoutLMv3ForTokenClassification
+    device_name = get_device()
+    bf_16_support = False
+    if device_name.startswith("cuda"):
+        bf_16_support = torch.cuda.is_bf16_supported()
+    elif device_name.startswith("mps"):
+        bf_16_support = True
+
+    device = torch.device(device_name)
+    if model_name == 'layoutreader':
+        # 检测modelscope的缓存目录是否存在
+        layoutreader_model_dir = get_local_layoutreader_model_dir()
+        if os.path.exists(layoutreader_model_dir):
+            model = LayoutLMv3ForTokenClassification.from_pretrained(
+                layoutreader_model_dir
+            )
+        else:
+            logger.warning(
+                'local layoutreader model not exists, use online model from huggingface'
+            )
+            model = LayoutLMv3ForTokenClassification.from_pretrained(
+                'hantian/layoutreader'
+            )
+        if bf_16_support:
+            model.to(device).eval().bfloat16()
+        else:
+            model.to(device).eval()
+    else:
+        logger.error('model name not allow')
+        exit(1)
+    return model
+
+
+class ModelSingleton:
+    _instance = None
+    _models = {}
+
+    def __new__(cls, *args, **kwargs):
+        if cls._instance is None:
+            cls._instance = super().__new__(cls)
+        return cls._instance
+
+    def get_model(self, model_name: str):
+        if model_name not in self._models:
+            self._models[model_name] = model_init(model_name=model_name)
+        return self._models[model_name]
+
+
+def do_predict(boxes: List[List[int]], model) -> List[int]:
+    from mineru.model.reading_order.layout_reader import (
+        boxes2inputs, parse_logits, prepare_inputs)
+
+    with warnings.catch_warnings():
+        warnings.filterwarnings("ignore", category=FutureWarning, module="transformers")
+
+        inputs = boxes2inputs(boxes)
+        inputs = prepare_inputs(inputs, model)
+        logits = model(**inputs).logits.cpu().squeeze(0)
+    return parse_logits(logits, len(boxes))
+
+
+def cal_block_index(fix_blocks, sorted_bboxes):
+
+    if sorted_bboxes is not None:
+        # 使用layoutreader排序
+        for block in fix_blocks:
+            line_index_list = []
+            if len(block['lines']) == 0:
+                block['index'] = sorted_bboxes.index(block['bbox'])
+            else:
+                for line in block['lines']:
+                    line['index'] = sorted_bboxes.index(line['bbox'])
+                    line_index_list.append(line['index'])
+                median_value = statistics.median(line_index_list)
+                block['index'] = median_value
+
+            # 删除图表body block中的虚拟line信息, 并用real_lines信息回填
+            if block['type'] in [BlockType.IMAGE_BODY, BlockType.TABLE_BODY, BlockType.TITLE, BlockType.INTERLINE_EQUATION]:
+                if 'real_lines' in block:
+                    block['virtual_lines'] = copy.deepcopy(block['lines'])
+                    block['lines'] = copy.deepcopy(block['real_lines'])
+                    del block['real_lines']
+    else:
+        # 使用xycut排序
+        block_bboxes = []
+        for block in fix_blocks:
+            # 如果block['bbox']任意值小于0,将其置为0
+            block['bbox'] = [max(0, x) for x in block['bbox']]
+            block_bboxes.append(block['bbox'])
+
+            # 删除图表body block中的虚拟line信息, 并用real_lines信息回填
+            if block['type'] in [BlockType.IMAGE_BODY, BlockType.TABLE_BODY, BlockType.TITLE, BlockType.INTERLINE_EQUATION]:
+                if 'real_lines' in block:
+                    block['virtual_lines'] = copy.deepcopy(block['lines'])
+                    block['lines'] = copy.deepcopy(block['real_lines'])
+                    del block['real_lines']
+
+        import numpy as np
+        from mineru.model.reading_order.xycut import recursive_xy_cut
+
+        random_boxes = np.array(block_bboxes)
+        np.random.shuffle(random_boxes)
+        res = []
+        recursive_xy_cut(np.asarray(random_boxes).astype(int), np.arange(len(block_bboxes)), res)
+        assert len(res) == len(block_bboxes)
+        sorted_boxes = random_boxes[np.array(res)].tolist()
+
+        for i, block in enumerate(fix_blocks):
+            block['index'] = sorted_boxes.index(block['bbox'])
+
+        # 生成line index
+        sorted_blocks = sorted(fix_blocks, key=lambda b: b['index'])
+        line_inedx = 1
+        for block in sorted_blocks:
+            for line in block['lines']:
+                line['index'] = line_inedx
+                line_inedx += 1
+
+    return fix_blocks
+
+
+def revert_group_blocks(blocks):
+    image_groups = {}
+    table_groups = {}
+    new_blocks = []
+    for block in blocks:
+        if block['type'] in [BlockType.IMAGE_BODY, BlockType.IMAGE_CAPTION, BlockType.IMAGE_FOOTNOTE]:
+            group_id = block['group_id']
+            if group_id not in image_groups:
+                image_groups[group_id] = []
+            image_groups[group_id].append(block)
+        elif block['type'] in [BlockType.TABLE_BODY, BlockType.TABLE_CAPTION, BlockType.TABLE_FOOTNOTE]:
+            group_id = block['group_id']
+            if group_id not in table_groups:
+                table_groups[group_id] = []
+            table_groups[group_id].append(block)
+        else:
+            new_blocks.append(block)
+
+    for group_id, blocks in image_groups.items():
+        new_blocks.append(process_block_list(blocks, BlockType.IMAGE_BODY, BlockType.IMAGE))
+
+    for group_id, blocks in table_groups.items():
+        new_blocks.append(process_block_list(blocks, BlockType.TABLE_BODY, BlockType.TABLE))
+
+    return new_blocks
+
+
+def process_block_list(blocks, body_type, block_type):
+    indices = [block['index'] for block in blocks]
+    median_index = statistics.median(indices)
+
+    body_bbox = next((block['bbox'] for block in blocks if block.get('type') == body_type), [])
+
+    return {
+        'type': block_type,
+        'bbox': body_bbox,
+        'blocks': blocks,
+        'index': median_index,
+    }

+ 2 - 2
mineru/utils/cut_image.py

@@ -3,14 +3,14 @@ from loguru import logger
 from .pdf_image_tools import cut_image
 
 
-def cut_image_and_table(span, page_pil_img, page_img_md5, page_id, imageWriter, scale=2):
+def cut_image_and_table(span, page_pil_img, page_img_md5, page_id, image_writer, scale=2):
 
     def return_path(path_type):
         return f"{path_type}/{page_img_md5}"
 
     span_type = span["type"]
 
-    if not check_img_bbox(span["bbox"]) or not imageWriter:
+    if not check_img_bbox(span["bbox"]) or not image_writer:
         span["image_path"] = ""
     else:
         span["image_path"] = cut_image(

+ 1 - 1
mineru/utils/pdf_image_tools.py

@@ -54,7 +54,7 @@ def load_images_from_pdf(
     return images_list, pdf_doc
 
 
-def cut_image(bbox: tuple, page_num: int, page_pil_img, return_path, imageWriter: FileBasedDataWriter, scale=3):
+def cut_image(bbox: tuple, page_num: int, page_pil_img, return_path, imageWriter: FileBasedDataWriter, scale=2):
     """从第page_num页的page中,根据bbox进行裁剪出一张jpg图片,返回图片路径 save_path:需要同时支持s3和本地,
     图片存放在save_path下,文件名是:
     {page_num}_{bbox[0]}_{bbox[1]}_{bbox[2]}_{bbox[3]}.jpg , bbox内数字取整。"""

+ 162 - 0
mineru/utils/span_block_fix.py

@@ -0,0 +1,162 @@
+# Copyright (c) Opendatalab. All rights reserved.
+from mineru.utils.boxbase import calculate_overlap_area_in_bbox1_area_ratio
+from mineru.utils.enum_class import BlockType, ContentType
+from mineru.utils.ocr_utils import __is_overlaps_y_exceeds_threshold
+
+
+def fill_spans_in_blocks(blocks, spans, radio):
+    """将allspans中的span按位置关系,放入blocks中."""
+    block_with_spans = []
+    for block in blocks:
+        block_type = block[7]
+        block_bbox = block[0:4]
+        block_dict = {
+            'type': block_type,
+            'bbox': block_bbox,
+        }
+        if block_type in [
+            BlockType.IMAGE_BODY, BlockType.IMAGE_CAPTION, BlockType.IMAGE_FOOTNOTE,
+            BlockType.TABLE_BODY, BlockType.TABLE_CAPTION, BlockType.TABLE_FOOTNOTE
+        ]:
+            block_dict['group_id'] = block[-1]
+        block_spans = []
+        for span in spans:
+            span_bbox = span['bbox']
+            if calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > radio and span_block_type_compatible(
+                    span['type'], block_type):
+                block_spans.append(span)
+
+        block_dict['spans'] = block_spans
+        block_with_spans.append(block_dict)
+
+        # 从spans删除已经放入block_spans中的span
+        if len(block_spans) > 0:
+            for span in block_spans:
+                spans.remove(span)
+
+    return block_with_spans, spans
+
+
+def span_block_type_compatible(span_type, block_type):
+    if span_type in [ContentType.TEXT, ContentType.INTERLINE_EQUATION]:
+        return block_type in [
+            BlockType.TEXT,
+            BlockType.TITLE,
+            BlockType.IMAGE_CAPTION,
+            BlockType.IMAGE_FOOTNOTE,
+            BlockType.TABLE_CAPTION,
+            BlockType.TABLE_FOOTNOTE,
+            BlockType.DISCARDED
+        ]
+    elif span_type == ContentType.INTERLINE_EQUATION:
+        return block_type in [BlockType.INTERLINE_EQUATION, BlockType.TEXT]
+    elif span_type == ContentType.IMAGE:
+        return block_type in [BlockType.IMAGE_BODY]
+    elif span_type == ContentType.TABLE:
+        return block_type in [BlockType.TABLE_BODY]
+    else:
+        return False
+
+
+def fix_discarded_block(discarded_block_with_spans):
+    fix_discarded_blocks = []
+    for block in discarded_block_with_spans:
+        block = fix_text_block(block)
+        fix_discarded_blocks.append(block)
+    return fix_discarded_blocks
+
+
+def fix_text_block(block):
+    # 文本block中的公式span都应该转换成行内type
+    for span in block['spans']:
+        if span['type'] == ContentType.INTERLINE_EQUATION:
+            span['type'] = ContentType.INLINE_EQUATION
+    block_lines = merge_spans_to_line(block['spans'])
+    sort_block_lines = line_sort_spans_by_left_to_right(block_lines)
+    block['lines'] = sort_block_lines
+    del block['spans']
+    return block
+
+
+def merge_spans_to_line(spans, threshold=0.6):
+    if len(spans) == 0:
+        return []
+    else:
+        # 按照y0坐标排序
+        spans.sort(key=lambda span: span['bbox'][1])
+
+        lines = []
+        current_line = [spans[0]]
+        for span in spans[1:]:
+            # 如果当前的span类型为"interline_equation" 或者 当前行中已经有"interline_equation"
+            # image和table类型,同上
+            if span['type'] in [
+                    ContentType.INTERLINE_EQUATION, ContentType.IMAGE,
+                    ContentType.TABLE
+            ] or any(s['type'] in [
+                    ContentType.INTERLINE_EQUATION, ContentType.IMAGE,
+                    ContentType.TABLE
+            ] for s in current_line):
+                # 则开始新行
+                lines.append(current_line)
+                current_line = [span]
+                continue
+
+            # 如果当前的span与当前行的最后一个span在y轴上重叠,则添加到当前行
+            if __is_overlaps_y_exceeds_threshold(span['bbox'], current_line[-1]['bbox'], threshold):
+                current_line.append(span)
+            else:
+                # 否则,开始新行
+                lines.append(current_line)
+                current_line = [span]
+
+        # 添加最后一行
+        if current_line:
+            lines.append(current_line)
+
+        return lines
+
+
+# 将每一个line中的span从左到右排序
+def line_sort_spans_by_left_to_right(lines):
+    line_objects = []
+    for line in lines:
+        #  按照x0坐标排序
+        line.sort(key=lambda span: span['bbox'][0])
+        line_bbox = [
+            min(span['bbox'][0] for span in line),  # x0
+            min(span['bbox'][1] for span in line),  # y0
+            max(span['bbox'][2] for span in line),  # x1
+            max(span['bbox'][3] for span in line),  # y1
+        ]
+        line_objects.append({
+            'bbox': line_bbox,
+            'spans': line,
+        })
+    return line_objects
+
+
+def fix_block_spans(block_with_spans):
+    fix_blocks = []
+    for block in block_with_spans:
+        block_type = block['type']
+
+        if block_type in [BlockType.TEXT, BlockType.TITLE,
+                          BlockType.IMAGE_CAPTION, BlockType.IMAGE_CAPTION,
+                          BlockType.TABLE_CAPTION, BlockType.TABLE_FOOTNOTE
+                          ]:
+            block = fix_text_block(block)
+        elif block_type in [BlockType.INTERLINE_EQUATION, BlockType.IMAGE_BODY, BlockType.TABLE_BODY]:
+            block = fix_interline_block(block)
+        else:
+            continue
+        fix_blocks.append(block)
+    return fix_blocks
+
+
+def fix_interline_block(block):
+    block_lines = merge_spans_to_line(block['spans'])
+    sort_block_lines = line_sort_spans_by_left_to_right(block_lines)
+    block['lines'] = sort_block_lines
+    del block['spans']
+    return block

+ 163 - 0
mineru/utils/span_pre_proc.py

@@ -0,0 +1,163 @@
+# Copyright (c) Opendatalab. All rights reserved.
+import cv2
+import numpy as np
+
+from mineru.utils.boxbase import calculate_overlap_area_in_bbox1_area_ratio, calculate_iou, \
+    get_minbox_if_overlap_by_ratio
+from mineru.utils.enum_class import BlockType, ContentType
+from mineru.utils.pdf_image_tools import get_crop_img
+
+
+def remove_outside_spans(spans, all_bboxes, all_discarded_blocks):
+    def get_block_bboxes(blocks, block_type_list):
+        return [block[0:4] for block in blocks if block[7] in block_type_list]
+
+    image_bboxes = get_block_bboxes(all_bboxes, [BlockType.IMAGE_BODY])
+    table_bboxes = get_block_bboxes(all_bboxes, [BlockType.TABLE_BODY])
+    other_block_type = []
+    for block_type in BlockType.__dict__.values():
+        if not isinstance(block_type, str):
+            continue
+        if block_type not in [BlockType.IMAGE_BODY, BlockType.TABLE_BODY]:
+            other_block_type.append(block_type)
+    other_block_bboxes = get_block_bboxes(all_bboxes, other_block_type)
+    discarded_block_bboxes = get_block_bboxes(all_discarded_blocks, [BlockType.DISCARDED])
+
+    new_spans = []
+
+    for span in spans:
+        span_bbox = span['bbox']
+        span_type = span['type']
+
+        if any(calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > 0.4 for block_bbox in
+               discarded_block_bboxes):
+            new_spans.append(span)
+            continue
+
+        if span_type == ContentType.IMAGE:
+            if any(calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > 0.5 for block_bbox in
+                   image_bboxes):
+                new_spans.append(span)
+        elif span_type == ContentType.TABLE:
+            if any(calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > 0.5 for block_bbox in
+                   table_bboxes):
+                new_spans.append(span)
+        else:
+            if any(calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > 0.5 for block_bbox in
+                   other_block_bboxes):
+                new_spans.append(span)
+
+    return new_spans
+
+
+def remove_overlaps_low_confidence_spans(spans):
+    dropped_spans = []
+    #  删除重叠spans中置信度低的的那些
+    for span1 in spans:
+        for span2 in spans:
+            if span1 != span2:
+                # span1 或 span2 任何一个都不应该在 dropped_spans 中
+                if span1 in dropped_spans or span2 in dropped_spans:
+                    continue
+                else:
+                    if calculate_iou(span1['bbox'], span2['bbox']) > 0.9:
+                        if span1['score'] < span2['score']:
+                            span_need_remove = span1
+                        else:
+                            span_need_remove = span2
+                        if (
+                            span_need_remove is not None
+                            and span_need_remove not in dropped_spans
+                        ):
+                            dropped_spans.append(span_need_remove)
+
+    if len(dropped_spans) > 0:
+        for span_need_remove in dropped_spans:
+            spans.remove(span_need_remove)
+
+    return spans, dropped_spans
+
+
+def remove_overlaps_min_spans(spans):
+    dropped_spans = []
+    #  删除重叠spans中较小的那些
+    for span1 in spans:
+        for span2 in spans:
+            if span1 != span2:
+                # span1 或 span2 任何一个都不应该在 dropped_spans 中
+                if span1 in dropped_spans or span2 in dropped_spans:
+                    continue
+                else:
+                    overlap_box = get_minbox_if_overlap_by_ratio(span1['bbox'], span2['bbox'], 0.65)
+                    if overlap_box is not None:
+                        span_need_remove = next((span for span in spans if span['bbox'] == overlap_box), None)
+                        if span_need_remove is not None and span_need_remove not in dropped_spans:
+                            dropped_spans.append(span_need_remove)
+    if len(dropped_spans) > 0:
+        for span_need_remove in dropped_spans:
+            spans.remove(span_need_remove)
+
+    return spans, dropped_spans
+
+
+def txt_spans_extract(pdf_page, spans, pil_img, scale):
+
+    textpage = pdf_page.get_textpage()
+    width, height = pdf_page.get_size()
+    cropbox = pdf_page.get_cropbox()
+    need_ocr_spans = []
+    for span in spans:
+        span_bbox = span['bbox']
+        rect_box = [span_bbox[0] + cropbox[0],
+                    height - span_bbox[3] + cropbox[1],
+                    span_bbox[2] + cropbox[0],
+                    height - span_bbox[1] + cropbox[1]]
+        text = textpage.get_text_bounded(left=rect_box[0], top=rect_box[1],
+                                         right=rect_box[2], bottom=rect_box[3])
+        if text and len(text) > 0:
+            span['content'] = text.strip()
+            span['score'] = 1.0
+        else:
+            need_ocr_spans.append(span)
+
+    if len(need_ocr_spans) > 0:
+
+        for span in need_ocr_spans:
+            # 对span的bbox截图再ocr
+            span_pil_img = get_crop_img(span['bbox'], pil_img, scale)
+            span_img = cv2.cvtColor(np.array(span_pil_img), cv2.COLOR_RGB2BGR)
+            # 计算span的对比度,低于0.20的span不进行ocr
+            if calculate_contrast(span_img, img_mode='bgr') <= 0.17:
+                spans.remove(span)
+                continue
+
+            span['content'] = ''
+            span['score'] = 1.0
+            span['np_img'] = span_img
+
+    return spans
+
+
+def calculate_contrast(img, img_mode) -> float:
+    """
+    计算给定图像的对比度。
+    :param img: 图像,类型为numpy.ndarray
+    :Param img_mode = 图像的色彩通道,'rgb' 或 'bgr'
+    :return: 图像的对比度值
+    """
+    if img_mode == 'rgb':
+        # 将RGB图像转换为灰度图
+        gray_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
+    elif img_mode == 'bgr':
+        # 将BGR图像转换为灰度图
+        gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+    else:
+        raise ValueError("Invalid image mode. Please provide 'rgb' or 'bgr'.")
+
+    # 计算均值和标准差
+    mean_value = np.mean(gray_img)
+    std_dev = np.std(gray_img)
+    # 对比度定义为标准差除以平均值(加上小常数避免除零错误)
+    contrast = std_dev / (mean_value + 1e-6)
+    # logger.debug(f"contrast: {contrast}")
+    return round(contrast, 2)