| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783 |
- """
- 批量处理 OmniDocBench 图片并生成符合评测要求的预测结果
- 根据 OmniDocBench 评测要求:
- - 输入:OpenDataLab___OmniDocBench/images 下的所有 .jpg 图片,以及PDF文件
- - 输出:每个图片对应的 .md、.json 和带标注的 layout 图片文件
- - 输出目录:用于后续的 end2end 评测
- """
- import os
- import sys
- import json
- import tempfile
- import uuid
- import shutil
- import time
- import traceback
- import warnings
- from pathlib import Path
- from typing import List, Dict, Any
- from PIL import Image
- from tqdm import tqdm
- import argparse
- # 导入 dots.ocr 相关模块
- from dots_ocr.parser import DotsOCRParser
- from dots_ocr.utils import dict_promptmode_to_prompt
- from dots_ocr.utils.consts import MIN_PIXELS, MAX_PIXELS
- from dots_ocr.utils.doc_utils import load_images_from_pdf
- # 导入工具函数
- from utils import (
- get_image_files_from_dir,
- get_image_files_from_list,
- get_image_files_from_csv,
- collect_pid_files,
- normalize_markdown_table,
- normalize_json_table
- )
- def convert_pdf_to_images(pdf_file: str, output_dir: str | None = None, dpi: int = 200) -> List[str]:
- """
- 将PDF转换为图像文件
-
- Args:
- pdf_file: PDF文件路径
- output_dir: 输出目录
- dpi: 图像分辨率
-
- Returns:
- 生成的图像文件路径列表
- """
- pdf_path = Path(pdf_file)
- if not pdf_path.exists() or pdf_path.suffix.lower() != '.pdf':
- print(f"❌ Invalid PDF file: {pdf_path}")
- return []
- # 如果没有指定输出目录,使用PDF同名目录
- if output_dir is None:
- output_path = pdf_path.parent / f"{pdf_path.stem}"
- else:
- output_path = Path(output_dir) / f"{pdf_path.stem}"
- output_path = output_path.resolve()
- output_path.mkdir(parents=True, exist_ok=True)
- try:
- # 使用utils中的函数加载PDF图像
- images = load_images_from_pdf(str(pdf_path), dpi=dpi)
-
- image_paths = []
- for i, image in enumerate(images):
- # 生成图像文件名
- image_filename = f"{pdf_path.stem}_page_{i+1:03d}.png"
- image_path = output_path / image_filename
- # 保存图像
- image.save(str(image_path))
- image_paths.append(str(image_path))
-
- print(f"✅ Converted {len(images)} pages from {pdf_path.name} to images")
- return image_paths
-
- except Exception as e:
- print(f"❌ Error converting PDF {pdf_path}: {e}")
- traceback.print_exc()
- return []
- def get_input_files(args) -> List[str]:
- """
- 获取输入文件列表,统一处理PDF和图像文件
-
- Args:
- args: 命令行参数
-
- Returns:
- 处理后的图像文件路径列表
- """
- input_files = []
-
- # 获取原始输入文件
- if args.input_csv:
- raw_files = get_image_files_from_csv(args.input_csv, "fail")
- elif args.input_file_list:
- raw_files = get_image_files_from_list(args.input_file_list)
- elif args.input_file:
- raw_files = [Path(args.input_file).resolve()]
- else:
- input_dir = Path(args.input_dir).resolve()
- if not input_dir.exists():
- print(f"❌ Input directory does not exist: {input_dir}")
- return []
-
- # 获取所有支持的文件(图像和PDF)
- image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif']
- pdf_extensions = ['.pdf']
-
- raw_files = []
- for ext in image_extensions + pdf_extensions:
- raw_files.extend(list(input_dir.glob(f"*{ext}")))
- raw_files.extend(list(input_dir.glob(f"*{ext.upper()}")))
-
- raw_files = [str(f) for f in raw_files]
-
- # 分别处理PDF和图像文件
- pdf_count = 0
- image_count = 0
-
- for file_path in raw_files:
- file_path = Path(file_path)
-
- if file_path.suffix.lower() == '.pdf':
- # 转换PDF为图像
- print(f"📄 Processing PDF: {file_path.name}")
- pdf_images = convert_pdf_to_images(
- str(file_path),
- args.output_dir,
- dpi=args.dpi
- )
- input_files.extend(pdf_images)
- pdf_count += 1
- else:
- # 直接添加图像文件
- if file_path.exists():
- input_files.append(str(file_path))
- image_count += 1
-
- print(f"📊 Input summary:")
- print(f" PDF files processed: {pdf_count}")
- print(f" Image files found: {image_count}")
- print(f" Total image files to process: {len(input_files)}")
-
- return input_files
- class DotsOCRProcessor:
- """DotsOCR 处理器"""
-
- def __init__(self,
- ip: str = "127.0.0.1",
- port: int = 8101,
- model_name: str = "DotsOCR",
- prompt_mode: str = "prompt_layout_all_en",
- dpi: int = 200,
- min_pixels: int = MIN_PIXELS,
- max_pixels: int = MAX_PIXELS,
- normalize_numbers: bool = False):
- """
- 初始化处理器
-
- Args:
- ip: vLLM 服务器 IP
- port: vLLM 服务器端口
- model_name: 模型名称
- prompt_mode: 提示模式
- dpi: PDF 处理 DPI
- min_pixels: 最小像素数
- max_pixels: 最大像素数
- """
- self.ip = ip
- self.port = port
- self.model_name = model_name
- self.prompt_mode = prompt_mode
- self.dpi = dpi
- self.min_pixels = min_pixels
- self.max_pixels = max_pixels
- self.normalize_numbers = normalize_numbers
- # 初始化解析器
- self.parser = DotsOCRParser(
- ip=ip,
- port=port,
- dpi=dpi,
- min_pixels=min_pixels,
- max_pixels=max_pixels,
- model_name=model_name
- )
-
- print(f"DotsOCR Parser 初始化完成:")
- print(f" - 服务器: {ip}:{port}")
- print(f" - 模型: {model_name}")
- print(f" - 提示模式: {prompt_mode}")
- print(f" - 像素范围: {min_pixels} - {max_pixels}")
-
- def create_temp_session_dir(self) -> tuple:
- """创建临时会话目录"""
- session_id = uuid.uuid4().hex[:8]
- temp_dir = os.path.join(tempfile.gettempdir(), f"omnidocbench_batch_{session_id}")
- os.makedirs(temp_dir, exist_ok=True)
- return temp_dir, session_id
-
- def save_results_to_output_dir(self, result: Dict, image_name: str, output_dir: str) -> Dict[str, str]:
- """
- 将处理结果保存到输出目录
-
- Args:
- result: 解析结果
- image_name: 图片文件名(不含扩展名)
- output_dir: 输出目录
-
- Returns:
- dict: 保存的文件路径
- """
- saved_files = {}
-
- try:
- # 1. 保存 Markdown 文件(OmniDocBench 评测必需)
- output_md_path = os.path.join(output_dir, f"{image_name}.md")
- md_content = ""
-
- # 优先使用无页眉页脚的版本(符合 OmniDocBench 评测要求)
- if 'md_content_nohf_path' in result and os.path.exists(result['md_content_nohf_path']):
- with open(result['md_content_nohf_path'], 'r', encoding='utf-8') as f:
- md_content = f.read()
- elif 'md_content_path' in result and os.path.exists(result['md_content_path']):
- with open(result['md_content_path'], 'r', encoding='utf-8') as f:
- md_content = f.read()
- else:
- md_content = "# 解析失败\n\n未能提取到有效的文档内容。"
-
- # 如果启用数字标准化,处理 Markdown 内容
- original_text = md_content
- if self.normalize_numbers:
- # generated_text = normalize_financial_numbers(generated_text)
- # 只对Markdown表格进行数字标准化
- generated_text = normalize_markdown_table(md_content)
-
- # 统计标准化的变化
- changes_count = len([1 for o, n in zip(original_text, generated_text) if o != n])
- if changes_count > 0:
- saved_files['md_normalized'] = f"✅ 已标准化 {changes_count} 个字符(全角→半角)"
- else:
- saved_files['md_normalized'] = "ℹ️ 无需标准化(已是标准格式)"
- with open(output_md_path, 'w', encoding='utf-8') as f:
- f.write(generated_text)
- saved_files['md'] = output_md_path
- # 如果启用了标准化,也保存原始版本用于对比
- if self.normalize_numbers and original_text != generated_text:
- original_markdown_path = Path(output_dir) / f"{Path(image_name).stem}_original.md"
- with open(original_markdown_path, 'w', encoding='utf-8') as f:
- f.write(original_text)
-
- # 2. 保存 JSON 文件
- output_json_path = os.path.join(output_dir, f"{image_name}.json")
- json_data = {}
-
- if 'layout_info_path' in result and os.path.exists(result['layout_info_path']):
- with open(result['layout_info_path'], 'r', encoding='utf-8') as f:
- json_content = f.read()
- else:
- json_content = f'{{"error": "未能提取到有效的布局信息"}}'
-
- # 对json中的表格内容进行数字标准化,
- original_json_text = json_content
- if self.normalize_numbers:
- json_content = normalize_json_table(json_content)
-
- # 统计标准化的变化
- changes_count = len([1 for o, n in zip(original_json_text, json_content) if o != n])
- if changes_count > 0:
- saved_files['json_normalized'] = f"✅ 已标准化 {changes_count} 个字符(全角→半角)"
- else:
- saved_files['json_normalized'] = "ℹ️ 无需标准化(已是标准格式)"
- with open(output_json_path, 'w', encoding='utf-8') as f:
- f.write(json_content)
- saved_files['json'] = output_json_path
- # 如果启用了标准化,也保存原始版本用于对比
- if self.normalize_numbers and original_json_text != json_content:
- original_json_path = Path(output_dir) / f"{Path(image_name).stem}_original.json"
- with open(original_json_path, 'w', encoding='utf-8') as f:
- f.write(original_json_text)
-
- # 3. 保存带标注的布局图片
- output_layout_image_path = os.path.join(output_dir, f"{image_name}_layout.jpg")
-
- if 'layout_image_path' in result and os.path.exists(result['layout_image_path']):
- # 直接复制布局图片
- shutil.copy2(result['layout_image_path'], output_layout_image_path)
- saved_files['layout_image'] = output_layout_image_path
- else:
- # 如果没有布局图片,使用原始图片作为占位符
- try:
- original_image = Image.open(result.get('original_image_path', ''))
- original_image.save(output_layout_image_path, 'JPEG', quality=95)
- saved_files['layout_image'] = output_layout_image_path
- except Exception as e:
- saved_files['layout_image'] = None
-
- except Exception as e:
- print(f"Error saving results for {image_name}: {e}")
-
- return saved_files
-
- def process_single_image(self, image_path: str, output_dir: str) -> Dict[str, Any]:
- """
- 处理单张图片
-
- Args:
- image_path: 图片路径
- output_dir: 输出目录
-
- Returns:
- dict: 处理结果
- """
- start_time = time.time()
- image_name = Path(image_path).stem
-
- result_info = {
- "image_path": image_path,
- "processing_time": 0,
- "success": False,
- "device": f"{self.ip}:{self.port}",
- "error": None,
- "output_files": {},
- "is_pdf_page": "_page_" in Path(image_path).name # 标记是否为PDF页面
- }
-
- try:
- # 检查输出文件是否已存在
- output_md_path = os.path.join(output_dir, f"{image_name}.md")
- output_json_path = os.path.join(output_dir, f"{image_name}.json")
- output_layout_path = os.path.join(output_dir, f"{image_name}_layout.jpg")
-
- if all(os.path.exists(p) for p in [output_md_path, output_json_path, output_layout_path]):
- result_info.update({
- "success": True,
- "processing_time": 0,
- "output_files": {
- "md": output_md_path,
- "json": output_json_path,
- "layout_image": output_layout_path
- },
- "skipped": True
- })
- return result_info
-
- # 创建临时会话目录
- temp_dir, session_id = self.create_temp_session_dir()
-
- try:
- # 读取图片
- image = Image.open(image_path)
-
- # 使用 DotsOCRParser 处理图片
- filename = f"omnidocbench_{session_id}"
- results = self.parser.parse_image(
- input_path=image,
- filename=filename,
- prompt_mode=self.prompt_mode,
- save_dir=temp_dir,
- fitz_preprocess=True # 对图片使用 fitz 预处理
- )
-
- # 解析结果
- if not results:
- raise Exception("未返回解析结果")
-
- result = results[0] # parse_image 返回单个结果的列表
-
- # 保存所有结果文件到输出目录
- saved_files = self.save_results_to_output_dir(result, image_name, output_dir)
-
- # 验证保存结果
- success_count = sum(1 for path in saved_files.values() if path and os.path.exists(path))
-
- if success_count >= 2: # 至少保存了 md 和 json
- result_info.update({
- "success": True,
- "output_files": saved_files
- })
- else:
- raise Exception(f"保存文件不完整 ({success_count}/3)")
-
- finally:
- # 清理临时目录
- if os.path.exists(temp_dir):
- shutil.rmtree(temp_dir, ignore_errors=True)
-
- except Exception as e:
- result_info["error"] = str(e)
- print(f"❌ Error processing {image_name}: {e}")
- finally:
- result_info["processing_time"] = time.time() - start_time
-
- return result_info
- def process_images_single_process(image_paths: List[str],
- processor: DotsOCRProcessor,
- batch_size: int = 1,
- output_dir: str = "./output") -> List[Dict[str, Any]]:
- """
- 单进程版本的图像处理函数
-
- Args:
- image_paths: 图像路径列表
- processor: DotsOCR处理器实例
- batch_size: 批处理大小
- output_dir: 输出目录
-
- Returns:
- 处理结果列表
- """
- # 创建输出目录
- output_path = Path(output_dir)
- output_path.mkdir(parents=True, exist_ok=True)
-
- all_results = []
- total_images = len(image_paths)
-
- print(f"Processing {total_images} images with batch size {batch_size}")
-
- # 使用tqdm显示进度,添加更多统计信息
- with tqdm(total=total_images, desc="Processing images", unit="img",
- bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') as pbar:
-
- # 按批次处理图像(DotsOCR通常单张处理)
- for i in range(0, total_images, batch_size):
- batch = image_paths[i:i + batch_size]
- batch_start_time = time.time()
- batch_results = []
-
- try:
- # 处理批次中的每张图片
- for image_path in batch:
- try:
- result = processor.process_single_image(image_path, output_dir)
- batch_results.append(result)
-
- except Exception as e:
- print(f"Error processing {image_path}: {e}", file=sys.stderr)
- traceback.print_exc()
-
- batch_results.append({
- "image_path": image_path,
- "processing_time": 0,
- "success": False,
- "device": f"{processor.ip}:{processor.port}",
- "error": str(e)
- })
-
- batch_processing_time = time.time() - batch_start_time
- all_results.extend(batch_results)
-
- # 更新进度条
- success_count = sum(1 for r in batch_results if r.get('success', False))
- skipped_count = sum(1 for r in batch_results if r.get('skipped', False))
- total_success = sum(1 for r in all_results if r.get('success', False))
- total_skipped = sum(1 for r in all_results if r.get('skipped', False))
- avg_time = batch_processing_time / len(batch)
-
- pbar.update(len(batch))
- pbar.set_postfix({
- 'batch_time': f"{batch_processing_time:.2f}s",
- 'avg_time': f"{avg_time:.2f}s/img",
- 'success': f"{total_success}/{len(all_results)}",
- 'skipped': f"{total_skipped}",
- 'rate': f"{total_success/len(all_results)*100:.1f}%"
- })
-
- except Exception as e:
- print(f"Error processing batch {[Path(p).name for p in batch]}: {e}", file=sys.stderr)
- traceback.print_exc()
-
- # 为批次中的所有图像添加错误结果
- error_results = []
- for img_path in batch:
- error_results.append({
- "image_path": str(img_path),
- "processing_time": 0,
- "success": False,
- "device": f"{processor.ip}:{processor.port}",
- "error": str(e)
- })
- all_results.extend(error_results)
- pbar.update(len(batch))
-
- return all_results
- def process_images_concurrent(image_paths: List[str],
- processor: DotsOCRProcessor,
- batch_size: int = 1,
- output_dir: str = "./output",
- max_workers: int = 3) -> List[Dict[str, Any]]:
- """并发版本的图像处理函数"""
-
- from concurrent.futures import ThreadPoolExecutor, as_completed
-
- Path(output_dir).mkdir(parents=True, exist_ok=True)
-
- def process_batch(batch_images):
- """处理一批图像"""
- batch_results = []
- for image_path in batch_images:
- try:
- result = processor.process_single_image(image_path, output_dir)
- batch_results.append(result)
- except Exception as e:
- batch_results.append({
- "image_path": image_path,
- "processing_time": 0,
- "success": False,
- "device": f"{processor.ip}:{processor.port}",
- "error": str(e)
- })
- return batch_results
-
- # 将图像分批
- batches = [image_paths[i:i + batch_size] for i in range(0, len(image_paths), batch_size)]
-
- all_results = []
-
- with ThreadPoolExecutor(max_workers=max_workers) as executor:
- # 提交所有批次
- future_to_batch = {executor.submit(process_batch, batch): batch for batch in batches}
-
- # 使用 tqdm 显示进度
- with tqdm(total=len(image_paths), desc="Processing images") as pbar:
- for future in as_completed(future_to_batch):
- try:
- batch_results = future.result()
- all_results.extend(batch_results)
-
- # 更新进度
- success_count = sum(1 for r in batch_results if r.get('success', False))
- pbar.update(len(batch_results))
- pbar.set_postfix({'batch_success': f"{success_count}/{len(batch_results)}"})
-
- except Exception as e:
- batch = future_to_batch[future]
- # 为批次中的所有图像添加错误结果
- error_results = [
- {
- "image_path": img_path,
- "processing_time": 0,
- "success": False,
- "device": f"{processor.ip}:{processor.port}",
- "error": str(e)
- }
- for img_path in batch
- ]
- all_results.extend(error_results)
- pbar.update(len(batch))
-
- return all_results
- def main():
- """主函数"""
- parser = argparse.ArgumentParser(description="DotsOCR OmniDocBench Processing with PDF Support")
-
- # 输入参数组
- input_group = parser.add_mutually_exclusive_group(required=True)
- input_group.add_argument("--input_file", type=str, help="Input file (supports both PDF and image file)")
- input_group.add_argument("--input_dir", type=str, help="Input directory (supports both PDF and image files)")
- input_group.add_argument("--input_file_list", type=str, help="Input file list (one file per line)")
- input_group.add_argument("--input_csv", type=str, help="Input CSV file with image_path and status columns")
- # 输出参数
- parser.add_argument("--output_dir", type=str, help="Output directory")
-
- # DotsOCR 参数
- parser.add_argument("--ip", type=str, default="127.0.0.1", help="vLLM server IP")
- parser.add_argument("--port", type=int, default=8101, help="vLLM server port")
- parser.add_argument("--model_name", type=str, default="DotsOCR", help="Model name")
- parser.add_argument("--prompt_mode", type=str, default="prompt_layout_all_en",
- choices=list(dict_promptmode_to_prompt.keys()), help="Prompt mode")
- parser.add_argument("--min_pixels", type=int, default=MIN_PIXELS, help="Minimum pixels")
- parser.add_argument("--max_pixels", type=int, default=MAX_PIXELS, help="Maximum pixels")
- parser.add_argument("--dpi", type=int, default=200, help="PDF processing DPI")
- parser.add_argument('--no-normalize', action='store_true', help='禁用数字标准化')
-
- # 处理参数
- parser.add_argument("--batch_size", type=int, default=1, help="Batch size")
- parser.add_argument("--input_pattern", type=str, default="*", help="Input file pattern")
- parser.add_argument("--test_mode", action="store_true", help="Test mode (process only 10 images)")
- parser.add_argument("--collect_results", type=str, help="收集处理结果到指定CSV文件")
- # 并发参数
- parser.add_argument("--max_workers", type=int, default=3,
- help="Maximum number of concurrent workers (should match vLLM data-parallel-size)")
- parser.add_argument("--use_threading", action="store_true",
- help="Use multi-threading")
-
- args = parser.parse_args()
-
- try:
- # 获取并预处理输入文件
- print("🔄 Preprocessing input files...")
- image_files = get_input_files(args)
-
- if not image_files:
- print("❌ No input files found or processed")
- return 1
- output_dir = Path(args.output_dir).resolve()
- print(f"📁 Output dir: {output_dir}")
- print(f"📊 Found {len(image_files)} image files to process")
-
- if args.test_mode:
- image_files = image_files[:10]
- print(f"🧪 Test mode: processing only {len(image_files)} images")
-
- print(f"🌐 Using server: {args.ip}:{args.port}")
- print(f"📦 Batch size: {args.batch_size}")
- print(f"🎯 Prompt mode: {args.prompt_mode}")
-
- # 创建处理器
- processor = DotsOCRProcessor(
- ip=args.ip,
- port=args.port,
- model_name=args.model_name,
- prompt_mode=args.prompt_mode,
- dpi=args.dpi,
- min_pixels=args.min_pixels,
- max_pixels=args.max_pixels,
- normalize_numbers=not args.no_normalize
- )
-
- # 开始处理
- start_time = time.time()
-
- # 选择处理方式
- if args.use_threading:
- results = process_images_concurrent(
- image_files,
- processor,
- args.batch_size,
- str(output_dir),
- args.max_workers
- )
- else:
- results = process_images_single_process(
- image_files,
- processor,
- args.batch_size,
- str(output_dir)
- )
-
- total_time = time.time() - start_time
-
- # 统计结果
- success_count = sum(1 for r in results if r.get('success', False))
- skipped_count = sum(1 for r in results if r.get('skipped', False))
- error_count = len(results) - success_count
- pdf_page_count = sum(1 for r in results if r.get('is_pdf_page', False))
-
- print(f"\n" + "="*60)
- print(f"✅ Processing completed!")
- print(f"📊 Statistics:")
- print(f" Total files processed: {len(image_files)}")
- print(f" PDF pages processed: {pdf_page_count}")
- print(f" Regular images processed: {len(image_files) - pdf_page_count}")
- print(f" Successful: {success_count}")
- print(f" Skipped: {skipped_count}")
- print(f" Failed: {error_count}")
- if len(image_files) > 0:
- print(f" Success rate: {success_count / len(image_files) * 100:.2f}%")
- print(f"⏱️ Performance:")
- print(f" Total time: {total_time:.2f} seconds")
- if total_time > 0:
- print(f" Throughput: {len(image_files) / total_time:.2f} images/second")
- print(f" Avg time per image: {total_time / len(image_files):.2f} seconds")
-
- # 保存结果统计
- stats = {
- "total_files": len(image_files),
- "pdf_pages": pdf_page_count,
- "regular_images": len(image_files) - pdf_page_count,
- "success_count": success_count,
- "skipped_count": skipped_count,
- "error_count": error_count,
- "success_rate": success_count / len(image_files) if len(image_files) > 0 else 0,
- "total_time": total_time,
- "throughput": len(image_files) / total_time if total_time > 0 else 0,
- "avg_time_per_image": total_time / len(image_files) if len(image_files) > 0 else 0,
- "batch_size": args.batch_size,
- "server": f"{args.ip}:{args.port}",
- "model": args.model_name,
- "prompt_mode": args.prompt_mode,
- "pdf_dpi": args.dpi,
- "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
- }
-
- # 保存最终结果
- output_file_name = Path(output_dir).name
- output_file = os.path.join(output_dir, f"{output_file_name}_results.json")
- final_results = {
- "stats": stats,
- "results": results
- }
-
- with open(output_file, 'w', encoding='utf-8') as f:
- json.dump(final_results, f, ensure_ascii=False, indent=2)
-
- print(f"💾 Results saved to: {output_file}")
- # 收集处理结果
- if not args.collect_results:
- output_file_processed = Path(args.output_dir) / f"processed_files_{time.strftime('%Y%m%d_%H%M%S')}.csv"
- else:
- output_file_processed = Path(args.collect_results).resolve()
-
- processed_files = collect_pid_files(output_file)
- with open(output_file_processed, 'w', encoding='utf-8') as f:
- f.write("image_path,status\n")
- for file_path, status in processed_files:
- f.write(f"{file_path},{status}\n")
- print(f"💾 Processed files saved to: {output_file_processed}")
- return 0
-
- except Exception as e:
- print(f"❌ Processing failed: {e}", file=sys.stderr)
- traceback.print_exc()
- return 1
- if __name__ == "__main__":
- print(f"🚀 启动DotsOCR统一PDF/图像处理程序...")
- print(f"🔧 CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', 'Not set')}")
-
- if len(sys.argv) == 1:
- # 如果没有命令行参数,使用默认配置运行
- print("ℹ️ No command line arguments provided. Running with default configuration...")
-
- # 默认配置
- default_config = {
- "input_file": "./sample_data/2023年度报告母公司_page_003.png",
- "output_dir": "./sample_data",
- "collect_results": "./sample_data/processed_files.csv",
- # "input_dir": "../../OmniDocBench/OpenDataLab___OmniDocBench/images",
- # "output_dir": "./OmniDocBench_DotsOCR_Results",
- # "collect_results": "./OmniDocBench_DotsOCR_Results/processed_files.csv",
- "ip": "10.192.72.11",
- # "ip": "127.0.0.1",
- "port": "8101",
- "model_name": "DotsOCR",
- "prompt_mode": "prompt_layout_all_en",
- "batch_size": "1",
- "max_workers": "3",
- "dpi": "200",
- }
-
- # 如果需要处理失败的文件,可以使用这个配置
- # default_config = {
- # "input_csv": "./OmniDocBench_DotsOCR_Results/processed_files.csv",
- # "output_dir": "./OmniDocBench_DotsOCR_Results",
- # "ip": "127.0.0.1",
- # "port": "8101",
- # "collect_results": f"./OmniDocBench_DotsOCR_Results/processed_files_{time.strftime('%Y%m%d_%H%M%S')}.csv",
- # }
-
- # 构造参数
- sys.argv = [sys.argv[0]]
- for key, value in default_config.items():
- sys.argv.extend([f"--{key}", str(value)])
-
- # 测试模式
- sys.argv.append("--use_threading")
- # sys.argv.append("--test_mode")
-
- sys.exit(main())
|