|
@@ -0,0 +1,644 @@
|
|
|
|
|
+"""
|
|
|
|
|
+批量处理图片/PDF文件并生成符合评测要求的预测结果(MinerU版本)
|
|
|
|
|
+
|
|
|
|
|
+根据 MinerU demo.py 框架调用方式:
|
|
|
|
|
+- 输入:支持 PDF 和各种图片格式
|
|
|
|
|
+- 输出:每个文件对应的 .md、.json 文件,所有图片保存为单独的图片文件
|
|
|
|
|
+- 调用方式:通过 vlm-http-client 连接到 MinerU vLLM 服务器
|
|
|
|
|
+"""
|
|
|
|
|
+
|
|
|
|
|
+import os
|
|
|
|
|
+import sys
|
|
|
|
|
+import json
|
|
|
|
|
+import copy
|
|
|
|
|
+import shutil
|
|
|
|
|
+import time
|
|
|
|
|
+import traceback
|
|
|
|
|
+from pathlib import Path
|
|
|
|
|
+from typing import List, Dict, Any
|
|
|
|
|
+from PIL import Image
|
|
|
|
|
+from tqdm import tqdm
|
|
|
|
|
+import argparse
|
|
|
|
|
+
|
|
|
|
|
+from loguru import logger
|
|
|
|
|
+
|
|
|
|
|
+# 导入 MinerU 核心组件 (参考 demo.py)
|
|
|
|
|
+from mineru.cli.common import read_fn, convert_pdf_bytes_to_bytes_by_pypdfium2, prepare_env
|
|
|
|
|
+from mineru.data.data_reader_writer import FileBasedDataWriter
|
|
|
|
|
+from mineru.utils.draw_bbox import draw_layout_bbox
|
|
|
|
|
+from mineru.utils.enum_class import MakeMode
|
|
|
|
|
+from mineru.backend.vlm.vlm_analyze import doc_analyze as vlm_doc_analyze
|
|
|
|
|
+from mineru.backend.vlm.vlm_middle_json_mkcontent import union_make as vlm_union_make
|
|
|
|
|
+
|
|
|
|
|
+# 导入工具函数
|
|
|
|
|
+from utils import (
|
|
|
|
|
+ get_input_files,
|
|
|
|
|
+ collect_pid_files,
|
|
|
|
|
+ normalize_markdown_table,
|
|
|
|
|
+ normalize_json_table,
|
|
|
|
|
+)
|
|
|
|
|
+
|
|
|
|
|
+class MinerUVLLMProcessor:
|
|
|
|
|
+ """MinerU vLLM 处理器 (基于 demo.py 框架)"""
|
|
|
|
|
+
|
|
|
|
|
+ def __init__(self,
|
|
|
|
|
+ server_url: str = "http://127.0.0.1:8121",
|
|
|
|
|
+ timeout: int = 300,
|
|
|
|
|
+ normalize_numbers: bool = False,
|
|
|
|
|
+ debug: bool = False):
|
|
|
|
|
+ """
|
|
|
|
|
+ 初始化处理器
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ server_url: vLLM 服务器地址
|
|
|
|
|
+ timeout: 请求超时时间
|
|
|
|
|
+ normalize_numbers: 是否标准化数字
|
|
|
|
|
+ debug: 是否启用调试模式
|
|
|
|
|
+ """
|
|
|
|
|
+ self.server_url = server_url.rstrip('/')
|
|
|
|
|
+ self.timeout = timeout
|
|
|
|
|
+ self.normalize_numbers = normalize_numbers
|
|
|
|
|
+ self.debug = debug
|
|
|
|
|
+ self.backend = "http-client" # 固定使用 http-client 后端
|
|
|
|
|
+
|
|
|
|
|
+ print(f"MinerU vLLM Processor 初始化完成:")
|
|
|
|
|
+ print(f" - 服务器: {server_url}")
|
|
|
|
|
+ print(f" - 后端: vlm-{self.backend}")
|
|
|
|
|
+ print(f" - 超时: {timeout}s")
|
|
|
|
|
+ print(f" - 数字标准化: {normalize_numbers}")
|
|
|
|
|
+ print(f" - 调试模式: {debug}")
|
|
|
|
|
+
|
|
|
|
|
+ def do_parse_single_file(self,
|
|
|
|
|
+ input_file: str,
|
|
|
|
|
+ output_dir: str,
|
|
|
|
|
+ start_page_id: int = 0,
|
|
|
|
|
+ end_page_id: int = None) -> Dict[str, Any]:
|
|
|
|
|
+ """
|
|
|
|
|
+ 解析单个文件 (参考 demo.py 的 do_parse 函数)
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ file_path: 文件路径
|
|
|
|
|
+ output_dir: 输出目录
|
|
|
|
|
+ start_page_id: 起始页ID
|
|
|
|
|
+ end_page_id: 结束页ID
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ dict: 处理结果
|
|
|
|
|
+ """
|
|
|
|
|
+ try:
|
|
|
|
|
+ # 准备文件名和字节数据
|
|
|
|
|
+ file_path = Path(input_file)
|
|
|
|
|
+ pdf_file_name = file_path.stem
|
|
|
|
|
+ pdf_bytes = read_fn(str(file_path))
|
|
|
|
|
+
|
|
|
|
|
+ # 转换PDF字节流 (如果需要)
|
|
|
|
|
+ if file_path.suffix.lower() == '.pdf':
|
|
|
|
|
+ pdf_bytes = convert_pdf_bytes_to_bytes_by_pypdfium2(
|
|
|
|
|
+ pdf_bytes, start_page_id, end_page_id
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # 准备环境 (创建输出目录)
|
|
|
|
|
+ # parse_method = "vlm"
|
|
|
|
|
+ # local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, parse_method)
|
|
|
|
|
+ local_md_dir = Path(output_dir).resolve()
|
|
|
|
|
+ local_image_dir = local_md_dir / "images"
|
|
|
|
|
+ image_writer = FileBasedDataWriter(local_image_dir.as_posix())
|
|
|
|
|
+ md_writer = FileBasedDataWriter(local_md_dir.as_posix())
|
|
|
|
|
+
|
|
|
|
|
+ # 使用 VLM 分析文档 (核心调用)
|
|
|
|
|
+ middle_json, model_output = vlm_doc_analyze(
|
|
|
|
|
+ pdf_bytes,
|
|
|
|
|
+ image_writer=image_writer,
|
|
|
|
|
+ backend=self.backend,
|
|
|
|
|
+ server_url=self.server_url
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ pdf_info = middle_json["pdf_info"]
|
|
|
|
|
+
|
|
|
|
|
+ # 处理输出 (参考 demo.py 的 _process_output)
|
|
|
|
|
+ output_files = self._process_output(
|
|
|
|
|
+ pdf_info=pdf_info,
|
|
|
|
|
+ pdf_bytes=pdf_bytes,
|
|
|
|
|
+ pdf_file_name=pdf_file_name,
|
|
|
|
|
+ local_md_dir=local_md_dir,
|
|
|
|
|
+ local_image_dir=local_image_dir,
|
|
|
|
|
+ md_writer=md_writer,
|
|
|
|
|
+ middle_json=middle_json,
|
|
|
|
|
+ model_output=model_output,
|
|
|
|
|
+ original_file_path=str(file_path)
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # 统计提取信息
|
|
|
|
|
+ extraction_stats = self._get_extraction_stats(middle_json)
|
|
|
|
|
+
|
|
|
|
|
+ return {
|
|
|
|
|
+ "success": True,
|
|
|
|
|
+ "pdf_info": pdf_info,
|
|
|
|
|
+ "middle_json": middle_json,
|
|
|
|
|
+ "model_output": model_output,
|
|
|
|
|
+ "output_files": output_files,
|
|
|
|
|
+ "extraction_stats": extraction_stats
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"Failed to process {file_path}: {e}")
|
|
|
|
|
+ return {
|
|
|
|
|
+ "success": False,
|
|
|
|
|
+ "error": str(e)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ def _process_output(self,
|
|
|
|
|
+ pdf_info,
|
|
|
|
|
+ pdf_bytes,
|
|
|
|
|
+ pdf_file_name,
|
|
|
|
|
+ local_md_dir,
|
|
|
|
|
+ local_image_dir,
|
|
|
|
|
+ md_writer,
|
|
|
|
|
+ middle_json,
|
|
|
|
|
+ model_output,
|
|
|
|
|
+ original_file_path: str) -> Dict[str, str]:
|
|
|
|
|
+ """
|
|
|
|
|
+ 处理输出文件 (改进版的 demo.py _process_output)
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ pdf_info: PDF信息
|
|
|
|
|
+ pdf_bytes: PDF字节数据
|
|
|
|
|
+ pdf_file_name: PDF文件名
|
|
|
|
|
+ local_md_dir: Markdown目录
|
|
|
|
|
+ local_image_dir: 图片目录
|
|
|
|
|
+ md_writer: Markdown写入器
|
|
|
|
|
+ middle_json: 中间JSON数据
|
|
|
|
|
+ model_output: 模型输出
|
|
|
|
|
+ original_file_path: 原始文件路径
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ dict: 保存的文件路径信息
|
|
|
|
|
+ """
|
|
|
|
|
+ saved_files = {}
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ # 设置相对图片目录名
|
|
|
|
|
+ image_dir = str(os.path.basename(local_image_dir))
|
|
|
|
|
+
|
|
|
|
|
+ # 1. 生成并保存 Markdown 文件
|
|
|
|
|
+ md_content_str = vlm_union_make(pdf_info, MakeMode.MM_MD, image_dir)
|
|
|
|
|
+
|
|
|
|
|
+ # 数字标准化处理
|
|
|
|
|
+ if self.normalize_numbers:
|
|
|
|
|
+ original_md = md_content_str
|
|
|
|
|
+ md_content_str = normalize_markdown_table(md_content_str)
|
|
|
|
|
+
|
|
|
|
|
+ changes_count = len([1 for o, n in zip(original_md, md_content_str) if o != n])
|
|
|
|
|
+ if changes_count > 0:
|
|
|
|
|
+ saved_files['md_normalized'] = f"✅ 已标准化 {changes_count} 个字符(全角→半角)"
|
|
|
|
|
+ else:
|
|
|
|
|
+ saved_files['md_normalized'] = "ℹ️ 无需标准化(已是标准格式)"
|
|
|
|
|
+
|
|
|
|
|
+ md_writer.write_string(f"{pdf_file_name}.md", md_content_str)
|
|
|
|
|
+ saved_files['md'] = os.path.join(local_md_dir, f"{pdf_file_name}.md")
|
|
|
|
|
+
|
|
|
|
|
+ # 2. 生成并保存 content_list JSON 文件
|
|
|
|
|
+ content_list = vlm_union_make(pdf_info, MakeMode.CONTENT_LIST, image_dir)
|
|
|
|
|
+ content_list_str = json.dumps(content_list, ensure_ascii=False, indent=2)
|
|
|
|
|
+
|
|
|
|
|
+ # 数字标准化处理
|
|
|
|
|
+ if self.normalize_numbers:
|
|
|
|
|
+ original_json = content_list_str
|
|
|
|
|
+ content_list_str = normalize_json_table(content_list_str)
|
|
|
|
|
+
|
|
|
|
|
+ changes_count = len([1 for o, n in zip(original_json, content_list_str) if o != n])
|
|
|
|
|
+ if changes_count > 0:
|
|
|
|
|
+ saved_files['json_normalized'] = f"✅ 已标准化 {changes_count} 个字符(全角→半角)"
|
|
|
|
|
+ else:
|
|
|
|
|
+ saved_files['json_normalized'] = "ℹ️ 无需标准化(已是标准格式)"
|
|
|
|
|
+
|
|
|
|
|
+ md_writer.write_string(f"{pdf_file_name}.json", content_list_str)
|
|
|
|
|
+ saved_files['json'] = os.path.join(local_md_dir, f"{pdf_file_name}.json")
|
|
|
|
|
+
|
|
|
|
|
+ # 绘制布局边界框
|
|
|
|
|
+ try:
|
|
|
|
|
+ draw_layout_bbox(pdf_info, pdf_bytes, local_md_dir, f"{pdf_file_name}_layout.pdf")
|
|
|
|
|
+ saved_files['layout_pdf'] = os.path.join(local_md_dir, f"{pdf_file_name}_layout.pdf")
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.warning(f"Failed to draw layout bbox: {e}")
|
|
|
|
|
+
|
|
|
|
|
+ # 3. 保存原始文件到 images 目录
|
|
|
|
|
+ # original_file_path = Path(original_file_path)
|
|
|
|
|
+ # output_image_path = os.path.join(local_image_dir, f"{pdf_file_name}{original_file_path.suffix}")
|
|
|
|
|
+ # if original_file_path.exists():
|
|
|
|
|
+ # shutil.copy2(str(original_file_path), output_image_path)
|
|
|
|
|
+ # saved_files['image'] = output_image_path
|
|
|
|
|
+
|
|
|
|
|
+ # 4. 调试模式下保存额外信息
|
|
|
|
|
+ if self.debug:
|
|
|
|
|
+ # 保存 middle.json
|
|
|
|
|
+ middle_json_str = json.dumps(middle_json, ensure_ascii=False, indent=2)
|
|
|
|
|
+ if self.normalize_numbers:
|
|
|
|
|
+ middle_json_str = normalize_json_table(middle_json_str)
|
|
|
|
|
+
|
|
|
|
|
+ md_writer.write_string(f"{pdf_file_name}_middle.json", middle_json_str)
|
|
|
|
|
+ saved_files['middle_json'] = os.path.join(local_md_dir, f"{pdf_file_name}_middle.json")
|
|
|
|
|
+
|
|
|
|
|
+ # 保存 model output
|
|
|
|
|
+ if model_output:
|
|
|
|
|
+ model_output_str = json.dumps(model_output, ensure_ascii=False, indent=2)
|
|
|
|
|
+ md_writer.write_string(f"{pdf_file_name}_model.json", model_output_str)
|
|
|
|
|
+ saved_files['model_output'] = os.path.join(local_md_dir, f"{pdf_file_name}_model.json")
|
|
|
|
|
+
|
|
|
|
|
+ # # 保存原始PDF
|
|
|
|
|
+ # md_writer.write(f"{pdf_file_name}_origin.pdf", pdf_bytes)
|
|
|
|
|
+ # saved_files['origin_pdf'] = os.path.join(local_md_dir, f"{pdf_file_name}_origin.pdf")
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+ logger.info(f"Output saved to: {local_md_dir}")
|
|
|
|
|
+
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"Error in _process_output: {e}")
|
|
|
|
|
+ if self.debug:
|
|
|
|
|
+ traceback.print_exc()
|
|
|
|
|
+
|
|
|
|
|
+ return saved_files
|
|
|
|
|
+
|
|
|
|
|
+ def _get_extraction_stats(self, middle_json: Dict) -> Dict[str, Any]:
|
|
|
|
|
+ """
|
|
|
|
|
+ 获取提取统计信息
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ middle_json: 中间JSON数据
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ dict: 统计信息
|
|
|
|
|
+ """
|
|
|
|
|
+ stats = {
|
|
|
|
|
+ "total_blocks": 0,
|
|
|
|
|
+ "block_types": {},
|
|
|
|
|
+ "total_pages": 0
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ pdf_info = middle_json.get("pdf_info", [])
|
|
|
|
|
+ if isinstance(pdf_info, list):
|
|
|
|
|
+ stats["total_pages"] = len(pdf_info)
|
|
|
|
|
+
|
|
|
|
|
+ for page_info in pdf_info:
|
|
|
|
|
+ para_blocks = page_info.get("para_blocks", [])
|
|
|
|
|
+ stats["total_blocks"] += len(para_blocks)
|
|
|
|
|
+
|
|
|
|
|
+ for block in para_blocks:
|
|
|
|
|
+ block_type = block.get("type", "unknown")
|
|
|
|
|
+ stats["block_types"][block_type] = stats["block_types"].get(block_type, 0) + 1
|
|
|
|
|
+
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.warning(f"Failed to get extraction stats: {e}")
|
|
|
|
|
+
|
|
|
|
|
+ return stats
|
|
|
|
|
+
|
|
|
|
|
+ def process_single_image(self, image_path: str, output_dir: str) -> Dict[str, Any]:
|
|
|
|
|
+ """
|
|
|
|
|
+ 处理单张图片 (保持与原接口兼容)
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ image_path: 图片路径
|
|
|
|
|
+ output_dir: 输出目录
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ dict: 处理结果
|
|
|
|
|
+ """
|
|
|
|
|
+ start_time = time.time()
|
|
|
|
|
+ image_name = Path(image_path).stem
|
|
|
|
|
+
|
|
|
|
|
+ result_info = {
|
|
|
|
|
+ "image_path": image_path,
|
|
|
|
|
+ "processing_time": 0,
|
|
|
|
|
+ "success": False,
|
|
|
|
|
+ "server": self.server_url,
|
|
|
|
|
+ "error": None,
|
|
|
|
|
+ "output_files": {},
|
|
|
|
|
+ "is_pdf_page": "_page_" in Path(image_path).name,
|
|
|
|
|
+ "extraction_stats": {}
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ # 检查输出文件是否已存在
|
|
|
|
|
+ expected_md_path = Path(output_dir) / f"{image_name}.md"
|
|
|
|
|
+ expected_json_path = Path(output_dir) / f"{image_name}.json"
|
|
|
|
|
+
|
|
|
|
|
+ if expected_md_path.exists() and expected_json_path.exists():
|
|
|
|
|
+ result_info.update({
|
|
|
|
|
+ "success": True,
|
|
|
|
|
+ "processing_time": 0,
|
|
|
|
|
+ "output_files": {
|
|
|
|
|
+ "md": str(expected_md_path),
|
|
|
|
|
+ "json": str(expected_json_path)
|
|
|
|
|
+ },
|
|
|
|
|
+ "skipped": True
|
|
|
|
|
+ })
|
|
|
|
|
+ return result_info
|
|
|
|
|
+
|
|
|
|
|
+ # 使用 do_parse_single_file 处理
|
|
|
|
|
+ parse_result = self.do_parse_single_file(image_path, output_dir)
|
|
|
|
|
+
|
|
|
|
|
+ if parse_result["success"]:
|
|
|
|
|
+ result_info.update({
|
|
|
|
|
+ "success": True,
|
|
|
|
|
+ "output_files": parse_result["output_files"],
|
|
|
|
|
+ "extraction_stats": parse_result["extraction_stats"]
|
|
|
|
|
+ })
|
|
|
|
|
+ else:
|
|
|
|
|
+ result_info["error"] = parse_result.get("error", "Unknown error")
|
|
|
|
|
+
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ result_info["error"] = str(e)
|
|
|
|
|
+ logger.error(f"Error processing {image_name}: {e}")
|
|
|
|
|
+ if self.debug:
|
|
|
|
|
+ traceback.print_exc()
|
|
|
|
|
+
|
|
|
|
|
+ finally:
|
|
|
|
|
+ result_info["processing_time"] = time.time() - start_time
|
|
|
|
|
+
|
|
|
|
|
+ return result_info
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def process_images_single_process(image_paths: List[str],
|
|
|
|
|
+ processor: MinerUVLLMProcessor,
|
|
|
|
|
+ batch_size: int = 1,
|
|
|
|
|
+ output_dir: str = "./output") -> List[Dict[str, Any]]:
|
|
|
|
|
+ """
|
|
|
|
|
+ 单进程版本的图像处理函数
|
|
|
|
|
+ """
|
|
|
|
|
+ # 创建输出目录
|
|
|
|
|
+ output_path = Path(output_dir)
|
|
|
|
|
+ output_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
+
|
|
|
|
|
+ all_results = []
|
|
|
|
|
+ total_images = len(image_paths)
|
|
|
|
|
+
|
|
|
|
|
+ print(f"Processing {total_images} images with batch size {batch_size}")
|
|
|
|
|
+
|
|
|
|
|
+ with tqdm(total=total_images, desc="Processing images", unit="img",
|
|
|
|
|
+ bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') as pbar:
|
|
|
|
|
+
|
|
|
|
|
+ for i in range(0, total_images, batch_size):
|
|
|
|
|
+ batch = image_paths[i:i + batch_size]
|
|
|
|
|
+ batch_start_time = time.time()
|
|
|
|
|
+ batch_results = []
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ for image_path in batch:
|
|
|
|
|
+ try:
|
|
|
|
|
+ result = processor.process_single_image(image_path, output_dir)
|
|
|
|
|
+ batch_results.append(result)
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"Error processing {image_path}: {e}")
|
|
|
|
|
+ batch_results.append({
|
|
|
|
|
+ "image_path": image_path,
|
|
|
|
|
+ "processing_time": 0,
|
|
|
|
|
+ "success": False,
|
|
|
|
|
+ "server": processor.server_url,
|
|
|
|
|
+ "error": str(e)
|
|
|
|
|
+ })
|
|
|
|
|
+
|
|
|
|
|
+ batch_processing_time = time.time() - batch_start_time
|
|
|
|
|
+ all_results.extend(batch_results)
|
|
|
|
|
+
|
|
|
|
|
+ # 更新进度条
|
|
|
|
|
+ success_count = sum(1 for r in batch_results if r.get('success', False))
|
|
|
|
|
+ skipped_count = sum(1 for r in batch_results if r.get('skipped', False))
|
|
|
|
|
+ total_success = sum(1 for r in all_results if r.get('success', False))
|
|
|
|
|
+ total_skipped = sum(1 for r in all_results if r.get('skipped', False))
|
|
|
|
|
+ avg_time = batch_processing_time / len(batch) if len(batch) > 0 else 0
|
|
|
|
|
+
|
|
|
|
|
+ total_blocks = sum(r.get('extraction_stats', {}).get('total_blocks', 0) for r in batch_results)
|
|
|
|
|
+
|
|
|
|
|
+ pbar.update(len(batch))
|
|
|
|
|
+ pbar.set_postfix({
|
|
|
|
|
+ 'batch_time': f"{batch_processing_time:.2f}s",
|
|
|
|
|
+ 'avg_time': f"{avg_time:.2f}s/img",
|
|
|
|
|
+ 'success': f"{total_success}/{len(all_results)}",
|
|
|
|
|
+ 'skipped': f"{total_skipped}",
|
|
|
|
|
+ 'blocks': f"{total_blocks}",
|
|
|
|
|
+ 'rate': f"{total_success/len(all_results)*100:.1f}%" if len(all_results) > 0 else "0%"
|
|
|
|
|
+ })
|
|
|
|
|
+
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"Error processing batch {[Path(p).name for p in batch]}: {e}")
|
|
|
|
|
+ error_results = []
|
|
|
|
|
+ for img_path in batch:
|
|
|
|
|
+ error_results.append({
|
|
|
|
|
+ "image_path": str(img_path),
|
|
|
|
|
+ "processing_time": 0,
|
|
|
|
|
+ "success": False,
|
|
|
|
|
+ "server": processor.server_url,
|
|
|
|
|
+ "error": str(e)
|
|
|
|
|
+ })
|
|
|
|
|
+ all_results.extend(error_results)
|
|
|
|
|
+ pbar.update(len(batch))
|
|
|
|
|
+
|
|
|
|
|
+ return all_results
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def main():
|
|
|
|
|
+ """主函数"""
|
|
|
|
|
+ parser = argparse.ArgumentParser(description="MinerU vLLM Batch Processing (demo.py framework)")
|
|
|
|
|
+
|
|
|
|
|
+ # 输入参数组
|
|
|
|
|
+ input_group = parser.add_mutually_exclusive_group(required=True)
|
|
|
|
|
+ input_group.add_argument("--input_file", type=str, help="Input file (supports both PDF and image file)")
|
|
|
|
|
+ input_group.add_argument("--input_dir", type=str, help="Input directory (supports both PDF and image files)")
|
|
|
|
|
+ input_group.add_argument("--input_file_list", type=str, help="Input file list (one file per line)")
|
|
|
|
|
+ input_group.add_argument("--input_csv", type=str, help="Input CSV file with image_path and status columns")
|
|
|
|
|
+
|
|
|
|
|
+ # 输出参数
|
|
|
|
|
+ parser.add_argument("--output_dir", type=str, required=True, help="Output directory")
|
|
|
|
|
+
|
|
|
|
|
+ # MinerU vLLM 参数
|
|
|
|
|
+ parser.add_argument("--server_url", type=str, default="http://127.0.0.1:8121",
|
|
|
|
|
+ help="MinerU vLLM server URL")
|
|
|
|
|
+ parser.add_argument("--timeout", type=int, default=300, help="Request timeout in seconds")
|
|
|
|
|
+ parser.add_argument("--dpi", type=int, default=200, help="PDF processing DPI")
|
|
|
|
|
+ parser.add_argument('--no-normalize', action='store_true', help='禁用数字标准化')
|
|
|
|
|
+ parser.add_argument('--debug', action='store_true', help='启用调试模式')
|
|
|
|
|
+
|
|
|
|
|
+ # 处理参数
|
|
|
|
|
+ parser.add_argument("--batch_size", type=int, default=1, help="Batch size")
|
|
|
|
|
+ parser.add_argument("--test_mode", action="store_true", help="Test mode (process only 10 images)")
|
|
|
|
|
+ parser.add_argument("--collect_results", type=str, help="收集处理结果到指定CSV文件")
|
|
|
|
|
+
|
|
|
|
|
+ args = parser.parse_args()
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ # 获取并预处理输入文件
|
|
|
|
|
+ print("🔄 Preprocessing input files...")
|
|
|
|
|
+ image_files = get_input_files(args)
|
|
|
|
|
+
|
|
|
|
|
+ if not image_files:
|
|
|
|
|
+ print("❌ No input files found or processed")
|
|
|
|
|
+ return 1
|
|
|
|
|
+
|
|
|
|
|
+ output_dir = Path(args.output_dir).resolve()
|
|
|
|
|
+ print(f"📁 Output dir: {output_dir}")
|
|
|
|
|
+ print(f"📊 Found {len(image_files)} image files to process")
|
|
|
|
|
+
|
|
|
|
|
+ if args.test_mode:
|
|
|
|
|
+ image_files = image_files[:10]
|
|
|
|
|
+ print(f"🧪 Test mode: processing only {len(image_files)} images")
|
|
|
|
|
+
|
|
|
|
|
+ print(f"🌐 Using server: {args.server_url}")
|
|
|
|
|
+ print(f"📦 Batch size: {args.batch_size}")
|
|
|
|
|
+ print(f"⏱️ Timeout: {args.timeout}s")
|
|
|
|
|
+
|
|
|
|
|
+ # 创建处理器
|
|
|
|
|
+ processor = MinerUVLLMProcessor(
|
|
|
|
|
+ server_url=args.server_url,
|
|
|
|
|
+ timeout=args.timeout,
|
|
|
|
|
+ normalize_numbers=not args.no_normalize,
|
|
|
|
|
+ debug=args.debug
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # 开始处理
|
|
|
|
|
+ start_time = time.time()
|
|
|
|
|
+ results = process_images_single_process(
|
|
|
|
|
+ image_files,
|
|
|
|
|
+ processor,
|
|
|
|
|
+ args.batch_size,
|
|
|
|
|
+ str(output_dir)
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ total_time = time.time() - start_time
|
|
|
|
|
+
|
|
|
|
|
+ # 统计结果
|
|
|
|
|
+ success_count = sum(1 for r in results if r.get('success', False))
|
|
|
|
|
+ skipped_count = sum(1 for r in results if r.get('skipped', False))
|
|
|
|
|
+ error_count = len(results) - success_count
|
|
|
|
|
+ pdf_page_count = sum(1 for r in results if r.get('is_pdf_page', False))
|
|
|
|
|
+
|
|
|
|
|
+ # 统计提取的块信息
|
|
|
|
|
+ total_blocks = sum(r.get('extraction_stats', {}).get('total_blocks', 0) for r in results)
|
|
|
|
|
+ block_type_stats = {}
|
|
|
|
|
+ for result in results:
|
|
|
|
|
+ if 'extraction_stats' in result and 'block_types' in result['extraction_stats']:
|
|
|
|
|
+ for block_type, count in result['extraction_stats']['block_types'].items():
|
|
|
|
|
+ block_type_stats[block_type] = block_type_stats.get(block_type, 0) + count
|
|
|
|
|
+
|
|
|
|
|
+ print(f"\n" + "="*60)
|
|
|
|
|
+ print(f"✅ Processing completed!")
|
|
|
|
|
+ print(f"📊 Statistics:")
|
|
|
|
|
+ print(f" Total files processed: {len(image_files)}")
|
|
|
|
|
+ print(f" PDF pages processed: {pdf_page_count}")
|
|
|
|
|
+ print(f" Regular images processed: {len(image_files) - pdf_page_count}")
|
|
|
|
|
+ print(f" Successful: {success_count}")
|
|
|
|
|
+ print(f" Skipped: {skipped_count}")
|
|
|
|
|
+ print(f" Failed: {error_count}")
|
|
|
|
|
+ if len(image_files) > 0:
|
|
|
|
|
+ print(f" Success rate: {success_count / len(image_files) * 100:.2f}%")
|
|
|
|
|
+
|
|
|
|
|
+ print(f"📋 Content Extraction:")
|
|
|
|
|
+ print(f" Total blocks extracted: {total_blocks}")
|
|
|
|
|
+ if block_type_stats:
|
|
|
|
|
+ print(f" Block types:")
|
|
|
|
|
+ for block_type, count in sorted(block_type_stats.items()):
|
|
|
|
|
+ print(f" {block_type}: {count}")
|
|
|
|
|
+
|
|
|
|
|
+ print(f"⏱️ Performance:")
|
|
|
|
|
+ print(f" Total time: {total_time:.2f} seconds")
|
|
|
|
|
+ if total_time > 0:
|
|
|
|
|
+ print(f" Throughput: {len(image_files) / total_time:.2f} images/second")
|
|
|
|
|
+ print(f" Avg time per image: {total_time / len(image_files):.2f} seconds")
|
|
|
|
|
+
|
|
|
|
|
+ print(f"\n📁 Output Structure (demo.py compatible):")
|
|
|
|
|
+ print(f" output_dir/")
|
|
|
|
|
+ print(f" ├── filename.md # Markdown content")
|
|
|
|
|
+ print(f" ├── filename.json # Content list")
|
|
|
|
|
+ print(f" ├── filename_layout.json # Debug: layout bbox")
|
|
|
|
|
+ print(f" └── images/ # Extracted images")
|
|
|
|
|
+ print(f" └── filename.png")
|
|
|
|
|
+ if args.debug:
|
|
|
|
|
+ print(f" ├── filename_middle.json # Debug: middle JSON")
|
|
|
|
|
+ print(f" └── filename_model.json # Debug: model output")
|
|
|
|
|
+
|
|
|
|
|
+ # 保存结果统计
|
|
|
|
|
+ stats = {
|
|
|
|
|
+ "total_files": len(image_files),
|
|
|
|
|
+ "pdf_pages": pdf_page_count,
|
|
|
|
|
+ "regular_images": len(image_files) - pdf_page_count,
|
|
|
|
|
+ "success_count": success_count,
|
|
|
|
|
+ "skipped_count": skipped_count,
|
|
|
|
|
+ "error_count": error_count,
|
|
|
|
|
+ "success_rate": success_count / len(image_files) if len(image_files) > 0 else 0,
|
|
|
|
|
+ "total_time": total_time,
|
|
|
|
|
+ "throughput": len(image_files) / total_time if total_time > 0 else 0,
|
|
|
|
|
+ "avg_time_per_image": total_time / len(image_files) if len(image_files) > 0 else 0,
|
|
|
|
|
+ "batch_size": args.batch_size,
|
|
|
|
|
+ "server": args.server_url,
|
|
|
|
|
+ "backend": "vlm-http-client",
|
|
|
|
|
+ "timeout": args.timeout,
|
|
|
|
|
+ "pdf_dpi": args.dpi,
|
|
|
|
|
+ "total_blocks": total_blocks,
|
|
|
|
|
+ "block_type_stats": block_type_stats,
|
|
|
|
|
+ "normalization_enabled": not args.no_normalize,
|
|
|
|
|
+ "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ # 保存最终结果
|
|
|
|
|
+ output_file_name = Path(output_dir).name
|
|
|
|
|
+ output_file = os.path.join(output_dir, f"{output_file_name}_results.json")
|
|
|
|
|
+ final_results = {
|
|
|
|
|
+ "stats": stats,
|
|
|
|
|
+ "results": results
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ with open(output_file, 'w', encoding='utf-8') as f:
|
|
|
|
|
+ json.dump(final_results, f, ensure_ascii=False, indent=2)
|
|
|
|
|
+
|
|
|
|
|
+ print(f"💾 Results saved to: {output_file}")
|
|
|
|
|
+
|
|
|
|
|
+ # 收集处理结果
|
|
|
|
|
+ if not args.collect_results:
|
|
|
|
|
+ output_file_processed = Path(args.output_dir) / f"processed_files_{time.strftime('%Y%m%d_%H%M%S')}.csv"
|
|
|
|
|
+ else:
|
|
|
|
|
+ output_file_processed = Path(args.collect_results).resolve()
|
|
|
|
|
+
|
|
|
|
|
+ processed_files = collect_pid_files(output_file)
|
|
|
|
|
+ with open(output_file_processed, 'w', encoding='utf-8') as f:
|
|
|
|
|
+ f.write("image_path,status\n")
|
|
|
|
|
+ for file_path, status in processed_files:
|
|
|
|
|
+ f.write(f"{file_path},{status}\n")
|
|
|
|
|
+ print(f"💾 Processed files saved to: {output_file_processed}")
|
|
|
|
|
+
|
|
|
|
|
+ return 0
|
|
|
|
|
+
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"Processing failed: {e}")
|
|
|
|
|
+ traceback.print_exc()
|
|
|
|
|
+ return 1
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+if __name__ == "__main__":
|
|
|
|
|
+ print(f"🚀 启动MinerU vLLM统一PDF/图像处理程序...")
|
|
|
|
|
+ print(f"🔧 CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', 'Not set')}")
|
|
|
|
|
+
|
|
|
|
|
+ if len(sys.argv) == 1:
|
|
|
|
|
+ # 如果没有命令行参数,使用默认配置运行
|
|
|
|
|
+ print("ℹ️ No command line arguments provided. Running with default configuration...")
|
|
|
|
|
+
|
|
|
|
|
+ # 默认配置
|
|
|
|
|
+ default_config = {
|
|
|
|
|
+ "input_file": "/home/ubuntu/zhch/data/至远彩色印刷工业有限公司/2023年度报告母公司.img/2023年度报告母公司_page_003.png",
|
|
|
|
|
+ "output_dir": "./output",
|
|
|
|
|
+ "collect_results": f"./output/processed_files_{time.strftime('%Y%m%d_%H%M%S')}.csv",
|
|
|
|
|
+ "server_url": "http://127.0.0.1:8121",
|
|
|
|
|
+ "timeout": "300",
|
|
|
|
|
+ "batch_size": "1",
|
|
|
|
|
+ "dpi": "200",
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ # 构造参数
|
|
|
|
|
+ sys.argv = [sys.argv[0]]
|
|
|
|
|
+ for key, value in default_config.items():
|
|
|
|
|
+ sys.argv.extend([f"--{key}", str(value)])
|
|
|
|
|
+
|
|
|
|
|
+ # 测试模式和调试模式
|
|
|
|
|
+ sys.argv.append("--test_mode")
|
|
|
|
|
+ sys.argv.append("--debug")
|
|
|
|
|
+
|
|
|
|
|
+ sys.exit(main())
|