浏览代码

refactor: enhance document parsing by supporting multiple PDF files and improving method organization

myhloli 5 月之前
父节点
当前提交
ea5cb65a1f

+ 3 - 0
mineru/backend/pipeline/model_json_to_middle_json.py

@@ -0,0 +1,3 @@
+# Copyright (c) Opendatalab. All rights reserved.
+def result_to_middle_json(model_json, images_list, pdf_doc, image_writer):
+    pass

+ 83 - 101
mineru/backend/pipeline/doc_analyze_by_custom_model.py → mineru/backend/pipeline/pipeline_analyze.py

@@ -2,12 +2,14 @@ import os
 import time
 import numpy as np
 import torch
+from pypdfium2 import PdfDocument
+
 from mineru.backend.pipeline.model_init import MineruPipelineModel
+from .model_json_to_middle_json import result_to_middle_json
+from ...utils.pdf_classify import classify
+from ...utils.pdf_image_tools import pdf_page_to_image
+
 
-os.environ['FLAGS_npu_jit_compile'] = '0'  # 关闭paddle的jit编译
-os.environ['FLAGS_use_stride_kernel'] = '0'
-os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'  # 让mps可以fallback
-os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1'  # 禁止albumentations检查更新
 
 
 from loguru import logger
@@ -18,6 +20,11 @@ from magic_pdf.libs.config_reader import (get_device, get_formula_config,
                                           get_local_models_dir,
                                           get_table_recog_config)
 
+
+
+os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'  # 让mps可以fallback
+os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1'  # 禁止albumentations检查更新
+
 class ModelSingleton:
     _instance = None
     _models = {}
@@ -76,117 +83,92 @@ def custom_model_init(
 
     return custom_model
 
+
 def doc_analyze(
-    dataset: Dataset,
-    ocr: bool = False,
-    start_page_id=0,
-    end_page_id=None,
-    lang=None,
-    formula_enable=None,
-    table_enable=None,
+        pdf_bytes_list,
+        lang_list,
+        parse_method: str = 'auto',
+        formula_enable=None,
+        table_enable=None,
 ):
-    end_page_id = (
-        end_page_id
-        if end_page_id is not None and end_page_id >= 0
-        else len(dataset) - 1
-    )
-
+    """
+    统一处理文档分析函数,根据输入参数类型决定处理单个数据集还是多个数据集
+
+    Args:
+        dataset_or_datasets: 单个Dataset对象或Dataset对象列表
+        parse_method: 解析方法,'auto'/'ocr'/'txt'
+        formula_enable: 是否启用公式识别
+        table_enable: 是否启用表格识别
+
+    Returns:
+        单个dataset时返回单个model_json,多个dataset时返回model_json列表
+    """
     MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 100))
-    images = []
-    page_wh_list = []
-    for index in range(len(dataset)):
-        if start_page_id <= index <= end_page_id:
-            page_data = dataset.get_page(index)
-            img_dict = page_data.get_image()
-            images.append(img_dict['img'])
-            page_wh_list.append((img_dict['width'], img_dict['height']))
-
-    images_with_extra_info = [(images[index], ocr, dataset._lang) for index in range(len(images))]
-
-    if len(images) >= MIN_BATCH_INFERENCE_SIZE:
-        batch_size = MIN_BATCH_INFERENCE_SIZE
-        batch_images = [images_with_extra_info[i:i+batch_size] for i in range(0, len(images_with_extra_info), batch_size)]
-    else:
-        batch_images = [images_with_extra_info]
 
+    # 收集所有页面信息
+    all_pages_info = []  # 存储(dataset_index, page_index, img, ocr, lang, width, height)
+
+    for pdf_idx, pdf_bytes in enumerate(pdf_bytes_list):
+        # 确定OCR设置
+        _ocr = False
+        if parse_method == 'auto':
+            if classify(pdf_bytes) == 'ocr':
+                _ocr = True
+        elif parse_method == 'ocr':
+            _ocr = True
+
+        _lang = lang_list[pdf_idx]
+
+        # 收集每个数据集中的页面
+        pdf_doc = PdfDocument(pdf_bytes)
+        for page_idx in range(len(pdf_doc)):
+            page_data = pdf_doc[page_idx]
+            img_dict = pdf_page_to_image(page_data)
+            all_pages_info.append((
+                pdf_idx, page_idx,
+                img_dict['img_pil'], _ocr, _lang,
+                img_dict['scale']
+            ))
+
+    # 准备批处理
+    images_with_extra_info = [(info[2], info[3], info[4]) for info in all_pages_info]
+    batch_size = MIN_BATCH_INFERENCE_SIZE
+    batch_images = [
+        images_with_extra_info[i:i + batch_size]
+        for i in range(0, len(images_with_extra_info), batch_size)
+    ]
+
+    # 执行批处理
     results = []
     processed_images_count = 0
     for index, batch_image in enumerate(batch_images):
         processed_images_count += len(batch_image)
-        logger.info(f'Batch {index + 1}/{len(batch_images)}: {processed_images_count} pages/{len(images_with_extra_info)} pages')
-        result = may_batch_image_analyze(batch_image, formula_enable, table_enable)
-        results.extend(result)
-
-    model_json = []
-    for index in range(len(dataset)):
-        if start_page_id <= index <= end_page_id:
-            result = results.pop(0)
-            page_width, page_height = page_wh_list.pop(0)
-        else:
-            result = []
-            page_height = 0
-            page_width = 0
-
-        page_info = {'page_no': index, 'width': page_width, 'height': page_height}
-        page_dict = {'layout_dets': result, 'page_info': page_info}
-        model_json.append(page_dict)
+        logger.info(
+            f'Batch {index + 1}/{len(batch_images)}: '
+            f'{processed_images_count} pages/{len(images_with_extra_info)} pages'
+        )
+        batch_results = may_batch_image_analyze(batch_image, formula_enable, table_enable)
+        results.extend(batch_results)
 
-    return model_json
-
-def batch_doc_analyze(
-    datasets: list[Dataset],
-    parse_method: str = 'auto',
-    lang=None,
-    formula_enable=None,
-    table_enable=None,
-):
-    MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 100))
-    batch_size = MIN_BATCH_INFERENCE_SIZE
-    page_wh_list = []
+    # 构建返回结果
 
-    images_with_extra_info = []
-    for dataset in datasets:
+    # 多数据集模式:按数据集分组结果
+    infer_results = [[] for _ in datasets]
 
-        ocr = False
-        if parse_method == 'auto':
-            if dataset.classify() == 'txt':
-                ocr = False
-            elif dataset.classify() == 'ocr':
-                ocr = True
-        elif parse_method == 'ocr':
-            ocr = True
-        elif parse_method == 'txt':
-            ocr = False
+    for i, page_info in enumerate(all_pages_info):
+        pdf_idx, page_idx, pil_img, _, _ = page_info
+        result = results[i]
 
-        _lang = dataset._lang
+        page_info_dict = {'page_no': page_idx, 'width': pil_img.get_width(), 'height': pil_img.get_height()}
+        page_dict = {'layout_dets': result, 'page_info': page_info_dict}
+        infer_results[pdf_idx].append(page_dict)
 
-        for index in range(len(dataset)):
-            page_data = dataset.get_page(index)
-            img_dict = page_data.get_image()
-            page_wh_list.append((img_dict['width'], img_dict['height']))
-            images_with_extra_info.append((img_dict['img'], ocr, _lang))
+    middle_json_list = []
+    for model_json in infer_results:
+        middle_json = result_to_middle_json(model_json)
+        middle_json_list.append(middle_json)
 
-    batch_images = [images_with_extra_info[i:i+batch_size] for i in range(0, len(images_with_extra_info), batch_size)]
-    results = []
-    processed_images_count = 0
-    for index, batch_image in enumerate(batch_images):
-        processed_images_count += len(batch_image)
-        logger.info(f'Batch {index + 1}/{len(batch_images)}: {processed_images_count} pages/{len(images_with_extra_info)} pages')
-        result = may_batch_image_analyze(batch_image, formula_enable, table_enable)
-        results.extend(result)
-
-    infer_results = []
-    for index in range(len(datasets)):
-        dataset = datasets[index]
-        model_json = []
-        for i in range(len(dataset)):
-            result = results.pop(0)
-            page_width, page_height = page_wh_list.pop(0)
-            page_info = {'page_no': i, 'width': page_width, 'height': page_height}
-            page_dict = {'layout_dets': result, 'page_info': page_info}
-            model_json.append(page_dict)
-        infer_results.append(model_json)
-    return infer_results
+    return middle_json_list, infer_results
 
 
 def may_batch_image_analyze(

+ 119 - 58
mineru/cli/common.py

@@ -7,7 +7,8 @@ from pathlib import Path
 import pypdfium2 as pdfium
 from loguru import logger
 from ..api.vlm_middle_json_mkcontent import union_make
-from ..backend.vlm.vlm_analyze import doc_analyze
+from ..backend.vlm.vlm_analyze import doc_analyze as vlm_doc_analyze
+from ..backend.pipeline.pipeline_analyze import doc_analyze as pipeline_doc_analyze
 from ..data.data_reader_writer import FileBasedDataWriter
 from ..utils.draw_bbox import draw_layout_bbox, draw_span_bbox
 from ..utils.enum_class import MakeMode
@@ -28,8 +29,8 @@ def read_fn(path: Path):
             raise Exception(f"Unknown file suffix: {path.suffix}")
 
 
-def prepare_env(output_dir, pdf_file_name):
-    local_parent_dir = os.path.join(output_dir, pdf_file_name)
+def prepare_env(output_dir, pdf_file_name, parse_method):
+    local_parent_dir = os.path.join(output_dir, pdf_file_name, parse_method)
 
     local_image_dir = os.path.join(str(local_parent_dir), "images")
     local_md_dir = local_parent_dir
@@ -70,13 +71,17 @@ def convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes, start_page_id=0, end_page
 
 def do_parse(
     output_dir,
-    pdf_file_name,
-    pdf_bytes,
+    pdf_file_names: list[str],
+    pdf_bytes_list: list[bytes],
+    p_lang_list: list[str],
     backend="pipeline",
     model_path="jinzhenj/OEEzRkQ3RTAtMDMx-0415",  # TODO: change to formal path after release.
+    parse_method="auto",
+    p_formula_enable=True,
+    p_table_enable=True,
     server_url=None,
     f_draw_layout_bbox=True,
-    f_draw_span_bbox=False,
+    f_draw_span_bbox=True,
     f_dump_md=True,
     f_dump_middle_json=True,
     f_dump_model_output=True,
@@ -86,58 +91,114 @@ def do_parse(
     start_page_id=0,
     end_page_id=None,
 ):
-    if backend == 'pipeline':
-        f_draw_span_bbox = True
-
-    pdf_bytes = convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes, start_page_id, end_page_id)
-    local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name)
-    image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir)
-
-    middle_json, infer_result = doc_analyze(pdf_bytes, image_writer=image_writer, backend=backend, server_url=server_url)
-    pdf_info = middle_json["pdf_info"]
-
-    if f_draw_layout_bbox:
-        draw_layout_bbox(pdf_info, pdf_bytes, local_md_dir, f"{pdf_file_name}_layout.pdf")
-
-    if f_draw_span_bbox:
-        draw_span_bbox(pdf_info, pdf_bytes, local_md_dir, f"{pdf_file_name}_span.pdf")
-
-    if f_dump_orig_pdf:
-        md_writer.write(
-            f"{pdf_file_name}_origin.pdf",
-            pdf_bytes,
-        )
-
-    if f_dump_md:
-        image_dir = str(os.path.basename(local_image_dir))
-        md_content_str = union_make(pdf_info, f_make_md_mode, image_dir)
-        md_writer.write_string(
-            f"{pdf_file_name}.md",
-            md_content_str,
-        )
-
-    if f_dump_content_list:
-        image_dir = str(os.path.basename(local_image_dir))
-        content_list = union_make(pdf_info, MakeMode.STANDARD_FORMAT, image_dir)
-        md_writer.write_string(
-            f"{pdf_file_name}_content_list.json",
-            json.dumps(content_list, ensure_ascii=False, indent=4),
-        )
-
-    if f_dump_middle_json:
-        md_writer.write_string(
-            f"{pdf_file_name}_middle.json",
-            json.dumps(middle_json, ensure_ascii=False, indent=4),
-        )
-
-    if f_dump_model_output:
-        model_output = ("\n" + "-" * 50 + "\n").join(infer_result)
-        md_writer.write_string(
-            f"{pdf_file_name}_model_output.txt",
-            model_output,
-        )
-
-    logger.info(f"local output dir is {local_md_dir}")
+
+    if backend == "pipeline":
+        for pdf_bytes in pdf_bytes_list:
+            pdf_bytes = convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes, start_page_id, end_page_id)
+        middle_json_list, infer_results = pipeline_doc_analyze(pdf_bytes_list, p_lang_list, parse_method=parse_method, formula_enable=p_formula_enable,table_enable=p_table_enable)
+        for idx, middle_json in enumerate(middle_json_list):
+            pdf_file_name = pdf_file_names[idx]
+            model_json = infer_results[idx]
+            local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, parse_method)
+            image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir)
+
+            pdf_info = middle_json["pdf_info"]
+
+            if f_draw_layout_bbox:
+                draw_layout_bbox(pdf_info, pdf_bytes, local_md_dir, f"{pdf_file_name}_layout.pdf")
+
+            if f_draw_span_bbox:
+                draw_span_bbox(pdf_info, pdf_bytes, local_md_dir, f"{pdf_file_name}_span.pdf")
+
+            if f_dump_orig_pdf:
+                md_writer.write(
+                    f"{pdf_file_name}_origin.pdf",
+                    pdf_bytes,
+                )
+
+            if f_dump_md:
+                image_dir = str(os.path.basename(local_image_dir))
+                md_content_str = union_make(pdf_info, f_make_md_mode, image_dir)
+                md_writer.write_string(
+                    f"{pdf_file_name}.md",
+                    md_content_str,
+                )
+
+            if f_dump_content_list:
+                image_dir = str(os.path.basename(local_image_dir))
+                content_list = union_make(pdf_info, MakeMode.STANDARD_FORMAT, image_dir)
+                md_writer.write_string(
+                    f"{pdf_file_name}_content_list.json",
+                    json.dumps(content_list, ensure_ascii=False, indent=4),
+                )
+
+            if f_dump_middle_json:
+                md_writer.write_string(
+                    f"{pdf_file_name}_middle.json",
+                    json.dumps(middle_json, ensure_ascii=False, indent=4),
+                )
+
+            if f_dump_model_output:
+                md_writer.write_string(
+                    f"{pdf_file_name}_model.json",
+                    json.dumps(model_json, ensure_ascii=False, indent=4),
+                )
+
+            logger.info(f"local output dir is {local_md_dir}")
+    else:
+        f_draw_span_bbox = False
+        parse_method = "vlm"
+        for idx, pdf_bytes in enumerate(pdf_bytes_list):
+            pdf_file_name = pdf_file_names[idx]
+            pdf_bytes = convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes, start_page_id, end_page_id)
+            local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, parse_method)
+            image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir)
+            middle_json, infer_result = vlm_doc_analyze(pdf_bytes, image_writer=image_writer, backend=backend, model_path=model_path, server_url=server_url)
+
+        pdf_info = middle_json["pdf_info"]
+
+        if f_draw_layout_bbox:
+            draw_layout_bbox(pdf_info, pdf_bytes, local_md_dir, f"{pdf_file_name}_layout.pdf")
+
+        if f_draw_span_bbox:
+            draw_span_bbox(pdf_info, pdf_bytes, local_md_dir, f"{pdf_file_name}_span.pdf")
+
+        if f_dump_orig_pdf:
+            md_writer.write(
+                f"{pdf_file_name}_origin.pdf",
+                pdf_bytes,
+            )
+
+        if f_dump_md:
+            image_dir = str(os.path.basename(local_image_dir))
+            md_content_str = union_make(pdf_info, f_make_md_mode, image_dir)
+            md_writer.write_string(
+                f"{pdf_file_name}.md",
+                md_content_str,
+            )
+
+        if f_dump_content_list:
+            image_dir = str(os.path.basename(local_image_dir))
+            content_list = union_make(pdf_info, MakeMode.STANDARD_FORMAT, image_dir)
+            md_writer.write_string(
+                f"{pdf_file_name}_content_list.json",
+                json.dumps(content_list, ensure_ascii=False, indent=4),
+            )
+
+        if f_dump_middle_json:
+            md_writer.write_string(
+                f"{pdf_file_name}_middle.json",
+                json.dumps(middle_json, ensure_ascii=False, indent=4),
+            )
+
+        if f_dump_model_output:
+            model_output = ("\n" + "-" * 50 + "\n").join(infer_result)
+            md_writer.write_string(
+                f"{pdf_file_name}_model_output.txt",
+                model_output,
+            )
+
+        logger.info(f"local output dir is {local_md_dir}")
 
     return infer_result
 

+ 155 - 0
mineru/utils/pdf_classify.py

@@ -0,0 +1,155 @@
+# Copyright (c) Opendatalab. All rights reserved.
+import re
+from io import BytesIO
+import numpy as np
+import pypdfium2 as pdfium
+from loguru import logger
+from pdfminer.high_level import extract_text
+from pdfminer.layout import LAParams
+
+
+def classify(pdf_bytes):
+    """
+    判断PDF文件是可以直接提取文本还是需要OCR
+
+    Args:
+        pdf_bytes: PDF文件的字节数据
+
+    Returns:
+        str: 'txt' 表示可以直接提取文本,'ocr' 表示需要OCR
+    """
+    try:
+        # 从字节数据加载PDF
+        sample_pdf_bytes = extract_pages(pdf_bytes)
+        pdf = pdfium.PdfDocument(sample_pdf_bytes)
+
+        # 获取PDF页数
+        page_count = len(pdf)
+
+        # 如果PDF页数为0,直接返回OCR
+        if page_count == 0:
+            return 'ocr'
+
+        # 总字符数
+        total_chars = 0
+        # 清理后的总字符数
+        cleaned_total_chars = 0
+        # 检查的页面数(最多检查10页)
+        pages_to_check = min(page_count, 10)
+
+        # 检查前几页的文本
+        for i in range(pages_to_check):
+            page = pdf[i]
+            text_page = page.get_textpage()
+            text = text_page.get_text_bounded()
+            total_chars += len(text)
+
+            # 清理提取的文本,移除空白字符
+            cleaned_text = re.sub(r'\s+', '', text)
+            cleaned_total_chars += len(cleaned_text)
+
+        # 计算平均每页字符数
+        # avg_chars_per_page = total_chars / pages_to_check
+        avg_cleaned_chars_per_page = cleaned_total_chars / pages_to_check
+
+        # 设置阈值:如果每页平均少于50个有效字符,认为需要OCR
+        chars_threshold = 50
+
+        # logger.debug(f"PDF分析: 平均每页{avg_chars_per_page:.1f}字符, 清理后{avg_cleaned_chars_per_page:.1f}字符")
+
+        if (avg_cleaned_chars_per_page < chars_threshold) or detect_invalid_chars(sample_pdf_bytes):
+            return 'ocr'
+        else:
+            return 'txt'
+    except Exception as e:
+        logger.error(f"判断PDF类型时出错: {e}")
+        # 出错时默认使用OCR
+        return 'ocr'
+
+
+def extract_pages(src_pdf_bytes: bytes) -> bytes:
+    """
+    从PDF字节数据中随机提取最多10页,返回新的PDF字节数据
+
+    Args:
+        src_pdf_bytes: PDF文件的字节数据
+
+    Returns:
+        bytes: 提取页面后的PDF字节数据
+    """
+
+    # 从字节数据加载PDF
+    pdf = pdfium.PdfDocument(src_pdf_bytes)
+
+    # 获取PDF页数
+    total_page = len(pdf)
+    if total_page == 0:
+        # 如果PDF没有页面,直接返回空文档
+        logger.warning("PDF is empty, return empty document")
+        return b''
+
+    # 选择最多10页
+    select_page_cnt = min(10, total_page)
+
+    # 从总页数中随机选择页面
+    page_indices = np.random.choice(total_page, select_page_cnt, replace=False).tolist()
+
+    # 创建一个新的PDF文档
+    sample_docs = pdfium.PdfDocument.new()
+
+    try:
+        # 将选择的页面导入新文档
+        sample_docs.import_pages(pdf, page_indices)
+
+        # 将新PDF保存到内存缓冲区
+        output_buffer = BytesIO()
+        sample_docs.save(output_buffer)
+
+        # 获取字节数据
+        return output_buffer.getvalue()
+    except Exception as e:
+        logger.exception(e)
+        return b''  # 出错时返回空字节
+
+
+def detect_invalid_chars(sample_pdf_bytes: bytes) -> bool:
+    """"
+    检测PDF中是否包含非法字符
+    """
+    '''pdfminer比较慢,需要先随机抽取10页左右的sample'''
+    # sample_pdf_bytes = extract_pages(src_pdf_bytes)
+    sample_pdf_file_like_object = BytesIO(sample_pdf_bytes)
+    laparams = LAParams(
+        line_overlap=0.5,
+        char_margin=2.0,
+        line_margin=0.5,
+        word_margin=0.1,
+        boxes_flow=None,
+        detect_vertical=False,
+        all_texts=False,
+    )
+    text = extract_text(pdf_file=sample_pdf_file_like_object, laparams=laparams)
+    text = text.replace("\n", "")
+    # logger.info(text)
+    '''乱码文本用pdfminer提取出来的文本特征是(cid:xxx)'''
+    cid_pattern = re.compile(r'\(cid:\d+\)')
+    matches = cid_pattern.findall(text)
+    cid_count = len(matches)
+    cid_len = sum(len(match) for match in matches)
+    text_len = len(text)
+    if text_len == 0:
+        cid_chars_radio = 0
+    else:
+        cid_chars_radio = cid_count/(cid_count + text_len - cid_len)
+    # logger.debug(f"cid_count: {cid_count}, text_len: {text_len}, cid_chars_radio: {cid_chars_radio}")
+    '''当一篇文章存在5%以上的文本是乱码时,认为该文档为乱码文档'''
+    if cid_chars_radio > 0.05:
+        return True  # 乱码文档
+    else:
+        return False   # 正常文档
+
+
+if __name__ == '__main__':
+    with open('/Users/myhloli/pdf/luanma2x10.pdf', 'rb') as f:
+        p_bytes = f.read()
+        logger.info(f"PDF分类结果: {classify(p_bytes)}")