فهرست منبع

feat: add batch processing for OCR detection and implement new client and common utilities

myhloli 5 ماه پیش
والد
کامیت
0a899f1af8

+ 189 - 0
mineru/api/vlm_middle_json_mkcontent.py

@@ -0,0 +1,189 @@
+import re
+from ..utils.enum_class import MakeMode, BlockType, ContentType
+
+
+def merge_para_with_text(para_block):
+
+    para_text = ''
+    for line in para_block['lines']:
+        for span in line['spans']:
+            content = span['content']
+            content = content.strip()
+
+            if content:
+                para_text += content
+            else:
+                continue
+
+    return para_text
+
+def mk_blocks_to_markdown(para_blocks, make_mode, img_buket_path=''):
+    page_markdown = []
+    for para_block in para_blocks:
+        para_text = ''
+        para_type = para_block['type']
+        if para_type in [BlockType.TEXT, BlockType.LIST, BlockType.INDEX, BlockType.TITLE, BlockType.INTERLINE_EQUATION]:
+            para_text = merge_para_with_text(para_block)
+        elif para_type == BlockType.IMAGE:
+            if make_mode == MakeMode.NLP_MD:
+                continue
+            elif make_mode == MakeMode.MM_MD:
+                # 检测是否存在图片脚注
+                has_image_footnote = any(block['type'] == BlockType.IMAGE_FOOTNOTE for block in para_block['blocks'])
+                # 如果存在图片脚注,则将图片脚注拼接到图片正文后面
+                if has_image_footnote:
+                    for block in para_block['blocks']:  # 1st.拼image_caption
+                        if block['type'] == BlockType.IMAGE_CAPTION:
+                            para_text += merge_para_with_text(block) + '  \n'
+                    for block in para_block['blocks']:  # 2nd.拼image_body
+                        if block['type'] == BlockType.IMAGE_BODY:
+                            for line in block['lines']:
+                                for span in line['spans']:
+                                    if span['type'] == ContentType.IMAGE:
+                                        if span.get('image_path', ''):
+                                            para_text += f"![]({img_buket_path}/{span['image_path']})"
+                    for block in para_block['blocks']:  # 3rd.拼image_footnote
+                        if block['type'] == BlockType.IMAGE_FOOTNOTE:
+                            para_text += '  \n' + merge_para_with_text(block)
+                else:
+                    for block in para_block['blocks']:  # 1st.拼image_body
+                        if block['type'] == BlockType.IMAGE_BODY:
+                            for line in block['lines']:
+                                for span in line['spans']:
+                                    if span['type'] == ContentType.IMAGE:
+                                        if span.get('image_path', ''):
+                                            para_text += f"![]({img_buket_path}/{span['image_path']})"
+                    for block in para_block['blocks']:  # 2nd.拼image_caption
+                        if block['type'] == BlockType.IMAGE_CAPTION:
+                            para_text += '  \n' + merge_para_with_text(block)
+
+        elif para_type == BlockType.TABLE:
+            if make_mode == MakeMode.NLP_MD:
+                continue
+            elif make_mode == MakeMode.MM_MD:
+                for block in para_block['blocks']:  # 1st.拼table_caption
+                    if block['type'] == BlockType.TABLE_CAPTION:
+                        para_text += merge_para_with_text(block) + '  \n'
+                for block in para_block['blocks']:  # 2nd.拼table_body
+                    if block['type'] == BlockType.TABLE_BODY:
+                        for line in block['lines']:
+                            for span in line['spans']:
+                                if span['type'] == ContentType.TABLE:
+                                    # if processed by table model
+                                    if span.get('html', ''):
+                                        para_text += f"\n{span['html']}\n"
+                                    elif span.get('image_path', ''):
+                                        para_text += f"![]({img_buket_path}/{span['image_path']})"
+                for block in para_block['blocks']:  # 3rd.拼table_footnote
+                    if block['type'] == BlockType.TABLE_FOOTNOTE:
+                        para_text += '\n' + merge_para_with_text(block) + '  '
+
+        if para_text.strip() == '':
+            continue
+        else:
+            # page_markdown.append(para_text.strip() + '  ')
+            page_markdown.append(para_text.strip())
+
+    return page_markdown
+
+
+def count_leading_hashes(text):
+    match = re.match(r'^(#+)', text)
+    return len(match.group(1)) if match else 0
+
+def strip_leading_hashes(text):
+    # 去除开头的#和紧随其后的空格
+    return re.sub(r'^#+\s*', '', text)
+
+
+def make_blocks_to_content_list(para_block, img_buket_path, page_idx):
+    para_type = para_block['type']
+    para_content = {}
+    if para_type in [BlockType.TEXT, BlockType.LIST, BlockType.INDEX]:
+        para_content = {
+            'type': 'text',
+            'text': merge_para_with_text(para_block),
+        }
+    elif para_type == BlockType.TITLE:
+        title_content = merge_para_with_text(para_block)
+        title_level = count_leading_hashes(title_content)
+        para_content = {
+            'type': 'text',
+            'text': strip_leading_hashes(title_content),
+        }
+        if title_level != 0:
+            para_content['text_level'] = title_level
+    elif para_type == BlockType.INTERLINE_EQUATION:
+        para_content = {
+            'type': 'equation',
+            'text': merge_para_with_text(para_block),
+            'text_format': 'latex',
+        }
+    elif para_type == BlockType.IMAGE:
+        para_content = {'type': 'image', 'img_path': '', 'img_caption': [], 'img_footnote': []}
+        for block in para_block['blocks']:
+            if block['type'] == BlockType.IMAGE_BODY:
+                for line in block['lines']:
+                    for span in line['spans']:
+                        if span['type'] == ContentType.IMAGE:
+                            if span.get('image_path', ''):
+                                para_content['img_path'] = f"{img_buket_path}/{span['image_path']}"
+            if block['type'] == BlockType.IMAGE_CAPTION:
+                para_content['img_caption'].append(merge_para_with_text(block))
+            if block['type'] == BlockType.IMAGE_FOOTNOTE:
+                para_content['img_footnote'].append(merge_para_with_text(block))
+    elif para_type == BlockType.TABLE:
+        para_content = {'type': 'table', 'img_path': '', 'table_caption': [], 'table_footnote': []}
+        for block in para_block['blocks']:
+            if block['type'] == BlockType.TABLE_BODY:
+                for line in block['lines']:
+                    for span in line['spans']:
+                        if span['type'] == ContentType.TABLE:
+
+                            if span.get('html', ''):
+                                para_content['table_body'] = f"{span['html']}"
+
+                            if span.get('image_path', ''):
+                                para_content['img_path'] = f"{img_buket_path}/{span['image_path']}"
+
+            if block['type'] == BlockType.TABLE_CAPTION:
+                para_content['table_caption'].append(merge_para_with_text(block))
+            if block['type'] == BlockType.TABLE_FOOTNOTE:
+                para_content['table_footnote'].append(merge_para_with_text(block))
+
+    para_content['page_idx'] = page_idx
+
+    return para_content
+
+def union_make(pdf_info_dict: list,
+               make_mode: str,
+               img_buket_path: str = '',
+               ):
+    output_content = []
+    for page_info in pdf_info_dict:
+        paras_of_layout = page_info.get('para_blocks')
+        page_idx = page_info.get('page_idx')
+        if not paras_of_layout:
+            continue
+        if make_mode in [MakeMode.MM_MD, MakeMode.NLP_MD]:
+            page_markdown = mk_blocks_to_markdown(paras_of_layout, make_mode, img_buket_path)
+            output_content.extend(page_markdown)
+        elif make_mode == MakeMode.STANDARD_FORMAT:
+            for para_block in paras_of_layout:
+                para_content = make_blocks_to_content_list(para_block, img_buket_path, page_idx)
+                output_content.append(para_content)
+
+    if make_mode in [MakeMode.MM_MD, MakeMode.NLP_MD]:
+        return '\n\n'.join(output_content)
+    elif make_mode == MakeMode.STANDARD_FORMAT:
+        return output_content
+    return None
+
+
+def get_title_level(block):
+    title_level = block.get('level', 1)
+    if title_level > 4:
+        title_level = 4
+    elif title_level < 1:
+        title_level = 0
+    return title_level

+ 157 - 42
mineru/backend/pipeline/batch_analyze.py

@@ -1,6 +1,8 @@
 import cv2
 from loguru import logger
 from tqdm import tqdm
+from collections import defaultdict
+import numpy as np
 
 from .model_init import AtomModelSingleton
 from ...utils.model_utils import crop_img, get_res_list_from_layout_res, get_coords_and_area
@@ -12,11 +14,12 @@ MFR_BASE_BATCH_SIZE = 16
 
 
 class BatchAnalyze:
-    def __init__(self, model_manager, batch_ratio: int, formula_enable, table_enable):
+    def __init__(self, model_manager, batch_ratio: int, formula_enable, table_enable, enable_ocr_det_batch: bool = True):
         self.batch_ratio = batch_ratio
         self.formula_enable = formula_enable
         self.table_enable = table_enable
         self.model_manager = model_manager
+        self.enable_ocr_det_batch = enable_ocr_det_batch
 
     def __call__(self, images_with_extra_info: list) -> list:
         if len(images_with_extra_info) == 0:
@@ -89,48 +92,160 @@ class BatchAnalyze:
                                                 'table_img':table_img,
                                               })
 
-        # 文本框检测
-
-        for ocr_res_list_dict in tqdm(ocr_res_list_all_page, desc="OCR-det Predict"):
-            # Process each area that requires OCR processing
-            _lang = ocr_res_list_dict['lang']
-            # Get OCR results for this language's images
-            ocr_model = atom_model_manager.get_atom_model(
-                atom_model_name='ocr',
-                det_db_box_thresh=0.3,
-                lang=_lang
-            )
-            for res in ocr_res_list_dict['ocr_res_list']:
-                new_image, useful_list = crop_img(
-                    res, ocr_res_list_dict['np_array_img'], crop_paste_x=50, crop_paste_y=50
-                )
-                adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
-                    ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
-                )
+                # OCR检测处理
+                if self.enable_ocr_det_batch:
+                    # 批处理模式 - 按语言和分辨率分组
+                    # 收集所有需要OCR检测的裁剪图像
+                    all_cropped_images_info = []
+
+                    for ocr_res_list_dict in ocr_res_list_all_page:
+                        _lang = ocr_res_list_dict['lang']
+
+                        for res in ocr_res_list_dict['ocr_res_list']:
+                            new_image, useful_list = crop_img(
+                                res, ocr_res_list_dict['np_array_img'], crop_paste_x=50, crop_paste_y=50
+                            )
+                            adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
+                                ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
+                            )
+
+                            # BGR转换
+                            new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
+
+                            all_cropped_images_info.append((
+                                new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang
+                            ))
+
+                    # 按语言分组
+                    lang_groups = defaultdict(list)
+                    for crop_info in all_cropped_images_info:
+                        lang = crop_info[5]
+                        lang_groups[lang].append(crop_info)
+
+                    # 对每种语言按分辨率分组并批处理
+                    for lang, lang_crop_list in lang_groups.items():
+                        if not lang_crop_list:
+                            continue
+
+                        # logger.info(f"Processing OCR detection for language {lang} with {len(lang_crop_list)} images")
+
+                        # 获取OCR模型
+                        ocr_model = atom_model_manager.get_atom_model(
+                            atom_model_name='ocr',
+                            ocr_show_log=False,
+                            det_db_box_thresh=0.3,
+                            lang=lang
+                        )
 
-                # OCR-det
-                new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
-                ocr_res = ocr_model.ocr(
-                    new_image, mfd_res=adjusted_mfdetrec_res, rec=False
-                )[0]
-
-                # Integration results
-                if ocr_res:
-                    ocr_result_list = get_ocr_result_list(ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], new_image, _lang)
-
-                    if res["category_id"] == 3:
-                        # ocr_result_list中所有bbox的面积之和
-                        ocr_res_area = sum(get_coords_and_area(ocr_res_item)[4] for ocr_res_item in ocr_result_list if 'poly' in ocr_res_item)
-                        # 求ocr_res_area和res的面积的比值
-                        res_area = get_coords_and_area(res)[4]
-                        if res_area > 0:
-                            ratio = ocr_res_area / res_area
-                            if ratio > 0.25:
-                                res["category_id"] = 1
-                            else:
-                                continue
-
-                    ocr_res_list_dict['layout_res'].extend(ocr_result_list)
+                        # 按分辨率分组并同时完成padding
+                        resolution_groups = defaultdict(list)
+                        for crop_info in lang_crop_list:
+                            cropped_img = crop_info[0]
+                            h, w = cropped_img.shape[:2]
+                            # 使用更大的分组容差,减少分组数量
+                            # 将尺寸标准化到32的倍数
+                            normalized_h = ((h + 32) // 32) * 32  # 向上取整到32的倍数
+                            normalized_w = ((w + 32) // 32) * 32
+                            group_key = (normalized_h, normalized_w)
+                            resolution_groups[group_key].append(crop_info)
+
+                        # 对每个分辨率组进行批处理
+                        for group_key, group_crops in tqdm(resolution_groups.items(), desc=f"OCR-det {lang}"):
+                            raw_images = [crop_info[0] for crop_info in group_crops]
+
+                            # 计算目标尺寸(组内最大尺寸,向上取整到32的倍数)
+                            max_h = max(img.shape[0] for img in raw_images)
+                            max_w = max(img.shape[1] for img in raw_images)
+                            target_h = ((max_h + 32 - 1) // 32) * 32
+                            target_w = ((max_w + 32 - 1) // 32) * 32
+
+                            # 对所有图像进行padding到统一尺寸
+                            batch_images = []
+                            for img in raw_images:
+                                h, w = img.shape[:2]
+                                # 创建目标尺寸的白色背景
+                                padded_img = np.ones((target_h, target_w, 3), dtype=np.uint8) * 255
+                                # 将原图像粘贴到左上角
+                                padded_img[:h, :w] = img
+                                batch_images.append(padded_img)
+
+                            # 批处理检测
+                            batch_size = min(len(batch_images), self.batch_ratio * 16)  # 增加批处理大小
+                            # logger.debug(f"OCR-det batch: {batch_size} images, target size: {target_h}x{target_w}")
+                            batch_results = ocr_model.text_detector.batch_predict(batch_images, batch_size)
+
+                            # 处理批处理结果
+                            for i, (crop_info, (dt_boxes, elapse)) in enumerate(zip(group_crops, batch_results)):
+                                new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang = crop_info
+
+                                if dt_boxes is not None:
+                                    # 构造OCR结果格式 - 每个box应该是4个点的列表
+                                    ocr_res = [box.tolist() for box in dt_boxes]
+
+                                    if ocr_res:
+                                        ocr_result_list = get_ocr_result_list(
+                                            ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], new_image, _lang
+                                        )
+
+                                        if res["category_id"] == 3:
+                                            # ocr_result_list中所有bbox的面积之和
+                                            ocr_res_area = sum(
+                                                get_coords_and_area(ocr_res_item)[4] for ocr_res_item in ocr_result_list if 'poly' in ocr_res_item)
+                                            # 求ocr_res_area和res的面积的比值
+                                            res_area = get_coords_and_area(res)[4]
+                                            if res_area > 0:
+                                                ratio = ocr_res_area / res_area
+                                                if ratio > 0.25:
+                                                    res["category_id"] = 1
+                                                else:
+                                                    continue
+
+                                        ocr_res_list_dict['layout_res'].extend(ocr_result_list)
+                else:
+                    # 原始单张处理模式
+                    for ocr_res_list_dict in tqdm(ocr_res_list_all_page, desc="OCR-det Predict"):
+                        # Process each area that requires OCR processing
+                        _lang = ocr_res_list_dict['lang']
+                        # Get OCR results for this language's images
+                        ocr_model = atom_model_manager.get_atom_model(
+                            atom_model_name='ocr',
+                            ocr_show_log=False,
+                            det_db_box_thresh=0.3,
+                            lang=_lang
+                        )
+                        for res in ocr_res_list_dict['ocr_res_list']:
+                            new_image, useful_list = crop_img(
+                                res, ocr_res_list_dict['np_array_img'], crop_paste_x=50, crop_paste_y=50
+                            )
+                            adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
+                                ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
+                            )
+
+                        # OCR-det
+                        new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
+                        ocr_res = ocr_model.ocr(
+                            new_image, mfd_res=adjusted_mfdetrec_res, rec=False
+                        )[0]
+
+                        # Integration results
+                        if ocr_res:
+                            ocr_result_list = get_ocr_result_list(ocr_res, useful_list, ocr_res_list_dict['ocr_enable'],
+                                                                  new_image, _lang)
+
+                            if res["category_id"] == 3:
+                                # ocr_result_list中所有bbox的面积之和
+                                ocr_res_area = sum(
+                                    get_coords_and_area(ocr_res_item)[4] for ocr_res_item in ocr_result_list if 'poly' in ocr_res_item)
+                                # 求ocr_res_area和res的面积的比值
+                                res_area = get_coords_and_area(res)[4]
+                                if res_area > 0:
+                                    ratio = ocr_res_area / res_area
+                                    if ratio > 0.25:
+                                        res["category_id"] = 1
+                                    else:
+                                        continue
+
+                            ocr_res_list_dict['layout_res'].extend(ocr_result_list)
 
         # 表格识别 table recognition
         if self.table_enable:

+ 2 - 2
mineru/backend/vlm/token_to_middle_json.py

@@ -3,7 +3,7 @@ import re
 from mineru.utils.cut_image import cut_image_and_table
 from mineru.utils.enum_class import BlockType, ContentType
 from mineru.utils.hash_utils import str_md5
-from mineru.utils.magic_model import fix_two_layer_blocks
+from mineru.utils.vlm_magic_model import fix_two_layer_blocks
 from mineru.version import __version__
 
 
@@ -113,7 +113,7 @@ def token_to_page_info(token, image_dict, page, image_writer, page_index) -> dic
     # 对page_blocks根据index的值进行排序
     page_blocks.sort(key=lambda x: x["index"])
 
-    page_info = {"para_blocks": page_blocks, "page_size": [width, height], "page_idx": page_index}
+    page_info = {"para_blocks": page_blocks, "discarded_blocks": [], "page_size": [width, height], "page_idx": page_index}
     return page_info
 
 

+ 91 - 0
mineru/cli/client.py

@@ -0,0 +1,91 @@
+# Copyright (c) Opendatalab. All rights reserved.
+import os
+import click
+from pathlib import Path
+from loguru import logger
+from ..version import __version__
+from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
+
+
+@click.command()
+@click.version_option(__version__,
+                      '--version',
+                      '-v',
+                      help='display the version and exit')
+@click.option(
+    '-p',
+    '--path',
+    'input_path',
+    type=click.Path(exists=True),
+    required=True,
+    help='local filepath or directory. support pdf, png, jpg, jpeg files',
+)
+@click.option(
+    '-o',
+    '--output-dir',
+    'output_dir',
+    type=click.Path(),
+    required=True,
+    help='output local directory',
+)
+@click.option(
+    '-b',
+    '--backend',
+    'backend',
+    type=click.Choice(['pipeline', 'vlm-huggingface', 'vlm-sglang-engine', 'vlm-sglang-client']),
+    help="""the backend for parsing pdf:
+    pipeline: More general.
+    vlm-huggingface: More general.
+    vlm-sglang-engine: Faster(engine).
+    vlm-sglang-client: Faster(client).
+    without method specified, huggingface will be used by default.""",
+    default='pipeline',
+)
+@click.option(
+    '-u',
+    '--url',
+    'server_url',
+    type=str,
+    help="""
+    When the backend is `sglang-client`, you need to specify the server_url, for example:`http://127.0.0.1:30000`
+    """,
+    default=None,
+)
+@click.option(
+    '-s',
+    '--start',
+    'start_page_id',
+    type=int,
+    help='The starting page for PDF parsing, beginning from 0.',
+    default=0,
+)
+@click.option(
+    '-e',
+    '--end',
+    'end_page_id',
+    type=int,
+    help='The ending page for PDF parsing, beginning from 0.',
+    default=None,
+)
+
+def main(input_path, output_dir, backend, server_url, start_page_id, end_page_id):
+    os.makedirs(output_dir, exist_ok=True)
+
+    def parse_doc(path: Path):
+        try:
+            file_name = str(Path(path).stem)
+            pdf_bits = read_fn(path)
+            do_parse(output_dir, file_name, pdf_bits, backend, server_url,
+                     start_page_id=start_page_id, end_page_id=end_page_id)
+        except Exception as e:
+            logger.exception(e)
+
+    if os.path.isdir(input_path):
+        for doc_path in Path(input_path).glob('*'):
+            if doc_path.suffix in pdf_suffixes + image_suffixes:
+                parse_doc(Path(doc_path))
+    else:
+        parse_doc(Path(input_path))
+
+if __name__ == '__main__':
+    main()

+ 153 - 0
mineru/cli/common.py

@@ -0,0 +1,153 @@
+# Copyright (c) Opendatalab. All rights reserved.
+import io
+import json
+import os
+from pathlib import Path
+
+import pypdfium2 as pdfium
+from loguru import logger
+from ..api.vlm_middle_json_mkcontent import union_make
+from ..backend.vlm.vlm_analyze import doc_analyze
+from ..data.data_reader_writer import FileBasedDataWriter
+from ..utils.draw_bbox import draw_layout_bbox, draw_span_bbox
+from ..utils.enum_class import MakeMode
+from ..utils.pdf_image_tools import images_bytes_to_pdf_bytes
+
+pdf_suffixes = [".pdf"]
+image_suffixes = [".png", ".jpeg", ".jpg"]
+
+
+def read_fn(path: Path):
+    with open(str(path), "rb") as input_file:
+        file_bytes = input_file.read()
+        if path.suffix in image_suffixes:
+            return images_bytes_to_pdf_bytes(file_bytes)
+        elif path.suffix in pdf_suffixes:
+            return file_bytes
+        else:
+            raise Exception(f"Unknown file suffix: {path.suffix}")
+
+
+def prepare_env(output_dir, pdf_file_name):
+    local_parent_dir = os.path.join(output_dir, pdf_file_name)
+
+    local_image_dir = os.path.join(str(local_parent_dir), "images")
+    local_md_dir = local_parent_dir
+    os.makedirs(local_image_dir, exist_ok=True)
+    os.makedirs(local_md_dir, exist_ok=True)
+    return local_image_dir, local_md_dir
+
+
+def convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes, start_page_id=0, end_page_id=None):
+
+    # 从字节数据加载PDF
+    pdf = pdfium.PdfDocument(pdf_bytes)
+
+    # 确定结束页
+    end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else len(pdf) - 1
+    if end_page_id > len(pdf) - 1:
+        logger.warning("end_page_id is out of range, use pdf_docs length")
+        end_page_id = len(pdf) - 1
+
+    # 创建一个新的PDF文档
+    output_pdf = pdfium.PdfDocument.new()
+
+    # 选择要导入的页面索引
+    page_indices = list(range(start_page_id, end_page_id + 1))
+
+    # 从原PDF导入页面到新PDF
+    output_pdf.import_pages(pdf, page_indices)
+
+    # 将新PDF保存到内存缓冲区
+    output_buffer = io.BytesIO()
+    output_pdf.save(output_buffer)
+
+    # 获取字节数据
+    output_bytes = output_buffer.getvalue()
+
+    return output_bytes
+
+
+def do_parse(
+    output_dir,
+    pdf_file_name,
+    pdf_bytes,
+    backend="pipeline",
+    model_path="jinzhenj/OEEzRkQ3RTAtMDMx-0415",  # TODO: change to formal path after release.
+    server_url=None,
+    f_draw_layout_bbox=True,
+    f_draw_span_bbox=False,
+    f_dump_md=True,
+    f_dump_middle_json=True,
+    f_dump_model_output=True,
+    f_dump_orig_pdf=True,
+    f_dump_content_list=True,
+    f_make_md_mode=MakeMode.MM_MD,
+    start_page_id=0,
+    end_page_id=None,
+):
+    if backend == 'pipeline':
+        f_draw_span_bbox = True
+
+    pdf_bytes = convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes, start_page_id, end_page_id)
+    local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name)
+    image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir)
+
+    middle_json, infer_result = doc_analyze(pdf_bytes, image_writer=image_writer, backend=backend, server_url=server_url)
+    pdf_info = middle_json["pdf_info"]
+
+    if f_draw_layout_bbox:
+        draw_layout_bbox(pdf_info, pdf_bytes, local_md_dir, f"{pdf_file_name}_layout.pdf")
+
+    if f_draw_span_bbox:
+        draw_span_bbox(pdf_info, pdf_bytes, local_md_dir, f"{pdf_file_name}_span.pdf")
+
+    if f_dump_orig_pdf:
+        md_writer.write(
+            f"{pdf_file_name}_origin.pdf",
+            pdf_bytes,
+        )
+
+    if f_dump_md:
+        image_dir = str(os.path.basename(local_image_dir))
+        md_content_str = union_make(pdf_info, f_make_md_mode, image_dir)
+        md_writer.write_string(
+            f"{pdf_file_name}.md",
+            md_content_str,
+        )
+
+    if f_dump_content_list:
+        image_dir = str(os.path.basename(local_image_dir))
+        content_list = union_make(pdf_info, MakeMode.STANDARD_FORMAT, image_dir)
+        md_writer.write_string(
+            f"{pdf_file_name}_content_list.json",
+            json.dumps(content_list, ensure_ascii=False, indent=4),
+        )
+
+    if f_dump_middle_json:
+        md_writer.write_string(
+            f"{pdf_file_name}_middle.json",
+            json.dumps(middle_json, ensure_ascii=False, indent=4),
+        )
+
+    if f_dump_model_output:
+        model_output = ("\n" + "-" * 50 + "\n").join(infer_result)
+        md_writer.write_string(
+            f"{pdf_file_name}_model_output.txt",
+            model_output,
+        )
+
+    logger.info(f"local output dir is {local_md_dir}")
+
+    return infer_result
+
+
+if __name__ == "__main__":
+    pdf_path = "../../demo/demo2.pdf"
+    with open(pdf_path, "rb") as f:
+        try:
+            result = do_parse("./output", Path(pdf_path).stem, f.read())
+        except Exception as e:
+            logger.exception(e)
+        # dict转成json
+        print(json.dumps(result, ensure_ascii=False, indent=4))

+ 122 - 0
mineru/model/ocr/paddleocr2pytorch/tools/infer/predict_det.py

@@ -117,6 +117,128 @@ class TextDetector(BaseOCRV20):
         self.net.eval()
         self.net.to(self.device)
 
+    def _batch_process_same_size(self, img_list):
+        """
+            对相同尺寸的图像进行批处理
+
+            Args:
+                img_list: 相同尺寸的图像列表
+
+            Returns:
+                batch_results: 批处理结果列表
+                total_elapse: 总耗时
+            """
+        starttime = time.time()
+
+        # 预处理所有图像
+        batch_data = []
+        batch_shapes = []
+        ori_imgs = []
+
+        for img in img_list:
+            ori_im = img.copy()
+            ori_imgs.append(ori_im)
+
+            data = {'image': img}
+            data = transform(data, self.preprocess_op)
+            if data is None:
+                # 如果预处理失败,返回空结果
+                return [(None, 0) for _ in img_list], 0
+
+            img_processed, shape_list = data
+            batch_data.append(img_processed)
+            batch_shapes.append(shape_list)
+
+        # 堆叠成批处理张量
+        try:
+            batch_tensor = np.stack(batch_data, axis=0)
+            batch_shapes = np.stack(batch_shapes, axis=0)
+        except Exception as e:
+            # 如果堆叠失败,回退到逐个处理
+            batch_results = []
+            for img in img_list:
+                dt_boxes, elapse = self.__call__(img)
+                batch_results.append((dt_boxes, elapse))
+            return batch_results, time.time() - starttime
+
+        # 批处理推理
+        with torch.no_grad():
+            inp = torch.from_numpy(batch_tensor)
+            inp = inp.to(self.device)
+            outputs = self.net(inp)
+
+        # 处理输出
+        preds = {}
+        if self.det_algorithm == "EAST":
+            preds['f_geo'] = outputs['f_geo'].cpu().numpy()
+            preds['f_score'] = outputs['f_score'].cpu().numpy()
+        elif self.det_algorithm == 'SAST':
+            preds['f_border'] = outputs['f_border'].cpu().numpy()
+            preds['f_score'] = outputs['f_score'].cpu().numpy()
+            preds['f_tco'] = outputs['f_tco'].cpu().numpy()
+            preds['f_tvo'] = outputs['f_tvo'].cpu().numpy()
+        elif self.det_algorithm in ['DB', 'PSE', 'DB++']:
+            preds['maps'] = outputs['maps'].cpu().numpy()
+        elif self.det_algorithm == 'FCE':
+            for i, (k, output) in enumerate(outputs.items()):
+                preds['level_{}'.format(i)] = output.cpu().numpy()
+        else:
+            raise NotImplementedError
+
+        # 后处理每个图像的结果
+        batch_results = []
+        total_elapse = time.time() - starttime
+
+        for i in range(len(img_list)):
+            # 提取单个图像的预测结果
+            single_preds = {}
+            for key, value in preds.items():
+                if isinstance(value, np.ndarray):
+                    single_preds[key] = value[i:i + 1]  # 保持批次维度
+                else:
+                    single_preds[key] = value
+
+            # 后处理
+            post_result = self.postprocess_op(single_preds, batch_shapes[i:i + 1])
+            dt_boxes = post_result[0]['points']
+
+            # 过滤和裁剪检测框
+            if (self.det_algorithm == "SAST" and
+                self.det_sast_polygon) or (self.det_algorithm in ["PSE", "FCE"] and
+                                           self.postprocess_op.box_type == 'poly'):
+                dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_imgs[i].shape)
+            else:
+                dt_boxes = self.filter_tag_det_res(dt_boxes, ori_imgs[i].shape)
+
+            batch_results.append((dt_boxes, total_elapse / len(img_list)))
+
+        return batch_results, total_elapse
+
+    def batch_predict(self, img_list, max_batch_size=8):
+        """
+        批处理预测方法,支持多张图像同时检测
+
+        Args:
+            img_list: 图像列表
+            max_batch_size: 最大批处理大小
+
+        Returns:
+            batch_results: 批处理结果列表,每个元素为(dt_boxes, elapse)
+        """
+        if not img_list:
+            return []
+
+        batch_results = []
+
+        # 分批处理
+        for i in range(0, len(img_list), max_batch_size):
+            batch_imgs = img_list[i:i + max_batch_size]
+            # assert尺寸一致
+            batch_dt_boxes, batch_elapse = self._batch_process_same_size(batch_imgs)
+            batch_results.extend(batch_dt_boxes)
+
+        return batch_results
+
     def order_points_clockwise(self, pts):
         """
         reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py

+ 115 - 7
mineru/utils/draw_bbox.py

@@ -4,7 +4,7 @@ from io import BytesIO
 from PyPDF2 import PdfReader, PdfWriter
 from reportlab.pdfgen import canvas
 
-from .enum_class import BlockType
+from .enum_class import BlockType, ContentType
 
 
 def draw_bbox_without_number(i, bbox_list, page, c, rgb_config, fill_config):
@@ -54,7 +54,7 @@ def draw_bbox_with_number(i, bbox_list, page, c, rgb_config, fill_config, draw_b
 
 
 def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
-    # dropped_bbox_list = []
+    dropped_bbox_list = []
     tables_list, tables_body_list = [], []
     tables_caption_list, tables_footnote_list = [], []
     imgs_list, imgs_body_list, imgs_caption_list = [], [], []
@@ -65,7 +65,7 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
     lists_list = []
     indexs_list = []
     for page in pdf_info:
-        # page_dropped_list = []
+        page_dropped_list = []
         tables, tables_body, tables_caption, tables_footnote = [], [], [], []
         imgs, imgs_body, imgs_caption, imgs_footnote = [], [], [], []
         titles = []
@@ -74,9 +74,9 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
         lists = []
         indices = []
 
-        # for dropped_bbox in page['discarded_blocks']:
-        #     page_dropped_list.append(dropped_bbox['bbox'])
-        # dropped_bbox_list.append(page_dropped_list)
+        for dropped_bbox in page['discarded_blocks']:
+            page_dropped_list.append(dropped_bbox['bbox'])
+        dropped_bbox_list.append(page_dropped_list)
         for block in page["para_blocks"]:
             bbox = block["bbox"]
             if block["type"] == BlockType.TABLE:
@@ -164,7 +164,7 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
         # 使用原始PDF的尺寸创建canvas
         c = canvas.Canvas(packet, pagesize=custom_page_size)
 
-        # c = draw_bbox_without_number(i, dropped_bbox_list, page, c, [158, 158, 158], True)
+        c = draw_bbox_without_number(i, dropped_bbox_list, page, c, [158, 158, 158], True)
         c = draw_bbox_without_number(i, tables_body_list, page, c, [204, 204, 0], True)
         c = draw_bbox_without_number(i, tables_caption_list, page, c, [255, 255, 102], True)
         c = draw_bbox_without_number(i, tables_footnote_list, page, c, [229, 255, 204], True)
@@ -190,6 +190,114 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
         output_pdf.write(f)
 
 
+def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename):
+    text_list = []
+    inline_equation_list = []
+    interline_equation_list = []
+    image_list = []
+    table_list = []
+    dropped_list = []
+    next_page_text_list = []
+    next_page_inline_equation_list = []
+
+    def get_span_info(span):
+        if span['type'] == ContentType.TEXT:
+            if span.get('cross_page', False):
+                next_page_text_list.append(span['bbox'])
+            else:
+                page_text_list.append(span['bbox'])
+        elif span['type'] == ContentType.INLINE_EQUATION:
+            if span.get('cross_page', False):
+                next_page_inline_equation_list.append(span['bbox'])
+            else:
+                page_inline_equation_list.append(span['bbox'])
+        elif span['type'] == ContentType.INTERLINE_EQUATION:
+            page_interline_equation_list.append(span['bbox'])
+        elif span['type'] == ContentType.IMAGE:
+            page_image_list.append(span['bbox'])
+        elif span['type'] == ContentType.TABLE:
+            page_table_list.append(span['bbox'])
+
+    for page in pdf_info:
+        page_text_list = []
+        page_inline_equation_list = []
+        page_interline_equation_list = []
+        page_image_list = []
+        page_table_list = []
+        page_dropped_list = []
+
+        # 将跨页的span放到移动到下一页的列表中
+        if len(next_page_text_list) > 0:
+            page_text_list.extend(next_page_text_list)
+            next_page_text_list.clear()
+        if len(next_page_inline_equation_list) > 0:
+            page_inline_equation_list.extend(next_page_inline_equation_list)
+            next_page_inline_equation_list.clear()
+
+        # 构造dropped_list
+        for block in page['discarded_blocks']:
+            if block['type'] == BlockType.DISCARDED:
+                for line in block['lines']:
+                    for span in line['spans']:
+                        page_dropped_list.append(span['bbox'])
+        dropped_list.append(page_dropped_list)
+        # 构造其余useful_list
+        # for block in page['para_blocks']:  # span直接用分段合并前的结果就可以
+        for block in page['preproc_blocks']:
+            if block['type'] in [
+                BlockType.TEXT,
+                BlockType.TITLE,
+                BlockType.INTERLINE_EQUATION,
+                BlockType.LIST,
+                BlockType.INDEX,
+            ]:
+                for line in block['lines']:
+                    for span in line['spans']:
+                        get_span_info(span)
+            elif block['type'] in [BlockType.IMAGE, BlockType.TABLE]:
+                for sub_block in block['blocks']:
+                    for line in sub_block['lines']:
+                        for span in line['spans']:
+                            get_span_info(span)
+        text_list.append(page_text_list)
+        inline_equation_list.append(page_inline_equation_list)
+        interline_equation_list.append(page_interline_equation_list)
+        image_list.append(page_image_list)
+        table_list.append(page_table_list)
+
+    pdf_bytes_io = BytesIO(pdf_bytes)
+    pdf_docs = PdfReader(pdf_bytes_io)
+    output_pdf = PdfWriter()
+
+    for i, page in enumerate(pdf_docs.pages):
+        # 获取原始页面尺寸
+        page_width, page_height = float(page.cropbox[2]), float(page.cropbox[3])
+        custom_page_size = (page_width, page_height)
+
+        packet = BytesIO()
+        # 使用原始PDF的尺寸创建canvas
+        c = canvas.Canvas(packet, pagesize=custom_page_size)
+
+        # 获取当前页面的数据
+        draw_bbox_without_number(i, text_list, page, c,[255, 0, 0], False)
+        draw_bbox_without_number(i, inline_equation_list, page, c, [0, 255, 0], False)
+        draw_bbox_without_number(i, interline_equation_list, page, c, [0, 0, 255], False)
+        draw_bbox_without_number(i, image_list, page, c, [255, 204, 0], False)
+        draw_bbox_without_number(i, table_list, page, c, [204, 0, 255], False)
+        draw_bbox_without_number(i, dropped_list, page, c, [158, 158, 158], False)
+
+        c.save()
+        packet.seek(0)
+        overlay_pdf = PdfReader(packet)
+
+        page.merge_page(overlay_pdf.pages[0])
+        output_pdf.add_page(page)
+
+    # Save the PDF
+    with open(f"{out_path}/{filename}", "wb") as f:
+        output_pdf.write(f)
+
+
 if __name__ == "__main__":
     # 读取PDF文件
     pdf_path = "examples/demo1.pdf"

+ 2 - 0
mineru/utils/enum_class.py

@@ -12,6 +12,7 @@ class BlockType:
     INTERLINE_EQUATION = 'interline_equation'
     LIST = 'list'
     INDEX = 'index'
+    DISCARDED = 'discarded'
 
 
 class ContentType:
@@ -19,6 +20,7 @@ class ContentType:
     TABLE = 'table'
     TEXT = 'text'
     INTERLINE_EQUATION = 'interline_equation'
+    INLINE_EQUATION = 'inline_equation'
 
 
 class MakeMode:

+ 0 - 0
mineru/backend/pipeline/magic_model.py → mineru/utils/pipeline_magic_model.py


+ 0 - 0
mineru/utils/magic_model.py → mineru/utils/vlm_magic_model.py