|
|
@@ -13,7 +13,11 @@ import json
|
|
|
import tempfile
|
|
|
import uuid
|
|
|
import shutil
|
|
|
+import time
|
|
|
+import traceback
|
|
|
+import warnings
|
|
|
from pathlib import Path
|
|
|
+from typing import List, Dict, Any
|
|
|
from PIL import Image
|
|
|
from tqdm import tqdm
|
|
|
import argparse
|
|
|
@@ -23,18 +27,25 @@ from dots_ocr.parser import DotsOCRParser
|
|
|
from dots_ocr.utils import dict_promptmode_to_prompt
|
|
|
from dots_ocr.utils.consts import MIN_PIXELS, MAX_PIXELS
|
|
|
|
|
|
+# 导入工具函数
|
|
|
+from utils import (
|
|
|
+ get_image_files_from_dir,
|
|
|
+ get_image_files_from_list,
|
|
|
+ get_image_files_from_csv,
|
|
|
+ collect_pid_files
|
|
|
+)
|
|
|
|
|
|
-class OmniDocBenchProcessor:
|
|
|
- """OmniDocBench 批量处理器"""
|
|
|
+class DotsOCRProcessor:
|
|
|
+ """DotsOCR 处理器"""
|
|
|
|
|
|
def __init__(self,
|
|
|
- ip="127.0.0.1",
|
|
|
- port=8101,
|
|
|
- model_name="DotsOCR",
|
|
|
- prompt_mode="prompt_layout_all_en",
|
|
|
- dpi=200,
|
|
|
- min_pixels=MIN_PIXELS,
|
|
|
- max_pixels=MAX_PIXELS):
|
|
|
+ ip: str = "127.0.0.1",
|
|
|
+ port: int = 8101,
|
|
|
+ model_name: str = "DotsOCR",
|
|
|
+ prompt_mode: str = "prompt_layout_all_en",
|
|
|
+ dpi: int = 200,
|
|
|
+ min_pixels: int = MIN_PIXELS,
|
|
|
+ max_pixels: int = MAX_PIXELS):
|
|
|
"""
|
|
|
初始化处理器
|
|
|
|
|
|
@@ -47,6 +58,15 @@ class OmniDocBenchProcessor:
|
|
|
min_pixels: 最小像素数
|
|
|
max_pixels: 最大像素数
|
|
|
"""
|
|
|
+ self.ip = ip
|
|
|
+ self.port = port
|
|
|
+ self.model_name = model_name
|
|
|
+ self.prompt_mode = prompt_mode
|
|
|
+ self.dpi = dpi
|
|
|
+ self.min_pixels = min_pixels
|
|
|
+ self.max_pixels = max_pixels
|
|
|
+
|
|
|
+ # 初始化解析器
|
|
|
self.parser = DotsOCRParser(
|
|
|
ip=ip,
|
|
|
port=port,
|
|
|
@@ -55,7 +75,6 @@ class OmniDocBenchProcessor:
|
|
|
max_pixels=max_pixels,
|
|
|
model_name=model_name
|
|
|
)
|
|
|
- self.prompt_mode = prompt_mode
|
|
|
|
|
|
print(f"DotsOCR Parser 初始化完成:")
|
|
|
print(f" - 服务器: {ip}:{port}")
|
|
|
@@ -63,14 +82,14 @@ class OmniDocBenchProcessor:
|
|
|
print(f" - 提示模式: {prompt_mode}")
|
|
|
print(f" - 像素范围: {min_pixels} - {max_pixels}")
|
|
|
|
|
|
- def create_temp_session_dir(self):
|
|
|
+ def create_temp_session_dir(self) -> tuple:
|
|
|
"""创建临时会话目录"""
|
|
|
session_id = uuid.uuid4().hex[:8]
|
|
|
temp_dir = os.path.join(tempfile.gettempdir(), f"omnidocbench_batch_{session_id}")
|
|
|
os.makedirs(temp_dir, exist_ok=True)
|
|
|
return temp_dir, session_id
|
|
|
|
|
|
- def save_results_to_output_dir(self, result, image_name, output_dir):
|
|
|
+ def save_results_to_output_dir(self, result: Dict, image_name: str, output_dir: str) -> Dict[str, str]:
|
|
|
"""
|
|
|
将处理结果保存到输出目录
|
|
|
|
|
|
@@ -84,68 +103,70 @@ class OmniDocBenchProcessor:
|
|
|
"""
|
|
|
saved_files = {}
|
|
|
|
|
|
- # 1. 保存 Markdown 文件(OmniDocBench 评测必需)
|
|
|
- output_md_path = os.path.join(output_dir, f"{image_name}.md")
|
|
|
- md_content = ""
|
|
|
-
|
|
|
- # 优先使用无页眉页脚的版本(符合 OmniDocBench 评测要求)
|
|
|
- if 'md_content_nohf_path' in result and os.path.exists(result['md_content_nohf_path']):
|
|
|
- with open(result['md_content_nohf_path'], 'r', encoding='utf-8') as f:
|
|
|
- md_content = f.read()
|
|
|
- elif 'md_content_path' in result and os.path.exists(result['md_content_path']):
|
|
|
- with open(result['md_content_path'], 'r', encoding='utf-8') as f:
|
|
|
- md_content = f.read()
|
|
|
- else:
|
|
|
- md_content = "# 解析失败\n\n未能提取到有效的文档内容。"
|
|
|
-
|
|
|
- with open(output_md_path, 'w', encoding='utf-8') as f:
|
|
|
- f.write(md_content)
|
|
|
- saved_files['md'] = output_md_path
|
|
|
-
|
|
|
- # 2. 保存 JSON 文件
|
|
|
- output_json_path = os.path.join(output_dir, f"{image_name}.json")
|
|
|
- json_data = {}
|
|
|
-
|
|
|
- if 'layout_info_path' in result and os.path.exists(result['layout_info_path']):
|
|
|
- with open(result['layout_info_path'], 'r', encoding='utf-8') as f:
|
|
|
- json_data = json.load(f)
|
|
|
- else:
|
|
|
- json_data = {
|
|
|
- "error": "未能提取到有效的布局信息",
|
|
|
- "cells": []
|
|
|
- }
|
|
|
-
|
|
|
- with open(output_json_path, 'w', encoding='utf-8') as f:
|
|
|
- json.dump(json_data, f, ensure_ascii=False, indent=2)
|
|
|
- saved_files['json'] = output_json_path
|
|
|
-
|
|
|
- # 3. 保存带标注的布局图片
|
|
|
- output_layout_image_path = os.path.join(output_dir, f"{image_name}_layout.jpg")
|
|
|
-
|
|
|
- if 'layout_image_path' in result and os.path.exists(result['layout_image_path']):
|
|
|
- # 直接复制布局图片
|
|
|
- shutil.copy2(result['layout_image_path'], output_layout_image_path)
|
|
|
- saved_files['layout_image'] = output_layout_image_path
|
|
|
- else:
|
|
|
- # 如果没有布局图片,使用原始图片作为占位符
|
|
|
- try:
|
|
|
- original_image = Image.open(result.get('original_image_path', ''))
|
|
|
- original_image.save(output_layout_image_path, 'JPEG', quality=95)
|
|
|
+ try:
|
|
|
+ # 1. 保存 Markdown 文件(OmniDocBench 评测必需)
|
|
|
+ output_md_path = os.path.join(output_dir, f"{image_name}.md")
|
|
|
+ md_content = ""
|
|
|
+
|
|
|
+ # 优先使用无页眉页脚的版本(符合 OmniDocBench 评测要求)
|
|
|
+ if 'md_content_nohf_path' in result and os.path.exists(result['md_content_nohf_path']):
|
|
|
+ with open(result['md_content_nohf_path'], 'r', encoding='utf-8') as f:
|
|
|
+ md_content = f.read()
|
|
|
+ elif 'md_content_path' in result and os.path.exists(result['md_content_path']):
|
|
|
+ with open(result['md_content_path'], 'r', encoding='utf-8') as f:
|
|
|
+ md_content = f.read()
|
|
|
+ else:
|
|
|
+ md_content = "# 解析失败\n\n未能提取到有效的文档内容。"
|
|
|
+
|
|
|
+ with open(output_md_path, 'w', encoding='utf-8') as f:
|
|
|
+ f.write(md_content)
|
|
|
+ saved_files['md'] = output_md_path
|
|
|
+
|
|
|
+ # 2. 保存 JSON 文件
|
|
|
+ output_json_path = os.path.join(output_dir, f"{image_name}.json")
|
|
|
+ json_data = {}
|
|
|
+
|
|
|
+ if 'layout_info_path' in result and os.path.exists(result['layout_info_path']):
|
|
|
+ with open(result['layout_info_path'], 'r', encoding='utf-8') as f:
|
|
|
+ json_data = json.load(f)
|
|
|
+ else:
|
|
|
+ json_data = {
|
|
|
+ "error": "未能提取到有效的布局信息",
|
|
|
+ "cells": []
|
|
|
+ }
|
|
|
+
|
|
|
+ with open(output_json_path, 'w', encoding='utf-8') as f:
|
|
|
+ json.dump(json_data, f, ensure_ascii=False, indent=2)
|
|
|
+ saved_files['json'] = output_json_path
|
|
|
+
|
|
|
+ # 3. 保存带标注的布局图片
|
|
|
+ output_layout_image_path = os.path.join(output_dir, f"{image_name}_layout.jpg")
|
|
|
+
|
|
|
+ if 'layout_image_path' in result and os.path.exists(result['layout_image_path']):
|
|
|
+ # 直接复制布局图片
|
|
|
+ shutil.copy2(result['layout_image_path'], output_layout_image_path)
|
|
|
saved_files['layout_image'] = output_layout_image_path
|
|
|
- print(f"⚠️ 使用原始图片作为布局图片: {image_name}")
|
|
|
- except Exception as e:
|
|
|
- print(f"⚠️ 无法保存布局图片: {image_name}, 错误: {e}")
|
|
|
- saved_files['layout_image'] = None
|
|
|
-
|
|
|
- # 4. 可选:保存原始图片副本
|
|
|
- output_original_image_path = os.path.join(output_dir, f"{image_name}_original.jpg")
|
|
|
- if 'original_image_path' in result and os.path.exists(result['original_image_path']):
|
|
|
- shutil.copy2(result['original_image_path'], output_original_image_path)
|
|
|
- saved_files['original_image'] = output_original_image_path
|
|
|
-
|
|
|
+ else:
|
|
|
+ # 如果没有布局图片,使用原始图片作为占位符
|
|
|
+ try:
|
|
|
+ original_image = Image.open(result.get('original_image_path', ''))
|
|
|
+ original_image.save(output_layout_image_path, 'JPEG', quality=95)
|
|
|
+ saved_files['layout_image'] = output_layout_image_path
|
|
|
+ except Exception as e:
|
|
|
+ saved_files['layout_image'] = None
|
|
|
+
|
|
|
+ # # 4. 可选:保存原始图片副本
|
|
|
+ # output_original_image_path = os.path.join(output_dir, f"{image_name}_original.jpg")
|
|
|
+ # if 'original_image_path' in result and os.path.exists(result['original_image_path']):
|
|
|
+ # shutil.copy2(result['original_image_path'], output_original_image_path)
|
|
|
+ # saved_files['original_image'] = output_original_image_path
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"Error saving results for {image_name}: {e}")
|
|
|
+
|
|
|
return saved_files
|
|
|
|
|
|
- def process_single_image(self, image_path, output_dir):
|
|
|
+ def process_single_image(self, image_path: str, output_dir: str) -> Dict[str, Any]:
|
|
|
"""
|
|
|
处理单张图片
|
|
|
|
|
|
@@ -154,20 +175,38 @@ class OmniDocBenchProcessor:
|
|
|
output_dir: 输出目录
|
|
|
|
|
|
Returns:
|
|
|
- bool: 处理是否成功
|
|
|
+ dict: 处理结果
|
|
|
"""
|
|
|
+ start_time = time.time()
|
|
|
+ image_name = Path(image_path).stem
|
|
|
+
|
|
|
+ result_info = {
|
|
|
+ "image_path": image_path,
|
|
|
+ "processing_time": 0,
|
|
|
+ "success": False,
|
|
|
+ "device": f"{self.ip}:{self.port}",
|
|
|
+ "error": None,
|
|
|
+ "output_files": {}
|
|
|
+ }
|
|
|
+
|
|
|
try:
|
|
|
- # 获取图片文件名(不含扩展名)
|
|
|
- image_name = Path(image_path).stem
|
|
|
-
|
|
|
# 检查输出文件是否已存在
|
|
|
output_md_path = os.path.join(output_dir, f"{image_name}.md")
|
|
|
output_json_path = os.path.join(output_dir, f"{image_name}.json")
|
|
|
output_layout_path = os.path.join(output_dir, f"{image_name}_layout.jpg")
|
|
|
|
|
|
if all(os.path.exists(p) for p in [output_md_path, output_json_path, output_layout_path]):
|
|
|
- print(f"跳过已存在的文件: {image_name}")
|
|
|
- return True
|
|
|
+ result_info.update({
|
|
|
+ "success": True,
|
|
|
+ "processing_time": 0,
|
|
|
+ "output_files": {
|
|
|
+ "md": output_md_path,
|
|
|
+ "json": output_json_path,
|
|
|
+ "layout_image": output_layout_path
|
|
|
+ },
|
|
|
+ "skipped": True
|
|
|
+ })
|
|
|
+ return result_info
|
|
|
|
|
|
# 创建临时会话目录
|
|
|
temp_dir, session_id = self.create_temp_session_dir()
|
|
|
@@ -188,31 +227,26 @@ class OmniDocBenchProcessor:
|
|
|
|
|
|
# 解析结果
|
|
|
if not results:
|
|
|
- print(f"警告: {image_name} 未返回解析结果")
|
|
|
- return False
|
|
|
+ raise Exception("未返回解析结果")
|
|
|
|
|
|
result = results[0] # parse_image 返回单个结果的列表
|
|
|
|
|
|
# 添加原始图片路径到结果中
|
|
|
- result['original_image_path'] = image_path
|
|
|
+ # result['original_image_path'] = image_path
|
|
|
|
|
|
# 保存所有结果文件到输出目录
|
|
|
saved_files = self.save_results_to_output_dir(result, image_name, output_dir)
|
|
|
|
|
|
# 验证保存结果
|
|
|
success_count = sum(1 for path in saved_files.values() if path and os.path.exists(path))
|
|
|
- total_expected = 3 # md, json, layout_image
|
|
|
|
|
|
if success_count >= 2: # 至少保存了 md 和 json
|
|
|
- print(f"✅ 成功处理: {image_name} (保存了 {success_count}/{total_expected} 个文件)")
|
|
|
- return True
|
|
|
+ result_info.update({
|
|
|
+ "success": True,
|
|
|
+ "output_files": saved_files
|
|
|
+ })
|
|
|
else:
|
|
|
- print(f"⚠️ 部分成功: {image_name} (保存了 {success_count}/{total_expected} 个文件)")
|
|
|
- return False
|
|
|
-
|
|
|
- except Exception as e:
|
|
|
- print(f"❌ 处理 {image_name} 时出错: {str(e)}")
|
|
|
- return False
|
|
|
+ raise Exception(f"保存文件不完整 ({success_count}/3)")
|
|
|
|
|
|
finally:
|
|
|
# 清理临时目录
|
|
|
@@ -220,207 +254,261 @@ class OmniDocBenchProcessor:
|
|
|
shutil.rmtree(temp_dir, ignore_errors=True)
|
|
|
|
|
|
except Exception as e:
|
|
|
- print(f"❌ 处理 {image_path} 时出现致命错误: {str(e)}")
|
|
|
- return False
|
|
|
+ result_info["error"] = str(e)
|
|
|
+
|
|
|
+ finally:
|
|
|
+ result_info["processing_time"] = time.time() - start_time
|
|
|
+
|
|
|
+ return result_info
|
|
|
+
|
|
|
+
|
|
|
+def process_images_single_process(image_paths: List[str],
|
|
|
+ processor: DotsOCRProcessor,
|
|
|
+ batch_size: int = 1,
|
|
|
+ output_dir: str = "./output") -> List[Dict[str, Any]]:
|
|
|
+ """
|
|
|
+ 单进程版本的图像处理函数
|
|
|
|
|
|
- def process_batch(self, images_dir, output_dir):
|
|
|
- """
|
|
|
- 批量处理图片
|
|
|
-
|
|
|
- Args:
|
|
|
- images_dir: 输入图片目录
|
|
|
- output_dir: 输出目录
|
|
|
- """
|
|
|
- # 创建输出目录
|
|
|
- os.makedirs(output_dir, exist_ok=True)
|
|
|
-
|
|
|
- # 获取所有图片文件
|
|
|
- image_extensions = ['.jpg', '.jpeg', '.png']
|
|
|
- image_files = []
|
|
|
-
|
|
|
- for ext in image_extensions:
|
|
|
- image_files.extend(Path(images_dir).glob(f"*{ext}"))
|
|
|
- image_files.extend(Path(images_dir).glob(f"*{ext.upper()}"))
|
|
|
-
|
|
|
- image_files = sorted(image_files)
|
|
|
-
|
|
|
- if not image_files:
|
|
|
- print(f"在 {images_dir} 中未找到图片文件")
|
|
|
- return
|
|
|
-
|
|
|
- print(f"找到 {len(image_files)} 个图片文件")
|
|
|
- print(f"输出目录结构: {output_dir}")
|
|
|
+ Args:
|
|
|
+ image_paths: 图像路径列表
|
|
|
+ processor: DotsOCR处理器实例
|
|
|
+ batch_size: 批处理大小
|
|
|
+ output_dir: 输出目录
|
|
|
|
|
|
- # 统计变量
|
|
|
- success_count = 0
|
|
|
- failed_count = 0
|
|
|
- skipped_count = 0
|
|
|
+ Returns:
|
|
|
+ 处理结果列表
|
|
|
+ """
|
|
|
+ # 创建输出目录
|
|
|
+ output_path = Path(output_dir)
|
|
|
+ output_path.mkdir(parents=True, exist_ok=True)
|
|
|
+
|
|
|
+ all_results = []
|
|
|
+ total_images = len(image_paths)
|
|
|
+
|
|
|
+ print(f"Processing {total_images} images with batch size {batch_size}")
|
|
|
+
|
|
|
+ # 使用tqdm显示进度,添加更多统计信息
|
|
|
+ with tqdm(total=total_images, desc="Processing images", unit="img",
|
|
|
+ bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') as pbar:
|
|
|
|
|
|
- # 使用进度条处理
|
|
|
- with tqdm(image_files, desc="处理图片", unit="张") as pbar:
|
|
|
- for image_path in pbar:
|
|
|
- # 更新进度条描述
|
|
|
- pbar.set_description(f"处理: {image_path.name}")
|
|
|
-
|
|
|
- # 检查输出文件是否已存在(在主输出目录中)
|
|
|
- image_name = image_path.stem
|
|
|
- output_md_path = os.path.join(output_dir, f"{image_name}.md")
|
|
|
- output_json_path = os.path.join(output_dir, f"{image_name}.json")
|
|
|
- output_layout_path = os.path.join(output_dir, f"{image_name}_layout.jpg")
|
|
|
+ # 按批次处理图像(DotsOCR通常单张处理)
|
|
|
+ for i in range(0, total_images, batch_size):
|
|
|
+ batch = image_paths[i:i + batch_size]
|
|
|
+ batch_start_time = time.time()
|
|
|
+ batch_results = []
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 处理批次中的每张图片
|
|
|
+ for image_path in batch:
|
|
|
+ try:
|
|
|
+ result = processor.process_single_image(image_path, output_dir)
|
|
|
+ batch_results.append(result)
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"Error processing {image_path}: {e}", file=sys.stderr)
|
|
|
+ traceback.print_exc()
|
|
|
+
|
|
|
+ batch_results.append({
|
|
|
+ "image_path": image_path,
|
|
|
+ "processing_time": 0,
|
|
|
+ "success": False,
|
|
|
+ "device": f"{processor.ip}:{processor.port}",
|
|
|
+ "error": str(e)
|
|
|
+ })
|
|
|
|
|
|
- if all(os.path.exists(p) for p in [output_md_path, output_json_path, output_layout_path]):
|
|
|
- skipped_count += 1
|
|
|
- continue
|
|
|
+ batch_processing_time = time.time() - batch_start_time
|
|
|
+ all_results.extend(batch_results)
|
|
|
|
|
|
- # 处理图片
|
|
|
- if self.process_single_image(str(image_path), output_dir):
|
|
|
- success_count += 1
|
|
|
- else:
|
|
|
- failed_count += 1
|
|
|
+ # 更新进度条
|
|
|
+ success_count = sum(1 for r in batch_results if r.get('success', False))
|
|
|
+ skipped_count = sum(1 for r in batch_results if r.get('skipped', False))
|
|
|
+ total_success = sum(1 for r in all_results if r.get('success', False))
|
|
|
+ total_skipped = sum(1 for r in all_results if r.get('skipped', False))
|
|
|
+ avg_time = batch_processing_time / len(batch)
|
|
|
|
|
|
- # 更新进度条后缀
|
|
|
+ pbar.update(len(batch))
|
|
|
pbar.set_postfix({
|
|
|
- 'success': success_count,
|
|
|
- 'failed': failed_count,
|
|
|
- 'skipped': skipped_count
|
|
|
+ 'batch_time': f"{batch_processing_time:.2f}s",
|
|
|
+ 'avg_time': f"{avg_time:.2f}s/img",
|
|
|
+ 'success': f"{total_success}/{len(all_results)}",
|
|
|
+ 'skipped': f"{total_skipped}",
|
|
|
+ 'rate': f"{total_success/len(all_results)*100:.1f}%"
|
|
|
})
|
|
|
-
|
|
|
- # 输出最终统计
|
|
|
- print(f"\n🎉 批量处理完成!")
|
|
|
- print(f" ✅ 成功: {success_count}")
|
|
|
- print(f" ❌ 失败: {failed_count}")
|
|
|
- print(f" ⏭️ 跳过: {skipped_count}")
|
|
|
- print(f" 📁 输出目录: {output_dir}")
|
|
|
-
|
|
|
- # 生成处理报告
|
|
|
- self.generate_processing_report(output_dir, success_count, failed_count, skipped_count)
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"Error processing batch {[Path(p).name for p in batch]}: {e}", file=sys.stderr)
|
|
|
+ traceback.print_exc()
|
|
|
+
|
|
|
+ # 为批次中的所有图像添加错误结果
|
|
|
+ error_results = []
|
|
|
+ for img_path in batch:
|
|
|
+ error_results.append({
|
|
|
+ "image_path": str(img_path),
|
|
|
+ "processing_time": 0,
|
|
|
+ "success": False,
|
|
|
+ "device": f"{processor.ip}:{processor.port}",
|
|
|
+ "error": str(e)
|
|
|
+ })
|
|
|
+ all_results.extend(error_results)
|
|
|
+ pbar.update(len(batch))
|
|
|
|
|
|
- def generate_processing_report(self, output_dir, success_count, failed_count, skipped_count):
|
|
|
- """生成处理报告"""
|
|
|
- report_path = os.path.join(output_dir, "processing_report.json")
|
|
|
-
|
|
|
- report = {
|
|
|
- "processing_summary": {
|
|
|
- "success_count": success_count,
|
|
|
- "failed_count": failed_count,
|
|
|
- "skipped_count": skipped_count,
|
|
|
- "total_processed": success_count + failed_count + skipped_count
|
|
|
- },
|
|
|
- "output_structure": {
|
|
|
- "markdown_files": f"{output_dir}/*.md",
|
|
|
- "json_files": f"{output_dir}/*.json",
|
|
|
- "layout_images": f"{output_dir}/*_layout.jpg",
|
|
|
- "original_images": f"{output_dir}/*_original.jpg"
|
|
|
- },
|
|
|
- "configuration": {
|
|
|
- "prompt_mode": self.prompt_mode,
|
|
|
- "server": f"{self.parser.ip}:{self.parser.port}",
|
|
|
- "pixel_range": f"{self.parser.min_pixels} - {self.parser.max_pixels}"
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- with open(report_path, 'w', encoding='utf-8') as f:
|
|
|
- json.dump(report, f, ensure_ascii=False, indent=2)
|
|
|
-
|
|
|
- print(f"📊 处理报告已保存: {report_path}")
|
|
|
+ return all_results
|
|
|
|
|
|
|
|
|
def main():
|
|
|
- parser = argparse.ArgumentParser(description="批量处理 OmniDocBench 图片")
|
|
|
-
|
|
|
- parser.add_argument(
|
|
|
- "--images_dir",
|
|
|
- type=str,
|
|
|
- default="../OmniDocBench/OpenDataLab___OmniDocBench/images",
|
|
|
- help="输入图片目录路径"
|
|
|
- )
|
|
|
-
|
|
|
- parser.add_argument(
|
|
|
- "--output_dir",
|
|
|
- type=str,
|
|
|
- default="./omnidocbench_predictions",
|
|
|
- help="输出目录路径"
|
|
|
- )
|
|
|
-
|
|
|
- parser.add_argument(
|
|
|
- "--ip",
|
|
|
- type=str,
|
|
|
- default="127.0.0.1",
|
|
|
- help="vLLM 服务器 IP"
|
|
|
- )
|
|
|
-
|
|
|
- parser.add_argument(
|
|
|
- "--port",
|
|
|
- type=int,
|
|
|
- default=8101,
|
|
|
- help="vLLM 服务器端口"
|
|
|
- )
|
|
|
+ """主函数"""
|
|
|
+ parser = argparse.ArgumentParser(description="DotsOCR OmniDocBench Single Process Processing")
|
|
|
|
|
|
- parser.add_argument(
|
|
|
- "--model_name",
|
|
|
- type=str,
|
|
|
- default="DotsOCR",
|
|
|
- help="模型名称"
|
|
|
- )
|
|
|
-
|
|
|
- parser.add_argument(
|
|
|
- "--prompt_mode",
|
|
|
- type=str,
|
|
|
- default="prompt_layout_all_en",
|
|
|
- choices=list(dict_promptmode_to_prompt.keys()),
|
|
|
- help="提示模式"
|
|
|
- )
|
|
|
-
|
|
|
- parser.add_argument(
|
|
|
- "--min_pixels",
|
|
|
- type=int,
|
|
|
- default=MIN_PIXELS,
|
|
|
- help="最小像素数"
|
|
|
- )
|
|
|
-
|
|
|
- parser.add_argument(
|
|
|
- "--max_pixels",
|
|
|
- type=int,
|
|
|
- default=MAX_PIXELS,
|
|
|
- help="最大像素数"
|
|
|
- )
|
|
|
+ # 输入参数组
|
|
|
+ input_group = parser.add_mutually_exclusive_group(required=True)
|
|
|
+ input_group.add_argument("--input_dir", type=str, help="Input directory")
|
|
|
+ input_group.add_argument("--input_file_list", type=str, help="Input file list (one file per line)")
|
|
|
+ input_group.add_argument("--input_csv", type=str, help="Input CSV file with image_path and status columns")
|
|
|
+
|
|
|
+ # 输出参数
|
|
|
+ parser.add_argument("--output_dir", type=str, help="Output directory")
|
|
|
|
|
|
- parser.add_argument(
|
|
|
- "--dpi",
|
|
|
- type=int,
|
|
|
- default=200,
|
|
|
- help="PDF 处理 DPI"
|
|
|
- )
|
|
|
+ # DotsOCR 参数
|
|
|
+ parser.add_argument("--ip", type=str, default="127.0.0.1", help="vLLM server IP")
|
|
|
+ parser.add_argument("--port", type=int, default=8101, help="vLLM server port")
|
|
|
+ parser.add_argument("--model_name", type=str, default="DotsOCR", help="Model name")
|
|
|
+ parser.add_argument("--prompt_mode", type=str, default="prompt_layout_all_en",
|
|
|
+ choices=list(dict_promptmode_to_prompt.keys()), help="Prompt mode")
|
|
|
+ parser.add_argument("--min_pixels", type=int, default=MIN_PIXELS, help="Minimum pixels")
|
|
|
+ parser.add_argument("--max_pixels", type=int, default=MAX_PIXELS, help="Maximum pixels")
|
|
|
+ parser.add_argument("--dpi", type=int, default=200, help="PDF processing DPI")
|
|
|
|
|
|
+ # 处理参数
|
|
|
+ parser.add_argument("--batch_size", type=int, default=1, help="Batch size")
|
|
|
+ parser.add_argument("--input_pattern", type=str, default="*", help="Input file pattern")
|
|
|
+ parser.add_argument("--test_mode", action="store_true", help="Test mode (process only 10 images)")
|
|
|
+ parser.add_argument("--collect_results", type=str, help="收集处理结果到指定CSV文件")
|
|
|
+
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
- # 检查输入目录
|
|
|
- if not os.path.exists(args.images_dir):
|
|
|
- print(f"❌ 输入目录不存在: {args.images_dir}")
|
|
|
- return
|
|
|
-
|
|
|
- print(f"🚀 开始批量处理 OmniDocBench 图片")
|
|
|
- print(f"📁 输入目录: {args.images_dir}")
|
|
|
- print(f"📁 输出目录: {args.output_dir}")
|
|
|
- print("="*60)
|
|
|
-
|
|
|
- # 创建处理器
|
|
|
- processor = OmniDocBenchProcessor(
|
|
|
- ip=args.ip,
|
|
|
- port=args.port,
|
|
|
- model_name=args.model_name,
|
|
|
- prompt_mode=args.prompt_mode,
|
|
|
- dpi=args.dpi,
|
|
|
- min_pixels=args.min_pixels,
|
|
|
- max_pixels=args.max_pixels
|
|
|
- )
|
|
|
-
|
|
|
- # 开始批量处理
|
|
|
- processor.process_batch(args.images_dir, args.output_dir)
|
|
|
+ try:
|
|
|
+ # 获取图像文件列表
|
|
|
+ if args.input_csv:
|
|
|
+ # 从CSV文件读取
|
|
|
+ image_files = get_image_files_from_csv(args.input_csv, "fail")
|
|
|
+ print(f"📊 Loaded {len(image_files)} files from CSV with status filter: fail")
|
|
|
+ elif args.input_file_list:
|
|
|
+ # 从文件列表读取
|
|
|
+ image_files = get_image_files_from_list(args.input_file_list)
|
|
|
+ else:
|
|
|
+ # 从目录读取
|
|
|
+ input_dir = Path(args.input_dir).resolve()
|
|
|
+ print(f"📁 Input dir: {input_dir}")
|
|
|
+
|
|
|
+ if not input_dir.exists():
|
|
|
+ print(f"❌ Input directory does not exist: {input_dir}")
|
|
|
+ return 1
|
|
|
+
|
|
|
+ image_files = get_image_files_from_dir(input_dir, args.input_pattern)
|
|
|
+
|
|
|
+ output_dir = Path(args.output_dir).resolve()
|
|
|
+ print(f"📁 Output dir: {output_dir}")
|
|
|
+ print(f"📊 Found {len(image_files)} image files")
|
|
|
+
|
|
|
+ if args.test_mode:
|
|
|
+ image_files = image_files[:10]
|
|
|
+ print(f"🧪 Test mode: processing only {len(image_files)} images")
|
|
|
+
|
|
|
+ print(f"🌐 Using server: {args.ip}:{args.port}")
|
|
|
+ print(f"📦 Batch size: {args.batch_size}")
|
|
|
+ print(f"🎯 Prompt mode: {args.prompt_mode}")
|
|
|
+
|
|
|
+ # 创建处理器
|
|
|
+ processor = DotsOCRProcessor(
|
|
|
+ ip=args.ip,
|
|
|
+ port=args.port,
|
|
|
+ model_name=args.model_name,
|
|
|
+ prompt_mode=args.prompt_mode,
|
|
|
+ dpi=args.dpi,
|
|
|
+ min_pixels=args.min_pixels,
|
|
|
+ max_pixels=args.max_pixels
|
|
|
+ )
|
|
|
+
|
|
|
+ # 开始处理
|
|
|
+ start_time = time.time()
|
|
|
+ results = process_images_single_process(
|
|
|
+ image_files,
|
|
|
+ processor,
|
|
|
+ args.batch_size,
|
|
|
+ str(output_dir)
|
|
|
+ )
|
|
|
+ total_time = time.time() - start_time
|
|
|
+
|
|
|
+ # 统计结果
|
|
|
+ success_count = sum(1 for r in results if r.get('success', False))
|
|
|
+ skipped_count = sum(1 for r in results if r.get('skipped', False))
|
|
|
+ error_count = len(results) - success_count
|
|
|
+
|
|
|
+ print(f"\n" + "="*60)
|
|
|
+ print(f"✅ Processing completed!")
|
|
|
+ print(f"📊 Statistics:")
|
|
|
+ print(f" Total files: {len(image_files)}")
|
|
|
+ print(f" Successful: {success_count}")
|
|
|
+ print(f" Skipped: {skipped_count}")
|
|
|
+ print(f" Failed: {error_count}")
|
|
|
+ if len(image_files) > 0:
|
|
|
+ print(f" Success rate: {success_count / len(image_files) * 100:.2f}%")
|
|
|
+ print(f"⏱️ Performance:")
|
|
|
+ print(f" Total time: {total_time:.2f} seconds")
|
|
|
+ if total_time > 0:
|
|
|
+ print(f" Throughput: {len(image_files) / total_time:.2f} images/second")
|
|
|
+ print(f" Avg time per image: {total_time / len(image_files):.2f} seconds")
|
|
|
+
|
|
|
+ # 保存结果统计
|
|
|
+ stats = {
|
|
|
+ "total_files": len(image_files),
|
|
|
+ "success_count": success_count,
|
|
|
+ "skipped_count": skipped_count,
|
|
|
+ "error_count": error_count,
|
|
|
+ "success_rate": success_count / len(image_files) if len(image_files) > 0 else 0,
|
|
|
+ "total_time": total_time,
|
|
|
+ "throughput": len(image_files) / total_time if total_time > 0 else 0,
|
|
|
+ "avg_time_per_image": total_time / len(image_files) if len(image_files) > 0 else 0,
|
|
|
+ "batch_size": args.batch_size,
|
|
|
+ "server": f"{args.ip}:{args.port}",
|
|
|
+ "model": args.model_name,
|
|
|
+ "prompt_mode": args.prompt_mode,
|
|
|
+ "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
|
|
|
+ }
|
|
|
+
|
|
|
+ # 保存最终结果
|
|
|
+ output_file_name = Path(output_dir).name
|
|
|
+ output_file = os.path.join(output_dir, f"{output_file_name}_results.json")
|
|
|
+ final_results = {
|
|
|
+ "stats": stats,
|
|
|
+ "results": results
|
|
|
+ }
|
|
|
+
|
|
|
+ with open(output_file, 'w', encoding='utf-8') as f:
|
|
|
+ json.dump(final_results, f, ensure_ascii=False, indent=2)
|
|
|
+
|
|
|
+ print(f"💾 Results saved to: {output_file}")
|
|
|
+
|
|
|
+ # 收集处理结果
|
|
|
+ if args.collect_results:
|
|
|
+ processed_files = collect_pid_files(output_file)
|
|
|
+ output_file_processed = Path(args.collect_results).resolve()
|
|
|
+ with open(output_file_processed, 'w', encoding='utf-8') as f:
|
|
|
+ f.write("image_path,status\n")
|
|
|
+ for file_path, status in processed_files:
|
|
|
+ f.write(f"{file_path},{status}\n")
|
|
|
+ print(f"💾 Processed files saved to: {output_file_processed}")
|
|
|
+
|
|
|
+ return 0
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"❌ Processing failed: {e}", file=sys.stderr)
|
|
|
+ traceback.print_exc()
|
|
|
+ return 1
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
- print(f"🚀 启动单进程DotsOCR程序...")
|
|
|
+ print(f"🚀 启动DotsOCR单进程程序...")
|
|
|
print(f"🔧 CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', 'Not set')}")
|
|
|
|
|
|
if len(sys.argv) == 1:
|
|
|
@@ -430,14 +518,30 @@ if __name__ == "__main__":
|
|
|
# 默认配置
|
|
|
default_config = {
|
|
|
"input_dir": "../../OmniDocBench/OpenDataLab___OmniDocBench/images",
|
|
|
- "output_dir": "./OmniDocBench_Results_Single",
|
|
|
+ "output_dir": "./OmniDocBench_DotsOCR_Results",
|
|
|
+ "ip": "10.192.72.11",
|
|
|
+ "port": "8101",
|
|
|
+ "model_name": "DotsOCR",
|
|
|
+ "prompt_mode": "prompt_layout_all_en",
|
|
|
+ "batch_size": "1",
|
|
|
+ "collect_results": "./OmniDocBench_DotsOCR_Results/processed_files.csv",
|
|
|
}
|
|
|
+
|
|
|
+ # 如果需要处理失败的文件,可以使用这个配置
|
|
|
+ # default_config = {
|
|
|
+ # "input_csv": "./OmniDocBench_DotsOCR_Results/processed_files.csv",
|
|
|
+ # "output_dir": "./OmniDocBench_DotsOCR_Results",
|
|
|
+ # "ip": "127.0.0.1",
|
|
|
+ # "port": "8101",
|
|
|
+ # "collect_results": f"./OmniDocBench_DotsOCR_Results/processed_files_{time.strftime('%Y%m%d_%H%M%S')}.csv",
|
|
|
+ # }
|
|
|
+
|
|
|
# 构造参数
|
|
|
sys.argv = [sys.argv[0]]
|
|
|
for key, value in default_config.items():
|
|
|
sys.argv.extend([f"--{key}", str(value)])
|
|
|
|
|
|
# 测试模式
|
|
|
- # sys.argv.append("--test_mode")
|
|
|
+ sys.argv.append("--test_mode")
|
|
|
|
|
|
sys.exit(main())
|