| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443 |
- """
- 批量处理 OmniDocBench 图片并生成符合评测要求的预测结果
- 根据 OmniDocBench 评测要求:
- - 输入:OpenDataLab___OmniDocBench/images 下的所有 .jpg 图片
- - 输出:每个图片对应的 .md、.json 和带标注的 layout 图片文件
- - 输出目录:用于后续的 end2end 评测
- """
- import os
- import sys
- import json
- import tempfile
- import uuid
- import shutil
- from pathlib import Path
- from PIL import Image
- from tqdm import tqdm
- import argparse
- # 导入 dots.ocr 相关模块
- from dots_ocr.parser import DotsOCRParser
- from dots_ocr.utils import dict_promptmode_to_prompt
- from dots_ocr.utils.consts import MIN_PIXELS, MAX_PIXELS
- class OmniDocBenchProcessor:
- """OmniDocBench 批量处理器"""
-
- def __init__(self,
- ip="127.0.0.1",
- port=8101,
- model_name="DotsOCR",
- prompt_mode="prompt_layout_all_en",
- dpi=200,
- min_pixels=MIN_PIXELS,
- max_pixels=MAX_PIXELS):
- """
- 初始化处理器
-
- Args:
- ip: vLLM 服务器 IP
- port: vLLM 服务器端口
- model_name: 模型名称
- prompt_mode: 提示模式
- dpi: PDF 处理 DPI
- min_pixels: 最小像素数
- max_pixels: 最大像素数
- """
- self.parser = DotsOCRParser(
- ip=ip,
- port=port,
- dpi=dpi,
- min_pixels=min_pixels,
- max_pixels=max_pixels,
- model_name=model_name
- )
- self.prompt_mode = prompt_mode
-
- print(f"DotsOCR Parser 初始化完成:")
- print(f" - 服务器: {ip}:{port}")
- print(f" - 模型: {model_name}")
- print(f" - 提示模式: {prompt_mode}")
- print(f" - 像素范围: {min_pixels} - {max_pixels}")
-
- def create_temp_session_dir(self):
- """创建临时会话目录"""
- session_id = uuid.uuid4().hex[:8]
- temp_dir = os.path.join(tempfile.gettempdir(), f"omnidocbench_batch_{session_id}")
- os.makedirs(temp_dir, exist_ok=True)
- return temp_dir, session_id
-
- def save_results_to_output_dir(self, result, image_name, output_dir):
- """
- 将处理结果保存到输出目录
-
- Args:
- result: 解析结果
- image_name: 图片文件名(不含扩展名)
- output_dir: 输出目录
-
- Returns:
- dict: 保存的文件路径
- """
- saved_files = {}
-
- # 1. 保存 Markdown 文件(OmniDocBench 评测必需)
- output_md_path = os.path.join(output_dir, f"{image_name}.md")
- md_content = ""
-
- # 优先使用无页眉页脚的版本(符合 OmniDocBench 评测要求)
- if 'md_content_nohf_path' in result and os.path.exists(result['md_content_nohf_path']):
- with open(result['md_content_nohf_path'], 'r', encoding='utf-8') as f:
- md_content = f.read()
- elif 'md_content_path' in result and os.path.exists(result['md_content_path']):
- with open(result['md_content_path'], 'r', encoding='utf-8') as f:
- md_content = f.read()
- else:
- md_content = "# 解析失败\n\n未能提取到有效的文档内容。"
-
- with open(output_md_path, 'w', encoding='utf-8') as f:
- f.write(md_content)
- saved_files['md'] = output_md_path
-
- # 2. 保存 JSON 文件
- output_json_path = os.path.join(output_dir, f"{image_name}.json")
- json_data = {}
-
- if 'layout_info_path' in result and os.path.exists(result['layout_info_path']):
- with open(result['layout_info_path'], 'r', encoding='utf-8') as f:
- json_data = json.load(f)
- else:
- json_data = {
- "error": "未能提取到有效的布局信息",
- "cells": []
- }
-
- with open(output_json_path, 'w', encoding='utf-8') as f:
- json.dump(json_data, f, ensure_ascii=False, indent=2)
- saved_files['json'] = output_json_path
-
- # 3. 保存带标注的布局图片
- output_layout_image_path = os.path.join(output_dir, f"{image_name}_layout.jpg")
-
- if 'layout_image_path' in result and os.path.exists(result['layout_image_path']):
- # 直接复制布局图片
- shutil.copy2(result['layout_image_path'], output_layout_image_path)
- saved_files['layout_image'] = output_layout_image_path
- else:
- # 如果没有布局图片,使用原始图片作为占位符
- try:
- original_image = Image.open(result.get('original_image_path', ''))
- original_image.save(output_layout_image_path, 'JPEG', quality=95)
- saved_files['layout_image'] = output_layout_image_path
- print(f"⚠️ 使用原始图片作为布局图片: {image_name}")
- except Exception as e:
- print(f"⚠️ 无法保存布局图片: {image_name}, 错误: {e}")
- saved_files['layout_image'] = None
-
- # 4. 可选:保存原始图片副本
- output_original_image_path = os.path.join(output_dir, f"{image_name}_original.jpg")
- if 'original_image_path' in result and os.path.exists(result['original_image_path']):
- shutil.copy2(result['original_image_path'], output_original_image_path)
- saved_files['original_image'] = output_original_image_path
-
- return saved_files
-
- def process_single_image(self, image_path, output_dir):
- """
- 处理单张图片
-
- Args:
- image_path: 图片路径
- output_dir: 输出目录
-
- Returns:
- bool: 处理是否成功
- """
- try:
- # 获取图片文件名(不含扩展名)
- image_name = Path(image_path).stem
-
- # 检查输出文件是否已存在
- output_md_path = os.path.join(output_dir, f"{image_name}.md")
- output_json_path = os.path.join(output_dir, f"{image_name}.json")
- output_layout_path = os.path.join(output_dir, f"{image_name}_layout.jpg")
-
- if all(os.path.exists(p) for p in [output_md_path, output_json_path, output_layout_path]):
- print(f"跳过已存在的文件: {image_name}")
- return True
-
- # 创建临时会话目录
- temp_dir, session_id = self.create_temp_session_dir()
-
- try:
- # 读取图片
- image = Image.open(image_path)
-
- # 使用 DotsOCRParser 处理图片
- filename = f"omnidocbench_{session_id}"
- results = self.parser.parse_image(
- input_path=image,
- filename=filename,
- prompt_mode=self.prompt_mode,
- save_dir=temp_dir,
- fitz_preprocess=True # 对图片使用 fitz 预处理
- )
-
- # 解析结果
- if not results:
- print(f"警告: {image_name} 未返回解析结果")
- return False
-
- result = results[0] # parse_image 返回单个结果的列表
-
- # 添加原始图片路径到结果中
- result['original_image_path'] = image_path
-
- # 保存所有结果文件到输出目录
- saved_files = self.save_results_to_output_dir(result, image_name, output_dir)
-
- # 验证保存结果
- success_count = sum(1 for path in saved_files.values() if path and os.path.exists(path))
- total_expected = 3 # md, json, layout_image
-
- if success_count >= 2: # 至少保存了 md 和 json
- print(f"✅ 成功处理: {image_name} (保存了 {success_count}/{total_expected} 个文件)")
- return True
- else:
- print(f"⚠️ 部分成功: {image_name} (保存了 {success_count}/{total_expected} 个文件)")
- return False
-
- except Exception as e:
- print(f"❌ 处理 {image_name} 时出错: {str(e)}")
- return False
-
- finally:
- # 清理临时目录
- if os.path.exists(temp_dir):
- shutil.rmtree(temp_dir, ignore_errors=True)
-
- except Exception as e:
- print(f"❌ 处理 {image_path} 时出现致命错误: {str(e)}")
- return False
-
- def process_batch(self, images_dir, output_dir):
- """
- 批量处理图片
-
- Args:
- images_dir: 输入图片目录
- output_dir: 输出目录
- """
- # 创建输出目录
- os.makedirs(output_dir, exist_ok=True)
-
- # 获取所有图片文件
- image_extensions = ['.jpg', '.jpeg', '.png']
- image_files = []
-
- for ext in image_extensions:
- image_files.extend(Path(images_dir).glob(f"*{ext}"))
- image_files.extend(Path(images_dir).glob(f"*{ext.upper()}"))
-
- image_files = sorted(image_files)
-
- if not image_files:
- print(f"在 {images_dir} 中未找到图片文件")
- return
-
- print(f"找到 {len(image_files)} 个图片文件")
- print(f"输出目录结构: {output_dir}")
-
- # 统计变量
- success_count = 0
- failed_count = 0
- skipped_count = 0
-
- # 使用进度条处理
- with tqdm(image_files, desc="处理图片", unit="张") as pbar:
- for image_path in pbar:
- # 更新进度条描述
- pbar.set_description(f"处理: {image_path.name}")
-
- # 检查输出文件是否已存在(在主输出目录中)
- image_name = image_path.stem
- output_md_path = os.path.join(output_dir, f"{image_name}.md")
- output_json_path = os.path.join(output_dir, f"{image_name}.json")
- output_layout_path = os.path.join(output_dir, f"{image_name}_layout.jpg")
-
- if all(os.path.exists(p) for p in [output_md_path, output_json_path, output_layout_path]):
- skipped_count += 1
- continue
-
- # 处理图片
- if self.process_single_image(str(image_path), output_dir):
- success_count += 1
- else:
- failed_count += 1
-
- # 更新进度条后缀
- pbar.set_postfix({
- 'success': success_count,
- 'failed': failed_count,
- 'skipped': skipped_count
- })
-
- # 输出最终统计
- print(f"\n🎉 批量处理完成!")
- print(f" ✅ 成功: {success_count}")
- print(f" ❌ 失败: {failed_count}")
- print(f" ⏭️ 跳过: {skipped_count}")
- print(f" 📁 输出目录: {output_dir}")
-
- # 生成处理报告
- self.generate_processing_report(output_dir, success_count, failed_count, skipped_count)
-
- def generate_processing_report(self, output_dir, success_count, failed_count, skipped_count):
- """生成处理报告"""
- report_path = os.path.join(output_dir, "processing_report.json")
-
- report = {
- "processing_summary": {
- "success_count": success_count,
- "failed_count": failed_count,
- "skipped_count": skipped_count,
- "total_processed": success_count + failed_count + skipped_count
- },
- "output_structure": {
- "markdown_files": f"{output_dir}/*.md",
- "json_files": f"{output_dir}/*.json",
- "layout_images": f"{output_dir}/*_layout.jpg",
- "original_images": f"{output_dir}/*_original.jpg"
- },
- "configuration": {
- "prompt_mode": self.prompt_mode,
- "server": f"{self.parser.ip}:{self.parser.port}",
- "pixel_range": f"{self.parser.min_pixels} - {self.parser.max_pixels}"
- }
- }
-
- with open(report_path, 'w', encoding='utf-8') as f:
- json.dump(report, f, ensure_ascii=False, indent=2)
-
- print(f"📊 处理报告已保存: {report_path}")
- def main():
- parser = argparse.ArgumentParser(description="批量处理 OmniDocBench 图片")
-
- parser.add_argument(
- "--images_dir",
- type=str,
- default="../OmniDocBench/OpenDataLab___OmniDocBench/images",
- help="输入图片目录路径"
- )
-
- parser.add_argument(
- "--output_dir",
- type=str,
- default="./omnidocbench_predictions",
- help="输出目录路径"
- )
-
- parser.add_argument(
- "--ip",
- type=str,
- default="127.0.0.1",
- help="vLLM 服务器 IP"
- )
-
- parser.add_argument(
- "--port",
- type=int,
- default=8101,
- help="vLLM 服务器端口"
- )
-
- parser.add_argument(
- "--model_name",
- type=str,
- default="DotsOCR",
- help="模型名称"
- )
-
- parser.add_argument(
- "--prompt_mode",
- type=str,
- default="prompt_layout_all_en",
- choices=list(dict_promptmode_to_prompt.keys()),
- help="提示模式"
- )
-
- parser.add_argument(
- "--min_pixels",
- type=int,
- default=MIN_PIXELS,
- help="最小像素数"
- )
-
- parser.add_argument(
- "--max_pixels",
- type=int,
- default=MAX_PIXELS,
- help="最大像素数"
- )
-
- parser.add_argument(
- "--dpi",
- type=int,
- default=200,
- help="PDF 处理 DPI"
- )
-
- args = parser.parse_args()
-
- # 检查输入目录
- if not os.path.exists(args.images_dir):
- print(f"❌ 输入目录不存在: {args.images_dir}")
- return
-
- print(f"🚀 开始批量处理 OmniDocBench 图片")
- print(f"📁 输入目录: {args.images_dir}")
- print(f"📁 输出目录: {args.output_dir}")
- print("="*60)
-
- # 创建处理器
- processor = OmniDocBenchProcessor(
- ip=args.ip,
- port=args.port,
- model_name=args.model_name,
- prompt_mode=args.prompt_mode,
- dpi=args.dpi,
- min_pixels=args.min_pixels,
- max_pixels=args.max_pixels
- )
-
- # 开始批量处理
- processor.process_batch(args.images_dir, args.output_dir)
- if __name__ == "__main__":
- print(f"🚀 启动单进程DotsOCR程序...")
- print(f"🔧 CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', 'Not set')}")
-
- if len(sys.argv) == 1:
- # 如果没有命令行参数,使用默认配置运行
- print("ℹ️ No command line arguments provided. Running with default configuration...")
-
- # 默认配置
- default_config = {
- "input_dir": "../../OmniDocBench/OpenDataLab___OmniDocBench/images",
- "output_dir": "./OmniDocBench_Results_Single",
- }
- # 构造参数
- sys.argv = [sys.argv[0]]
- for key, value in default_config.items():
- sys.argv.extend([f"--{key}", str(value)])
-
- # 测试模式
- # sys.argv.append("--test_mode")
-
- sys.exit(main())
|