|
|
@@ -0,0 +1,422 @@
|
|
|
+"""
|
|
|
+批量处理 OmniDocBench 图片并生成符合评测要求的预测结果
|
|
|
+
|
|
|
+根据 OmniDocBench 评测要求:
|
|
|
+- 输入:OpenDataLab___OmniDocBench/images 下的所有 .jpg 图片
|
|
|
+- 输出:每个图片对应的 .md、.json 和带标注的 layout 图片文件
|
|
|
+- 输出目录:用于后续的 end2end 评测
|
|
|
+"""
|
|
|
+
|
|
|
+import os
|
|
|
+import json
|
|
|
+import tempfile
|
|
|
+import uuid
|
|
|
+import shutil
|
|
|
+from pathlib import Path
|
|
|
+from PIL import Image
|
|
|
+from tqdm import tqdm
|
|
|
+import argparse
|
|
|
+
|
|
|
+# 导入 dots.ocr 相关模块
|
|
|
+from dots_ocr.parser import DotsOCRParser
|
|
|
+from dots_ocr.utils import dict_promptmode_to_prompt
|
|
|
+from dots_ocr.utils.consts import MIN_PIXELS, MAX_PIXELS
|
|
|
+
|
|
|
+
|
|
|
+class OmniDocBenchProcessor:
|
|
|
+ """OmniDocBench 批量处理器"""
|
|
|
+
|
|
|
+ def __init__(self,
|
|
|
+ ip="127.0.0.1",
|
|
|
+ port=8101,
|
|
|
+ model_name="DotsOCR",
|
|
|
+ prompt_mode="prompt_layout_all_en",
|
|
|
+ dpi=200,
|
|
|
+ min_pixels=MIN_PIXELS,
|
|
|
+ max_pixels=MAX_PIXELS):
|
|
|
+ """
|
|
|
+ 初始化处理器
|
|
|
+
|
|
|
+ Args:
|
|
|
+ ip: vLLM 服务器 IP
|
|
|
+ port: vLLM 服务器端口
|
|
|
+ model_name: 模型名称
|
|
|
+ prompt_mode: 提示模式
|
|
|
+ dpi: PDF 处理 DPI
|
|
|
+ min_pixels: 最小像素数
|
|
|
+ max_pixels: 最大像素数
|
|
|
+ """
|
|
|
+ self.parser = DotsOCRParser(
|
|
|
+ ip=ip,
|
|
|
+ port=port,
|
|
|
+ dpi=dpi,
|
|
|
+ min_pixels=min_pixels,
|
|
|
+ max_pixels=max_pixels,
|
|
|
+ model_name=model_name
|
|
|
+ )
|
|
|
+ self.prompt_mode = prompt_mode
|
|
|
+
|
|
|
+ print(f"DotsOCR Parser 初始化完成:")
|
|
|
+ print(f" - 服务器: {ip}:{port}")
|
|
|
+ print(f" - 模型: {model_name}")
|
|
|
+ print(f" - 提示模式: {prompt_mode}")
|
|
|
+ print(f" - 像素范围: {min_pixels} - {max_pixels}")
|
|
|
+
|
|
|
+ def create_temp_session_dir(self):
|
|
|
+ """创建临时会话目录"""
|
|
|
+ session_id = uuid.uuid4().hex[:8]
|
|
|
+ temp_dir = os.path.join(tempfile.gettempdir(), f"omnidocbench_batch_{session_id}")
|
|
|
+ os.makedirs(temp_dir, exist_ok=True)
|
|
|
+ return temp_dir, session_id
|
|
|
+
|
|
|
+ def save_results_to_output_dir(self, result, image_name, output_dir):
|
|
|
+ """
|
|
|
+ 将处理结果保存到输出目录
|
|
|
+
|
|
|
+ Args:
|
|
|
+ result: 解析结果
|
|
|
+ image_name: 图片文件名(不含扩展名)
|
|
|
+ output_dir: 输出目录
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ dict: 保存的文件路径
|
|
|
+ """
|
|
|
+ saved_files = {}
|
|
|
+
|
|
|
+ # 1. 保存 Markdown 文件(OmniDocBench 评测必需)
|
|
|
+ output_md_path = os.path.join(output_dir, f"{image_name}.md")
|
|
|
+ md_content = ""
|
|
|
+
|
|
|
+ # 优先使用无页眉页脚的版本(符合 OmniDocBench 评测要求)
|
|
|
+ if 'md_content_nohf_path' in result and os.path.exists(result['md_content_nohf_path']):
|
|
|
+ with open(result['md_content_nohf_path'], 'r', encoding='utf-8') as f:
|
|
|
+ md_content = f.read()
|
|
|
+ elif 'md_content_path' in result and os.path.exists(result['md_content_path']):
|
|
|
+ with open(result['md_content_path'], 'r', encoding='utf-8') as f:
|
|
|
+ md_content = f.read()
|
|
|
+ else:
|
|
|
+ md_content = "# 解析失败\n\n未能提取到有效的文档内容。"
|
|
|
+
|
|
|
+ with open(output_md_path, 'w', encoding='utf-8') as f:
|
|
|
+ f.write(md_content)
|
|
|
+ saved_files['md'] = output_md_path
|
|
|
+
|
|
|
+ # 2. 保存 JSON 文件
|
|
|
+ output_json_path = os.path.join(output_dir, f"{image_name}.json")
|
|
|
+ json_data = {}
|
|
|
+
|
|
|
+ if 'layout_info_path' in result and os.path.exists(result['layout_info_path']):
|
|
|
+ with open(result['layout_info_path'], 'r', encoding='utf-8') as f:
|
|
|
+ json_data = json.load(f)
|
|
|
+ else:
|
|
|
+ json_data = {
|
|
|
+ "error": "未能提取到有效的布局信息",
|
|
|
+ "cells": []
|
|
|
+ }
|
|
|
+
|
|
|
+ with open(output_json_path, 'w', encoding='utf-8') as f:
|
|
|
+ json.dump(json_data, f, ensure_ascii=False, indent=2)
|
|
|
+ saved_files['json'] = output_json_path
|
|
|
+
|
|
|
+ # 3. 保存带标注的布局图片
|
|
|
+ output_layout_image_path = os.path.join(output_dir, f"{image_name}_layout.jpg")
|
|
|
+
|
|
|
+ if 'layout_image_path' in result and os.path.exists(result['layout_image_path']):
|
|
|
+ # 直接复制布局图片
|
|
|
+ shutil.copy2(result['layout_image_path'], output_layout_image_path)
|
|
|
+ saved_files['layout_image'] = output_layout_image_path
|
|
|
+ else:
|
|
|
+ # 如果没有布局图片,使用原始图片作为占位符
|
|
|
+ try:
|
|
|
+ original_image = Image.open(result.get('original_image_path', ''))
|
|
|
+ original_image.save(output_layout_image_path, 'JPEG', quality=95)
|
|
|
+ saved_files['layout_image'] = output_layout_image_path
|
|
|
+ print(f"⚠️ 使用原始图片作为布局图片: {image_name}")
|
|
|
+ except Exception as e:
|
|
|
+ print(f"⚠️ 无法保存布局图片: {image_name}, 错误: {e}")
|
|
|
+ saved_files['layout_image'] = None
|
|
|
+
|
|
|
+ # 4. 可选:保存原始图片副本
|
|
|
+ output_original_image_path = os.path.join(output_dir, f"{image_name}_original.jpg")
|
|
|
+ if 'original_image_path' in result and os.path.exists(result['original_image_path']):
|
|
|
+ shutil.copy2(result['original_image_path'], output_original_image_path)
|
|
|
+ saved_files['original_image'] = output_original_image_path
|
|
|
+
|
|
|
+ return saved_files
|
|
|
+
|
|
|
+ def process_single_image(self, image_path, output_dir):
|
|
|
+ """
|
|
|
+ 处理单张图片
|
|
|
+
|
|
|
+ Args:
|
|
|
+ image_path: 图片路径
|
|
|
+ output_dir: 输出目录
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ bool: 处理是否成功
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ # 获取图片文件名(不含扩展名)
|
|
|
+ image_name = Path(image_path).stem
|
|
|
+
|
|
|
+ # 检查输出文件是否已存在
|
|
|
+ output_md_path = os.path.join(output_dir, f"{image_name}.md")
|
|
|
+ output_json_path = os.path.join(output_dir, f"{image_name}.json")
|
|
|
+ output_layout_path = os.path.join(output_dir, f"{image_name}_layout.jpg")
|
|
|
+
|
|
|
+ if all(os.path.exists(p) for p in [output_md_path, output_json_path, output_layout_path]):
|
|
|
+ print(f"跳过已存在的文件: {image_name}")
|
|
|
+ return True
|
|
|
+
|
|
|
+ # 创建临时会话目录
|
|
|
+ temp_dir, session_id = self.create_temp_session_dir()
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 读取图片
|
|
|
+ image = Image.open(image_path)
|
|
|
+
|
|
|
+ # 使用 DotsOCRParser 处理图片
|
|
|
+ filename = f"omnidocbench_{session_id}"
|
|
|
+ results = self.parser.parse_image(
|
|
|
+ input_path=image,
|
|
|
+ filename=filename,
|
|
|
+ prompt_mode=self.prompt_mode,
|
|
|
+ save_dir=temp_dir,
|
|
|
+ fitz_preprocess=True # 对图片使用 fitz 预处理
|
|
|
+ )
|
|
|
+
|
|
|
+ # 解析结果
|
|
|
+ if not results:
|
|
|
+ print(f"警告: {image_name} 未返回解析结果")
|
|
|
+ return False
|
|
|
+
|
|
|
+ result = results[0] # parse_image 返回单个结果的列表
|
|
|
+
|
|
|
+ # 添加原始图片路径到结果中
|
|
|
+ result['original_image_path'] = image_path
|
|
|
+
|
|
|
+ # 保存所有结果文件到输出目录
|
|
|
+ saved_files = self.save_results_to_output_dir(result, image_name, output_dir)
|
|
|
+
|
|
|
+ # 验证保存结果
|
|
|
+ success_count = sum(1 for path in saved_files.values() if path and os.path.exists(path))
|
|
|
+ total_expected = 3 # md, json, layout_image
|
|
|
+
|
|
|
+ if success_count >= 2: # 至少保存了 md 和 json
|
|
|
+ print(f"✅ 成功处理: {image_name} (保存了 {success_count}/{total_expected} 个文件)")
|
|
|
+ return True
|
|
|
+ else:
|
|
|
+ print(f"⚠️ 部分成功: {image_name} (保存了 {success_count}/{total_expected} 个文件)")
|
|
|
+ return False
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"❌ 处理 {image_name} 时出错: {str(e)}")
|
|
|
+ return False
|
|
|
+
|
|
|
+ finally:
|
|
|
+ # 清理临时目录
|
|
|
+ if os.path.exists(temp_dir):
|
|
|
+ shutil.rmtree(temp_dir, ignore_errors=True)
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"❌ 处理 {image_path} 时出现致命错误: {str(e)}")
|
|
|
+ return False
|
|
|
+
|
|
|
+ def process_batch(self, images_dir, output_dir):
|
|
|
+ """
|
|
|
+ 批量处理图片
|
|
|
+
|
|
|
+ Args:
|
|
|
+ images_dir: 输入图片目录
|
|
|
+ output_dir: 输出目录
|
|
|
+ """
|
|
|
+ # 创建输出目录
|
|
|
+ os.makedirs(output_dir, exist_ok=True)
|
|
|
+
|
|
|
+ # 获取所有图片文件
|
|
|
+ image_extensions = ['.jpg', '.jpeg', '.png']
|
|
|
+ image_files = []
|
|
|
+
|
|
|
+ for ext in image_extensions:
|
|
|
+ image_files.extend(Path(images_dir).glob(f"*{ext}"))
|
|
|
+ image_files.extend(Path(images_dir).glob(f"*{ext.upper()}"))
|
|
|
+
|
|
|
+ image_files = sorted(image_files)
|
|
|
+
|
|
|
+ if not image_files:
|
|
|
+ print(f"在 {images_dir} 中未找到图片文件")
|
|
|
+ return
|
|
|
+
|
|
|
+ print(f"找到 {len(image_files)} 个图片文件")
|
|
|
+ print(f"输出目录结构: {output_dir}")
|
|
|
+
|
|
|
+ # 统计变量
|
|
|
+ success_count = 0
|
|
|
+ failed_count = 0
|
|
|
+ skipped_count = 0
|
|
|
+
|
|
|
+ # 使用进度条处理
|
|
|
+ with tqdm(image_files, desc="处理图片", unit="张") as pbar:
|
|
|
+ for image_path in pbar:
|
|
|
+ # 更新进度条描述
|
|
|
+ pbar.set_description(f"处理: {image_path.name}")
|
|
|
+
|
|
|
+ # 检查输出文件是否已存在(在主输出目录中)
|
|
|
+ image_name = image_path.stem
|
|
|
+ output_md_path = os.path.join(output_dir, f"{image_name}.md")
|
|
|
+ output_json_path = os.path.join(output_dir, f"{image_name}.json")
|
|
|
+ output_layout_path = os.path.join(output_dir, f"{image_name}_layout.jpg")
|
|
|
+
|
|
|
+ if all(os.path.exists(p) for p in [output_md_path, output_json_path, output_layout_path]):
|
|
|
+ skipped_count += 1
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 处理图片
|
|
|
+ if self.process_single_image(str(image_path), output_dir):
|
|
|
+ success_count += 1
|
|
|
+ else:
|
|
|
+ failed_count += 1
|
|
|
+
|
|
|
+ # 更新进度条后缀
|
|
|
+ pbar.set_postfix({
|
|
|
+ 'success': success_count,
|
|
|
+ 'failed': failed_count,
|
|
|
+ 'skipped': skipped_count
|
|
|
+ })
|
|
|
+
|
|
|
+ # 输出最终统计
|
|
|
+ print(f"\n🎉 批量处理完成!")
|
|
|
+ print(f" ✅ 成功: {success_count}")
|
|
|
+ print(f" ❌ 失败: {failed_count}")
|
|
|
+ print(f" ⏭️ 跳过: {skipped_count}")
|
|
|
+ print(f" 📁 输出目录: {output_dir}")
|
|
|
+
|
|
|
+ # 生成处理报告
|
|
|
+ self.generate_processing_report(output_dir, success_count, failed_count, skipped_count)
|
|
|
+
|
|
|
+ def generate_processing_report(self, output_dir, success_count, failed_count, skipped_count):
|
|
|
+ """生成处理报告"""
|
|
|
+ report_path = os.path.join(output_dir, "processing_report.json")
|
|
|
+
|
|
|
+ report = {
|
|
|
+ "processing_summary": {
|
|
|
+ "success_count": success_count,
|
|
|
+ "failed_count": failed_count,
|
|
|
+ "skipped_count": skipped_count,
|
|
|
+ "total_processed": success_count + failed_count + skipped_count
|
|
|
+ },
|
|
|
+ "output_structure": {
|
|
|
+ "markdown_files": f"{output_dir}/*.md",
|
|
|
+ "json_files": f"{output_dir}/*.json",
|
|
|
+ "layout_images": f"{output_dir}/*_layout.jpg",
|
|
|
+ "original_images": f"{output_dir}/*_original.jpg"
|
|
|
+ },
|
|
|
+ "configuration": {
|
|
|
+ "prompt_mode": self.prompt_mode,
|
|
|
+ "server": f"{self.parser.ip}:{self.parser.port}",
|
|
|
+ "pixel_range": f"{self.parser.min_pixels} - {self.parser.max_pixels}"
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ with open(report_path, 'w', encoding='utf-8') as f:
|
|
|
+ json.dump(report, f, ensure_ascii=False, indent=2)
|
|
|
+
|
|
|
+ print(f"📊 处理报告已保存: {report_path}")
|
|
|
+
|
|
|
+
|
|
|
+def main():
|
|
|
+ parser = argparse.ArgumentParser(description="批量处理 OmniDocBench 图片")
|
|
|
+
|
|
|
+ parser.add_argument(
|
|
|
+ "--images_dir",
|
|
|
+ type=str,
|
|
|
+ default="../OmniDocBench/OpenDataLab___OmniDocBench/images",
|
|
|
+ help="输入图片目录路径"
|
|
|
+ )
|
|
|
+
|
|
|
+ parser.add_argument(
|
|
|
+ "--output_dir",
|
|
|
+ type=str,
|
|
|
+ default="./omnidocbench_predictions",
|
|
|
+ help="输出目录路径"
|
|
|
+ )
|
|
|
+
|
|
|
+ parser.add_argument(
|
|
|
+ "--ip",
|
|
|
+ type=str,
|
|
|
+ default="127.0.0.1",
|
|
|
+ help="vLLM 服务器 IP"
|
|
|
+ )
|
|
|
+
|
|
|
+ parser.add_argument(
|
|
|
+ "--port",
|
|
|
+ type=int,
|
|
|
+ default=8101,
|
|
|
+ help="vLLM 服务器端口"
|
|
|
+ )
|
|
|
+
|
|
|
+ parser.add_argument(
|
|
|
+ "--model_name",
|
|
|
+ type=str,
|
|
|
+ default="DotsOCR",
|
|
|
+ help="模型名称"
|
|
|
+ )
|
|
|
+
|
|
|
+ parser.add_argument(
|
|
|
+ "--prompt_mode",
|
|
|
+ type=str,
|
|
|
+ default="prompt_layout_all_en",
|
|
|
+ choices=list(dict_promptmode_to_prompt.keys()),
|
|
|
+ help="提示模式"
|
|
|
+ )
|
|
|
+
|
|
|
+ parser.add_argument(
|
|
|
+ "--min_pixels",
|
|
|
+ type=int,
|
|
|
+ default=MIN_PIXELS,
|
|
|
+ help="最小像素数"
|
|
|
+ )
|
|
|
+
|
|
|
+ parser.add_argument(
|
|
|
+ "--max_pixels",
|
|
|
+ type=int,
|
|
|
+ default=MAX_PIXELS,
|
|
|
+ help="最大像素数"
|
|
|
+ )
|
|
|
+
|
|
|
+ parser.add_argument(
|
|
|
+ "--dpi",
|
|
|
+ type=int,
|
|
|
+ default=200,
|
|
|
+ help="PDF 处理 DPI"
|
|
|
+ )
|
|
|
+
|
|
|
+ args = parser.parse_args()
|
|
|
+
|
|
|
+ # 检查输入目录
|
|
|
+ if not os.path.exists(args.images_dir):
|
|
|
+ print(f"❌ 输入目录不存在: {args.images_dir}")
|
|
|
+ return
|
|
|
+
|
|
|
+ print(f"🚀 开始批量处理 OmniDocBench 图片")
|
|
|
+ print(f"📁 输入目录: {args.images_dir}")
|
|
|
+ print(f"📁 输出目录: {args.output_dir}")
|
|
|
+ print("="*60)
|
|
|
+
|
|
|
+ # 创建处理器
|
|
|
+ processor = OmniDocBenchProcessor(
|
|
|
+ ip=args.ip,
|
|
|
+ port=args.port,
|
|
|
+ model_name=args.model_name,
|
|
|
+ prompt_mode=args.prompt_mode,
|
|
|
+ dpi=args.dpi,
|
|
|
+ min_pixels=args.min_pixels,
|
|
|
+ max_pixels=args.max_pixels
|
|
|
+ )
|
|
|
+
|
|
|
+ # 开始批量处理
|
|
|
+ processor.process_batch(args.images_dir, args.output_dir)
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ main()
|