""" 批量处理 OmniDocBench 图片并生成符合评测要求的预测结果 根据 OmniDocBench 评测要求: - 输入:OpenDataLab___OmniDocBench/images 下的所有 .jpg 图片 - 输出:每个图片对应的 .md、.json 和带标注的 layout 图片文件 - 输出目录:用于后续的 end2end 评测 """ import os import sys import json import tempfile import uuid import shutil import time import traceback import warnings from pathlib import Path from typing import List, Dict, Any from PIL import Image from tqdm import tqdm import argparse # 导入 dots.ocr 相关模块 from dots_ocr.parser import DotsOCRParser from dots_ocr.utils import dict_promptmode_to_prompt from dots_ocr.utils.consts import MIN_PIXELS, MAX_PIXELS # 导入工具函数 from utils import ( get_image_files_from_dir, get_image_files_from_list, get_image_files_from_csv, collect_pid_files ) class DotsOCRProcessor: """DotsOCR 处理器""" def __init__(self, ip: str = "127.0.0.1", port: int = 8101, model_name: str = "DotsOCR", prompt_mode: str = "prompt_layout_all_en", dpi: int = 200, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS): """ 初始化处理器 Args: ip: vLLM 服务器 IP port: vLLM 服务器端口 model_name: 模型名称 prompt_mode: 提示模式 dpi: PDF 处理 DPI min_pixels: 最小像素数 max_pixels: 最大像素数 """ self.ip = ip self.port = port self.model_name = model_name self.prompt_mode = prompt_mode self.dpi = dpi self.min_pixels = min_pixels self.max_pixels = max_pixels # 初始化解析器 self.parser = DotsOCRParser( ip=ip, port=port, dpi=dpi, min_pixels=min_pixels, max_pixels=max_pixels, model_name=model_name ) print(f"DotsOCR Parser 初始化完成:") print(f" - 服务器: {ip}:{port}") print(f" - 模型: {model_name}") print(f" - 提示模式: {prompt_mode}") print(f" - 像素范围: {min_pixels} - {max_pixels}") def create_temp_session_dir(self) -> tuple: """创建临时会话目录""" session_id = uuid.uuid4().hex[:8] temp_dir = os.path.join(tempfile.gettempdir(), f"omnidocbench_batch_{session_id}") os.makedirs(temp_dir, exist_ok=True) return temp_dir, session_id def save_results_to_output_dir(self, result: Dict, image_name: str, output_dir: str) -> Dict[str, str]: """ 将处理结果保存到输出目录 Args: result: 解析结果 image_name: 图片文件名(不含扩展名) output_dir: 输出目录 Returns: dict: 保存的文件路径 """ saved_files = {} try: # 1. 保存 Markdown 文件(OmniDocBench 评测必需) output_md_path = os.path.join(output_dir, f"{image_name}.md") md_content = "" # 优先使用无页眉页脚的版本(符合 OmniDocBench 评测要求) if 'md_content_nohf_path' in result and os.path.exists(result['md_content_nohf_path']): with open(result['md_content_nohf_path'], 'r', encoding='utf-8') as f: md_content = f.read() elif 'md_content_path' in result and os.path.exists(result['md_content_path']): with open(result['md_content_path'], 'r', encoding='utf-8') as f: md_content = f.read() else: md_content = "# 解析失败\n\n未能提取到有效的文档内容。" with open(output_md_path, 'w', encoding='utf-8') as f: f.write(md_content) saved_files['md'] = output_md_path # 2. 保存 JSON 文件 output_json_path = os.path.join(output_dir, f"{image_name}.json") json_data = {} if 'layout_info_path' in result and os.path.exists(result['layout_info_path']): with open(result['layout_info_path'], 'r', encoding='utf-8') as f: json_data = json.load(f) else: json_data = { "error": "未能提取到有效的布局信息", "cells": [] } with open(output_json_path, 'w', encoding='utf-8') as f: json.dump(json_data, f, ensure_ascii=False, indent=2) saved_files['json'] = output_json_path # 3. 保存带标注的布局图片 output_layout_image_path = os.path.join(output_dir, f"{image_name}_layout.jpg") if 'layout_image_path' in result and os.path.exists(result['layout_image_path']): # 直接复制布局图片 shutil.copy2(result['layout_image_path'], output_layout_image_path) saved_files['layout_image'] = output_layout_image_path else: # 如果没有布局图片,使用原始图片作为占位符 try: original_image = Image.open(result.get('original_image_path', '')) original_image.save(output_layout_image_path, 'JPEG', quality=95) saved_files['layout_image'] = output_layout_image_path except Exception as e: saved_files['layout_image'] = None # # 4. 可选:保存原始图片副本 # output_original_image_path = os.path.join(output_dir, f"{image_name}_original.jpg") # if 'original_image_path' in result and os.path.exists(result['original_image_path']): # shutil.copy2(result['original_image_path'], output_original_image_path) # saved_files['original_image'] = output_original_image_path except Exception as e: print(f"Error saving results for {image_name}: {e}") return saved_files def process_single_image(self, image_path: str, output_dir: str) -> Dict[str, Any]: """ 处理单张图片 Args: image_path: 图片路径 output_dir: 输出目录 Returns: dict: 处理结果 """ start_time = time.time() image_name = Path(image_path).stem result_info = { "image_path": image_path, "processing_time": 0, "success": False, "device": f"{self.ip}:{self.port}", "error": None, "output_files": {} } try: # 检查输出文件是否已存在 output_md_path = os.path.join(output_dir, f"{image_name}.md") output_json_path = os.path.join(output_dir, f"{image_name}.json") output_layout_path = os.path.join(output_dir, f"{image_name}_layout.jpg") if all(os.path.exists(p) for p in [output_md_path, output_json_path, output_layout_path]): result_info.update({ "success": True, "processing_time": 0, "output_files": { "md": output_md_path, "json": output_json_path, "layout_image": output_layout_path }, "skipped": True }) return result_info # 创建临时会话目录 temp_dir, session_id = self.create_temp_session_dir() try: # 读取图片 image = Image.open(image_path) # 使用 DotsOCRParser 处理图片 filename = f"omnidocbench_{session_id}" results = self.parser.parse_image( input_path=image, filename=filename, prompt_mode=self.prompt_mode, save_dir=temp_dir, fitz_preprocess=True # 对图片使用 fitz 预处理 ) # 解析结果 if not results: raise Exception("未返回解析结果") result = results[0] # parse_image 返回单个结果的列表 # 添加原始图片路径到结果中 # result['original_image_path'] = image_path # 保存所有结果文件到输出目录 saved_files = self.save_results_to_output_dir(result, image_name, output_dir) # 验证保存结果 success_count = sum(1 for path in saved_files.values() if path and os.path.exists(path)) if success_count >= 2: # 至少保存了 md 和 json result_info.update({ "success": True, "output_files": saved_files }) else: raise Exception(f"保存文件不完整 ({success_count}/3)") finally: # 清理临时目录 if os.path.exists(temp_dir): shutil.rmtree(temp_dir, ignore_errors=True) except Exception as e: result_info["error"] = str(e) finally: result_info["processing_time"] = time.time() - start_time return result_info def process_images_single_process(image_paths: List[str], processor: DotsOCRProcessor, batch_size: int = 1, output_dir: str = "./output") -> List[Dict[str, Any]]: """ 单进程版本的图像处理函数 Args: image_paths: 图像路径列表 processor: DotsOCR处理器实例 batch_size: 批处理大小 output_dir: 输出目录 Returns: 处理结果列表 """ # 创建输出目录 output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) all_results = [] total_images = len(image_paths) print(f"Processing {total_images} images with batch size {batch_size}") # 使用tqdm显示进度,添加更多统计信息 with tqdm(total=total_images, desc="Processing images", unit="img", bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') as pbar: # 按批次处理图像(DotsOCR通常单张处理) for i in range(0, total_images, batch_size): batch = image_paths[i:i + batch_size] batch_start_time = time.time() batch_results = [] try: # 处理批次中的每张图片 for image_path in batch: try: result = processor.process_single_image(image_path, output_dir) batch_results.append(result) except Exception as e: print(f"Error processing {image_path}: {e}", file=sys.stderr) traceback.print_exc() batch_results.append({ "image_path": image_path, "processing_time": 0, "success": False, "device": f"{processor.ip}:{processor.port}", "error": str(e) }) batch_processing_time = time.time() - batch_start_time all_results.extend(batch_results) # 更新进度条 success_count = sum(1 for r in batch_results if r.get('success', False)) skipped_count = sum(1 for r in batch_results if r.get('skipped', False)) total_success = sum(1 for r in all_results if r.get('success', False)) total_skipped = sum(1 for r in all_results if r.get('skipped', False)) avg_time = batch_processing_time / len(batch) pbar.update(len(batch)) pbar.set_postfix({ 'batch_time': f"{batch_processing_time:.2f}s", 'avg_time': f"{avg_time:.2f}s/img", 'success': f"{total_success}/{len(all_results)}", 'skipped': f"{total_skipped}", 'rate': f"{total_success/len(all_results)*100:.1f}%" }) except Exception as e: print(f"Error processing batch {[Path(p).name for p in batch]}: {e}", file=sys.stderr) traceback.print_exc() # 为批次中的所有图像添加错误结果 error_results = [] for img_path in batch: error_results.append({ "image_path": str(img_path), "processing_time": 0, "success": False, "device": f"{processor.ip}:{processor.port}", "error": str(e) }) all_results.extend(error_results) pbar.update(len(batch)) return all_results def main(): """主函数""" parser = argparse.ArgumentParser(description="DotsOCR OmniDocBench Single Process Processing") # 输入参数组 input_group = parser.add_mutually_exclusive_group(required=True) input_group.add_argument("--input_dir", type=str, help="Input directory") input_group.add_argument("--input_file_list", type=str, help="Input file list (one file per line)") input_group.add_argument("--input_csv", type=str, help="Input CSV file with image_path and status columns") # 输出参数 parser.add_argument("--output_dir", type=str, help="Output directory") # DotsOCR 参数 parser.add_argument("--ip", type=str, default="127.0.0.1", help="vLLM server IP") parser.add_argument("--port", type=int, default=8101, help="vLLM server port") parser.add_argument("--model_name", type=str, default="DotsOCR", help="Model name") parser.add_argument("--prompt_mode", type=str, default="prompt_layout_all_en", choices=list(dict_promptmode_to_prompt.keys()), help="Prompt mode") parser.add_argument("--min_pixels", type=int, default=MIN_PIXELS, help="Minimum pixels") parser.add_argument("--max_pixels", type=int, default=MAX_PIXELS, help="Maximum pixels") parser.add_argument("--dpi", type=int, default=200, help="PDF processing DPI") # 处理参数 parser.add_argument("--batch_size", type=int, default=1, help="Batch size") parser.add_argument("--input_pattern", type=str, default="*", help="Input file pattern") parser.add_argument("--test_mode", action="store_true", help="Test mode (process only 10 images)") parser.add_argument("--collect_results", type=str, help="收集处理结果到指定CSV文件") args = parser.parse_args() try: # 获取图像文件列表 if args.input_csv: # 从CSV文件读取 image_files = get_image_files_from_csv(args.input_csv, "fail") print(f"📊 Loaded {len(image_files)} files from CSV with status filter: fail") elif args.input_file_list: # 从文件列表读取 image_files = get_image_files_from_list(args.input_file_list) else: # 从目录读取 input_dir = Path(args.input_dir).resolve() print(f"📁 Input dir: {input_dir}") if not input_dir.exists(): print(f"❌ Input directory does not exist: {input_dir}") return 1 image_files = get_image_files_from_dir(input_dir, args.input_pattern) output_dir = Path(args.output_dir).resolve() print(f"📁 Output dir: {output_dir}") print(f"📊 Found {len(image_files)} image files") if args.test_mode: image_files = image_files[:10] print(f"🧪 Test mode: processing only {len(image_files)} images") print(f"🌐 Using server: {args.ip}:{args.port}") print(f"📦 Batch size: {args.batch_size}") print(f"🎯 Prompt mode: {args.prompt_mode}") # 创建处理器 processor = DotsOCRProcessor( ip=args.ip, port=args.port, model_name=args.model_name, prompt_mode=args.prompt_mode, dpi=args.dpi, min_pixels=args.min_pixels, max_pixels=args.max_pixels ) # 开始处理 start_time = time.time() results = process_images_single_process( image_files, processor, args.batch_size, str(output_dir) ) total_time = time.time() - start_time # 统计结果 success_count = sum(1 for r in results if r.get('success', False)) skipped_count = sum(1 for r in results if r.get('skipped', False)) error_count = len(results) - success_count print(f"\n" + "="*60) print(f"✅ Processing completed!") print(f"📊 Statistics:") print(f" Total files: {len(image_files)}") print(f" Successful: {success_count}") print(f" Skipped: {skipped_count}") print(f" Failed: {error_count}") if len(image_files) > 0: print(f" Success rate: {success_count / len(image_files) * 100:.2f}%") print(f"⏱️ Performance:") print(f" Total time: {total_time:.2f} seconds") if total_time > 0: print(f" Throughput: {len(image_files) / total_time:.2f} images/second") print(f" Avg time per image: {total_time / len(image_files):.2f} seconds") # 保存结果统计 stats = { "total_files": len(image_files), "success_count": success_count, "skipped_count": skipped_count, "error_count": error_count, "success_rate": success_count / len(image_files) if len(image_files) > 0 else 0, "total_time": total_time, "throughput": len(image_files) / total_time if total_time > 0 else 0, "avg_time_per_image": total_time / len(image_files) if len(image_files) > 0 else 0, "batch_size": args.batch_size, "server": f"{args.ip}:{args.port}", "model": args.model_name, "prompt_mode": args.prompt_mode, "timestamp": time.strftime("%Y-%m-%d %H:%M:%S") } # 保存最终结果 output_file_name = Path(output_dir).name output_file = os.path.join(output_dir, f"{output_file_name}_results.json") final_results = { "stats": stats, "results": results } with open(output_file, 'w', encoding='utf-8') as f: json.dump(final_results, f, ensure_ascii=False, indent=2) print(f"💾 Results saved to: {output_file}") # 收集处理结果 if args.collect_results: processed_files = collect_pid_files(output_file) output_file_processed = Path(args.collect_results).resolve() with open(output_file_processed, 'w', encoding='utf-8') as f: f.write("image_path,status\n") for file_path, status in processed_files: f.write(f"{file_path},{status}\n") print(f"💾 Processed files saved to: {output_file_processed}") return 0 except Exception as e: print(f"❌ Processing failed: {e}", file=sys.stderr) traceback.print_exc() return 1 if __name__ == "__main__": print(f"🚀 启动DotsOCR单进程程序...") print(f"🔧 CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', 'Not set')}") if len(sys.argv) == 1: # 如果没有命令行参数,使用默认配置运行 print("ℹ️ No command line arguments provided. Running with default configuration...") # 默认配置 default_config = { "input_dir": "../../OmniDocBench/OpenDataLab___OmniDocBench/images", "output_dir": "./OmniDocBench_DotsOCR_Results", "ip": "10.192.72.11", "port": "8101", "model_name": "DotsOCR", "prompt_mode": "prompt_layout_all_en", "batch_size": "1", "collect_results": "./OmniDocBench_DotsOCR_Results/processed_files.csv", } # 如果需要处理失败的文件,可以使用这个配置 # default_config = { # "input_csv": "./OmniDocBench_DotsOCR_Results/processed_files.csv", # "output_dir": "./OmniDocBench_DotsOCR_Results", # "ip": "127.0.0.1", # "port": "8101", # "collect_results": f"./OmniDocBench_DotsOCR_Results/processed_files_{time.strftime('%Y%m%d_%H%M%S')}.csv", # } # 构造参数 sys.argv = [sys.argv[0]] for key, value in default_config.items(): sys.argv.extend([f"--{key}", str(value)]) # 测试模式 sys.argv.append("--test_mode") sys.exit(main())