浏览代码

refactor: streamline document analysis and enhance image handling in processing pipeline

myhloli 5 月之前
父节点
当前提交
38ace5dc61
共有 4 个文件被更改,包括 63 次插入78 次删除
  1. 1 23
      mineru/backend/pipeline/pipeline_analyze.py
  2. 59 52
      mineru/cli/common.py
  3. 1 1
      mineru/utils/cut_image.py
  4. 2 2
      mineru/utils/pdf_image_tools.py

+ 1 - 23
mineru/backend/pipeline/pipeline_analyze.py

@@ -80,23 +80,10 @@ def custom_model_init(
 def doc_analyze(
         pdf_bytes_list,
         lang_list,
-        image_writer: DataWriter | None,
         parse_method: str = 'auto',
         formula_enable=None,
         table_enable=None,
 ):
-    """
-    统一处理文档分析函数,根据输入参数类型决定处理单个数据集还是多个数据集
-
-    Args:
-        dataset_or_datasets: 单个Dataset对象或Dataset对象列表
-        parse_method: 解析方法,'auto'/'ocr'/'txt'
-        formula_enable: 是否启用公式识别
-        table_enable: 是否启用表格识别
-
-    Returns:
-        单个dataset时返回单个model_json,多个dataset时返回model_json列表
-    """
     MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 100))
 
     # 收集所有页面信息
@@ -159,16 +146,7 @@ def doc_analyze(
         page_dict = {'layout_dets': result, 'page_info': page_info_dict}
         infer_results[pdf_idx][page_idx] = page_dict
 
-    middle_json_list = []
-    for pdf_idx, model_list in enumerate(infer_results):
-        images_list = all_image_lists[pdf_idx]
-        pdf_doc = all_pdf_docs[pdf_idx]
-        _lang = lang_list[pdf_idx]
-        _ocr = ocr_enabled_list[pdf_idx]
-        middle_json = result_to_middle_json(model_list, images_list, pdf_doc, image_writer, _lang, _ocr)
-        middle_json_list.append(middle_json)
-
-    return middle_json_list, infer_results
+    return infer_results, all_image_lists, all_pdf_docs, lang_list, ocr_enabled_list
 
 
 def batch_image_analyze(

+ 59 - 52
mineru/cli/common.py

@@ -6,6 +6,8 @@ from pathlib import Path
 
 import pypdfium2 as pdfium
 from loguru import logger
+
+from mineru.backend.pipeline.model_json_to_middle_json import result_to_middle_json as pipeline_result_to_middle_json
 from ..api.vlm_middle_json_mkcontent import union_make
 from ..backend.vlm.vlm_analyze import doc_analyze as vlm_doc_analyze
 from ..backend.pipeline.pipeline_analyze import doc_analyze as pipeline_doc_analyze
@@ -30,10 +32,8 @@ def read_fn(path: Path):
 
 
 def prepare_env(output_dir, pdf_file_name, parse_method):
-    local_parent_dir = os.path.join(output_dir, pdf_file_name, parse_method)
-
-    local_image_dir = os.path.join(str(local_parent_dir), "images")
-    local_md_dir = local_parent_dir
+    local_md_dir = str(os.path.join(output_dir, pdf_file_name, parse_method))
+    local_image_dir = os.path.join(str(local_md_dir), "images")
     os.makedirs(local_image_dir, exist_ok=True)
     os.makedirs(local_md_dir, exist_ok=True)
     return local_image_dir, local_md_dir
@@ -95,15 +95,23 @@ def do_parse(
     if backend == "pipeline":
         for pdf_bytes in pdf_bytes_list:
             pdf_bytes = convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes, start_page_id, end_page_id)
-        middle_json_list, infer_results = pipeline_doc_analyze(pdf_bytes_list, p_lang_list, parse_method=parse_method, formula_enable=p_formula_enable,table_enable=p_table_enable)
-        for idx, middle_json in enumerate(middle_json_list):
+        infer_results, all_image_lists, all_pdf_docs, lang_list, ocr_enabled_list = pipeline_doc_analyze(pdf_bytes_list, p_lang_list, parse_method=parse_method, formula_enable=p_formula_enable,table_enable=p_table_enable)
+
+        for idx, model_list in enumerate(infer_results):
             pdf_file_name = pdf_file_names[idx]
             model_json = infer_results[idx]
             local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, parse_method)
             image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir)
 
+            images_list = all_image_lists[idx]
+            pdf_doc = all_pdf_docs[idx]
+            _lang = lang_list[idx]
+            _ocr = ocr_enabled_list[idx]
+            middle_json = pipeline_result_to_middle_json(model_list, images_list, pdf_doc, image_writer, _lang, _ocr)
+
             pdf_info = middle_json["pdf_info"]
 
+            pdf_bytes = pdf_bytes_list[idx]
             if f_draw_layout_bbox:
                 draw_layout_bbox(pdf_info, pdf_bytes, local_md_dir, f"{pdf_file_name}_layout.pdf")
 
@@ -155,52 +163,51 @@ def do_parse(
             image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir)
             middle_json, infer_result = vlm_doc_analyze(pdf_bytes, image_writer=image_writer, backend=backend, model_path=model_path, server_url=server_url)
 
-        pdf_info = middle_json["pdf_info"]
-
-        if f_draw_layout_bbox:
-            draw_layout_bbox(pdf_info, pdf_bytes, local_md_dir, f"{pdf_file_name}_layout.pdf")
-
-        if f_draw_span_bbox:
-            draw_span_bbox(pdf_info, pdf_bytes, local_md_dir, f"{pdf_file_name}_span.pdf")
-
-        if f_dump_orig_pdf:
-            md_writer.write(
-                f"{pdf_file_name}_origin.pdf",
-                pdf_bytes,
-            )
-
-        if f_dump_md:
-            image_dir = str(os.path.basename(local_image_dir))
-            md_content_str = union_make(pdf_info, f_make_md_mode, image_dir)
-            md_writer.write_string(
-                f"{pdf_file_name}.md",
-                md_content_str,
-            )
-
-        if f_dump_content_list:
-            image_dir = str(os.path.basename(local_image_dir))
-            content_list = union_make(pdf_info, MakeMode.STANDARD_FORMAT, image_dir)
-            md_writer.write_string(
-                f"{pdf_file_name}_content_list.json",
-                json.dumps(content_list, ensure_ascii=False, indent=4),
-            )
-
-        if f_dump_middle_json:
-            md_writer.write_string(
-                f"{pdf_file_name}_middle.json",
-                json.dumps(middle_json, ensure_ascii=False, indent=4),
-            )
-
-        if f_dump_model_output:
-            model_output = ("\n" + "-" * 50 + "\n").join(infer_result)
-            md_writer.write_string(
-                f"{pdf_file_name}_model_output.txt",
-                model_output,
-            )
-
-        logger.info(f"local output dir is {local_md_dir}")
-
-    return infer_result
+            pdf_info = middle_json["pdf_info"]
+
+            if f_draw_layout_bbox:
+                draw_layout_bbox(pdf_info, pdf_bytes, local_md_dir, f"{pdf_file_name}_layout.pdf")
+
+            if f_draw_span_bbox:
+                draw_span_bbox(pdf_info, pdf_bytes, local_md_dir, f"{pdf_file_name}_span.pdf")
+
+            if f_dump_orig_pdf:
+                md_writer.write(
+                    f"{pdf_file_name}_origin.pdf",
+                    pdf_bytes,
+                )
+
+            if f_dump_md:
+                image_dir = str(os.path.basename(local_image_dir))
+                md_content_str = union_make(pdf_info, f_make_md_mode, image_dir)
+                md_writer.write_string(
+                    f"{pdf_file_name}.md",
+                    md_content_str,
+                )
+
+            if f_dump_content_list:
+                image_dir = str(os.path.basename(local_image_dir))
+                content_list = union_make(pdf_info, MakeMode.STANDARD_FORMAT, image_dir)
+                md_writer.write_string(
+                    f"{pdf_file_name}_content_list.json",
+                    json.dumps(content_list, ensure_ascii=False, indent=4),
+                )
+
+            if f_dump_middle_json:
+                md_writer.write_string(
+                    f"{pdf_file_name}_middle.json",
+                    json.dumps(middle_json, ensure_ascii=False, indent=4),
+                )
+
+            if f_dump_model_output:
+                model_output = ("\n" + "-" * 50 + "\n").join(infer_result)
+                md_writer.write_string(
+                    f"{pdf_file_name}_model_output.txt",
+                    model_output,
+                )
+
+            logger.info(f"local output dir is {local_md_dir}")
+
 
 
 if __name__ == "__main__":

+ 1 - 1
mineru/utils/cut_image.py

@@ -14,7 +14,7 @@ def cut_image_and_table(span, page_pil_img, page_img_md5, page_id, image_writer,
         span["image_path"] = ""
     else:
         span["image_path"] = cut_image(
-            span["bbox"], page_id, page_pil_img, return_path=return_path(span_type), imageWriter=imageWriter, scale=scale
+            span["bbox"], page_id, page_pil_img, return_path=return_path(span_type), image_writer=image_writer, scale=scale
         )
 
     return span

+ 2 - 2
mineru/utils/pdf_image_tools.py

@@ -54,7 +54,7 @@ def load_images_from_pdf(
     return images_list, pdf_doc
 
 
-def cut_image(bbox: tuple, page_num: int, page_pil_img, return_path, imageWriter: FileBasedDataWriter, scale=2):
+def cut_image(bbox: tuple, page_num: int, page_pil_img, return_path, image_writer: FileBasedDataWriter, scale=2):
     """从第page_num页的page中,根据bbox进行裁剪出一张jpg图片,返回图片路径 save_path:需要同时支持s3和本地,
     图片存放在save_path下,文件名是:
     {page_num}_{bbox[0]}_{bbox[1]}_{bbox[2]}_{bbox[3]}.jpg , bbox内数字取整。"""
@@ -73,7 +73,7 @@ def cut_image(bbox: tuple, page_num: int, page_pil_img, return_path, imageWriter
 
     img_bytes = image_to_bytes(crop_img, image_format="JPEG")
 
-    imageWriter.write(img_hash256_path, img_bytes)
+    image_writer.write(img_hash256_path, img_bytes)
     return img_hash256_path