""" 批量处理 OmniDocBench 图片并生成符合评测要求的预测结果 根据 OmniDocBench 评测要求: - 输入:OpenDataLab___OmniDocBench/images 下的所有 .jpg 图片 - 输出:每个图片对应的 .md、.json 和带标注的 layout 图片文件 - 输出目录:用于后续的 end2end 评测 """ import os import sys import json import tempfile import uuid import shutil from pathlib import Path from PIL import Image from tqdm import tqdm import argparse # 导入 dots.ocr 相关模块 from dots_ocr.parser import DotsOCRParser from dots_ocr.utils import dict_promptmode_to_prompt from dots_ocr.utils.consts import MIN_PIXELS, MAX_PIXELS class OmniDocBenchProcessor: """OmniDocBench 批量处理器""" def __init__(self, ip="127.0.0.1", port=8101, model_name="DotsOCR", prompt_mode="prompt_layout_all_en", dpi=200, min_pixels=MIN_PIXELS, max_pixels=MAX_PIXELS): """ 初始化处理器 Args: ip: vLLM 服务器 IP port: vLLM 服务器端口 model_name: 模型名称 prompt_mode: 提示模式 dpi: PDF 处理 DPI min_pixels: 最小像素数 max_pixels: 最大像素数 """ self.parser = DotsOCRParser( ip=ip, port=port, dpi=dpi, min_pixels=min_pixels, max_pixels=max_pixels, model_name=model_name ) self.prompt_mode = prompt_mode print(f"DotsOCR Parser 初始化完成:") print(f" - 服务器: {ip}:{port}") print(f" - 模型: {model_name}") print(f" - 提示模式: {prompt_mode}") print(f" - 像素范围: {min_pixels} - {max_pixels}") def create_temp_session_dir(self): """创建临时会话目录""" session_id = uuid.uuid4().hex[:8] temp_dir = os.path.join(tempfile.gettempdir(), f"omnidocbench_batch_{session_id}") os.makedirs(temp_dir, exist_ok=True) return temp_dir, session_id def save_results_to_output_dir(self, result, image_name, output_dir): """ 将处理结果保存到输出目录 Args: result: 解析结果 image_name: 图片文件名(不含扩展名) output_dir: 输出目录 Returns: dict: 保存的文件路径 """ saved_files = {} # 1. 保存 Markdown 文件(OmniDocBench 评测必需) output_md_path = os.path.join(output_dir, f"{image_name}.md") md_content = "" # 优先使用无页眉页脚的版本(符合 OmniDocBench 评测要求) if 'md_content_nohf_path' in result and os.path.exists(result['md_content_nohf_path']): with open(result['md_content_nohf_path'], 'r', encoding='utf-8') as f: md_content = f.read() elif 'md_content_path' in result and os.path.exists(result['md_content_path']): with open(result['md_content_path'], 'r', encoding='utf-8') as f: md_content = f.read() else: md_content = "# 解析失败\n\n未能提取到有效的文档内容。" with open(output_md_path, 'w', encoding='utf-8') as f: f.write(md_content) saved_files['md'] = output_md_path # 2. 保存 JSON 文件 output_json_path = os.path.join(output_dir, f"{image_name}.json") json_data = {} if 'layout_info_path' in result and os.path.exists(result['layout_info_path']): with open(result['layout_info_path'], 'r', encoding='utf-8') as f: json_data = json.load(f) else: json_data = { "error": "未能提取到有效的布局信息", "cells": [] } with open(output_json_path, 'w', encoding='utf-8') as f: json.dump(json_data, f, ensure_ascii=False, indent=2) saved_files['json'] = output_json_path # 3. 保存带标注的布局图片 output_layout_image_path = os.path.join(output_dir, f"{image_name}_layout.jpg") if 'layout_image_path' in result and os.path.exists(result['layout_image_path']): # 直接复制布局图片 shutil.copy2(result['layout_image_path'], output_layout_image_path) saved_files['layout_image'] = output_layout_image_path else: # 如果没有布局图片,使用原始图片作为占位符 try: original_image = Image.open(result.get('original_image_path', '')) original_image.save(output_layout_image_path, 'JPEG', quality=95) saved_files['layout_image'] = output_layout_image_path print(f"⚠️ 使用原始图片作为布局图片: {image_name}") except Exception as e: print(f"⚠️ 无法保存布局图片: {image_name}, 错误: {e}") saved_files['layout_image'] = None # 4. 可选:保存原始图片副本 output_original_image_path = os.path.join(output_dir, f"{image_name}_original.jpg") if 'original_image_path' in result and os.path.exists(result['original_image_path']): shutil.copy2(result['original_image_path'], output_original_image_path) saved_files['original_image'] = output_original_image_path return saved_files def process_single_image(self, image_path, output_dir): """ 处理单张图片 Args: image_path: 图片路径 output_dir: 输出目录 Returns: bool: 处理是否成功 """ try: # 获取图片文件名(不含扩展名) image_name = Path(image_path).stem # 检查输出文件是否已存在 output_md_path = os.path.join(output_dir, f"{image_name}.md") output_json_path = os.path.join(output_dir, f"{image_name}.json") output_layout_path = os.path.join(output_dir, f"{image_name}_layout.jpg") if all(os.path.exists(p) for p in [output_md_path, output_json_path, output_layout_path]): print(f"跳过已存在的文件: {image_name}") return True # 创建临时会话目录 temp_dir, session_id = self.create_temp_session_dir() try: # 读取图片 image = Image.open(image_path) # 使用 DotsOCRParser 处理图片 filename = f"omnidocbench_{session_id}" results = self.parser.parse_image( input_path=image, filename=filename, prompt_mode=self.prompt_mode, save_dir=temp_dir, fitz_preprocess=True # 对图片使用 fitz 预处理 ) # 解析结果 if not results: print(f"警告: {image_name} 未返回解析结果") return False result = results[0] # parse_image 返回单个结果的列表 # 添加原始图片路径到结果中 result['original_image_path'] = image_path # 保存所有结果文件到输出目录 saved_files = self.save_results_to_output_dir(result, image_name, output_dir) # 验证保存结果 success_count = sum(1 for path in saved_files.values() if path and os.path.exists(path)) total_expected = 3 # md, json, layout_image if success_count >= 2: # 至少保存了 md 和 json print(f"✅ 成功处理: {image_name} (保存了 {success_count}/{total_expected} 个文件)") return True else: print(f"⚠️ 部分成功: {image_name} (保存了 {success_count}/{total_expected} 个文件)") return False except Exception as e: print(f"❌ 处理 {image_name} 时出错: {str(e)}") return False finally: # 清理临时目录 if os.path.exists(temp_dir): shutil.rmtree(temp_dir, ignore_errors=True) except Exception as e: print(f"❌ 处理 {image_path} 时出现致命错误: {str(e)}") return False def process_batch(self, images_dir, output_dir): """ 批量处理图片 Args: images_dir: 输入图片目录 output_dir: 输出目录 """ # 创建输出目录 os.makedirs(output_dir, exist_ok=True) # 获取所有图片文件 image_extensions = ['.jpg', '.jpeg', '.png'] image_files = [] for ext in image_extensions: image_files.extend(Path(images_dir).glob(f"*{ext}")) image_files.extend(Path(images_dir).glob(f"*{ext.upper()}")) image_files = sorted(image_files) if not image_files: print(f"在 {images_dir} 中未找到图片文件") return print(f"找到 {len(image_files)} 个图片文件") print(f"输出目录结构: {output_dir}") # 统计变量 success_count = 0 failed_count = 0 skipped_count = 0 # 使用进度条处理 with tqdm(image_files, desc="处理图片", unit="张") as pbar: for image_path in pbar: # 更新进度条描述 pbar.set_description(f"处理: {image_path.name}") # 检查输出文件是否已存在(在主输出目录中) image_name = image_path.stem output_md_path = os.path.join(output_dir, f"{image_name}.md") output_json_path = os.path.join(output_dir, f"{image_name}.json") output_layout_path = os.path.join(output_dir, f"{image_name}_layout.jpg") if all(os.path.exists(p) for p in [output_md_path, output_json_path, output_layout_path]): skipped_count += 1 continue # 处理图片 if self.process_single_image(str(image_path), output_dir): success_count += 1 else: failed_count += 1 # 更新进度条后缀 pbar.set_postfix({ 'success': success_count, 'failed': failed_count, 'skipped': skipped_count }) # 输出最终统计 print(f"\n🎉 批量处理完成!") print(f" ✅ 成功: {success_count}") print(f" ❌ 失败: {failed_count}") print(f" ⏭️ 跳过: {skipped_count}") print(f" 📁 输出目录: {output_dir}") # 生成处理报告 self.generate_processing_report(output_dir, success_count, failed_count, skipped_count) def generate_processing_report(self, output_dir, success_count, failed_count, skipped_count): """生成处理报告""" report_path = os.path.join(output_dir, "processing_report.json") report = { "processing_summary": { "success_count": success_count, "failed_count": failed_count, "skipped_count": skipped_count, "total_processed": success_count + failed_count + skipped_count }, "output_structure": { "markdown_files": f"{output_dir}/*.md", "json_files": f"{output_dir}/*.json", "layout_images": f"{output_dir}/*_layout.jpg", "original_images": f"{output_dir}/*_original.jpg" }, "configuration": { "prompt_mode": self.prompt_mode, "server": f"{self.parser.ip}:{self.parser.port}", "pixel_range": f"{self.parser.min_pixels} - {self.parser.max_pixels}" } } with open(report_path, 'w', encoding='utf-8') as f: json.dump(report, f, ensure_ascii=False, indent=2) print(f"📊 处理报告已保存: {report_path}") def main(): parser = argparse.ArgumentParser(description="批量处理 OmniDocBench 图片") parser.add_argument( "--images_dir", type=str, default="../OmniDocBench/OpenDataLab___OmniDocBench/images", help="输入图片目录路径" ) parser.add_argument( "--output_dir", type=str, default="./omnidocbench_predictions", help="输出目录路径" ) parser.add_argument( "--ip", type=str, default="127.0.0.1", help="vLLM 服务器 IP" ) parser.add_argument( "--port", type=int, default=8101, help="vLLM 服务器端口" ) parser.add_argument( "--model_name", type=str, default="DotsOCR", help="模型名称" ) parser.add_argument( "--prompt_mode", type=str, default="prompt_layout_all_en", choices=list(dict_promptmode_to_prompt.keys()), help="提示模式" ) parser.add_argument( "--min_pixels", type=int, default=MIN_PIXELS, help="最小像素数" ) parser.add_argument( "--max_pixels", type=int, default=MAX_PIXELS, help="最大像素数" ) parser.add_argument( "--dpi", type=int, default=200, help="PDF 处理 DPI" ) args = parser.parse_args() # 检查输入目录 if not os.path.exists(args.images_dir): print(f"❌ 输入目录不存在: {args.images_dir}") return print(f"🚀 开始批量处理 OmniDocBench 图片") print(f"📁 输入目录: {args.images_dir}") print(f"📁 输出目录: {args.output_dir}") print("="*60) # 创建处理器 processor = OmniDocBenchProcessor( ip=args.ip, port=args.port, model_name=args.model_name, prompt_mode=args.prompt_mode, dpi=args.dpi, min_pixels=args.min_pixels, max_pixels=args.max_pixels ) # 开始批量处理 processor.process_batch(args.images_dir, args.output_dir) if __name__ == "__main__": print(f"🚀 启动单进程DotsOCR程序...") print(f"🔧 CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', 'Not set')}") if len(sys.argv) == 1: # 如果没有命令行参数,使用默认配置运行 print("ℹ️ No command line arguments provided. Running with default configuration...") # 默认配置 default_config = { "input_dir": "../../OmniDocBench/OpenDataLab___OmniDocBench/images", "output_dir": "./OmniDocBench_Results_Single", } # 构造参数 sys.argv = [sys.argv[0]] for key, value in default_config.items(): sys.argv.extend([f"--{key}", str(value)]) # 测试模式 # sys.argv.append("--test_mode") sys.exit(main())