OmniDocBench_DotsOCR_multthreads.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783
  1. """
  2. 批量处理 OmniDocBench 图片并生成符合评测要求的预测结果
  3. 根据 OmniDocBench 评测要求:
  4. - 输入:OpenDataLab___OmniDocBench/images 下的所有 .jpg 图片,以及PDF文件
  5. - 输出:每个图片对应的 .md、.json 和带标注的 layout 图片文件
  6. - 输出目录:用于后续的 end2end 评测
  7. """
  8. import os
  9. import sys
  10. import json
  11. import tempfile
  12. import uuid
  13. import shutil
  14. import time
  15. import traceback
  16. import warnings
  17. from pathlib import Path
  18. from typing import List, Dict, Any
  19. from PIL import Image
  20. from tqdm import tqdm
  21. import argparse
  22. # 导入 dots.ocr 相关模块
  23. from dots_ocr.parser import DotsOCRParser
  24. from dots_ocr.utils import dict_promptmode_to_prompt
  25. from dots_ocr.utils.consts import MIN_PIXELS, MAX_PIXELS
  26. from dots_ocr.utils.doc_utils import load_images_from_pdf
  27. # 导入工具函数
  28. from utils import (
  29. get_image_files_from_dir,
  30. get_image_files_from_list,
  31. get_image_files_from_csv,
  32. collect_pid_files,
  33. normalize_markdown_table,
  34. normalize_json_table
  35. )
  36. def convert_pdf_to_images(pdf_file: str, output_dir: str | None = None, dpi: int = 200) -> List[str]:
  37. """
  38. 将PDF转换为图像文件
  39. Args:
  40. pdf_file: PDF文件路径
  41. output_dir: 输出目录
  42. dpi: 图像分辨率
  43. Returns:
  44. 生成的图像文件路径列表
  45. """
  46. pdf_path = Path(pdf_file)
  47. if not pdf_path.exists() or pdf_path.suffix.lower() != '.pdf':
  48. print(f"❌ Invalid PDF file: {pdf_path}")
  49. return []
  50. # 如果没有指定输出目录,使用PDF同名目录
  51. if output_dir is None:
  52. output_path = pdf_path.parent / f"{pdf_path.stem}"
  53. else:
  54. output_path = Path(output_dir) / f"{pdf_path.stem}"
  55. output_path = output_path.resolve()
  56. output_path.mkdir(parents=True, exist_ok=True)
  57. try:
  58. # 使用utils中的函数加载PDF图像
  59. images = load_images_from_pdf(str(pdf_path), dpi=dpi)
  60. image_paths = []
  61. for i, image in enumerate(images):
  62. # 生成图像文件名
  63. image_filename = f"{pdf_path.stem}_page_{i+1:03d}.png"
  64. image_path = output_path / image_filename
  65. # 保存图像
  66. image.save(str(image_path))
  67. image_paths.append(str(image_path))
  68. print(f"✅ Converted {len(images)} pages from {pdf_path.name} to images")
  69. return image_paths
  70. except Exception as e:
  71. print(f"❌ Error converting PDF {pdf_path}: {e}")
  72. traceback.print_exc()
  73. return []
  74. def get_input_files(args) -> List[str]:
  75. """
  76. 获取输入文件列表,统一处理PDF和图像文件
  77. Args:
  78. args: 命令行参数
  79. Returns:
  80. 处理后的图像文件路径列表
  81. """
  82. input_files = []
  83. # 获取原始输入文件
  84. if args.input_csv:
  85. raw_files = get_image_files_from_csv(args.input_csv, "fail")
  86. elif args.input_file_list:
  87. raw_files = get_image_files_from_list(args.input_file_list)
  88. elif args.input_file:
  89. raw_files = [Path(args.input_file).resolve()]
  90. else:
  91. input_dir = Path(args.input_dir).resolve()
  92. if not input_dir.exists():
  93. print(f"❌ Input directory does not exist: {input_dir}")
  94. return []
  95. # 获取所有支持的文件(图像和PDF)
  96. image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif']
  97. pdf_extensions = ['.pdf']
  98. raw_files = []
  99. for ext in image_extensions + pdf_extensions:
  100. raw_files.extend(list(input_dir.glob(f"*{ext}")))
  101. raw_files.extend(list(input_dir.glob(f"*{ext.upper()}")))
  102. raw_files = [str(f) for f in raw_files]
  103. # 分别处理PDF和图像文件
  104. pdf_count = 0
  105. image_count = 0
  106. for file_path in raw_files:
  107. file_path = Path(file_path)
  108. if file_path.suffix.lower() == '.pdf':
  109. # 转换PDF为图像
  110. print(f"📄 Processing PDF: {file_path.name}")
  111. pdf_images = convert_pdf_to_images(
  112. str(file_path),
  113. args.output_dir,
  114. dpi=args.dpi
  115. )
  116. input_files.extend(pdf_images)
  117. pdf_count += 1
  118. else:
  119. # 直接添加图像文件
  120. if file_path.exists():
  121. input_files.append(str(file_path))
  122. image_count += 1
  123. print(f"📊 Input summary:")
  124. print(f" PDF files processed: {pdf_count}")
  125. print(f" Image files found: {image_count}")
  126. print(f" Total image files to process: {len(input_files)}")
  127. return input_files
  128. class DotsOCRProcessor:
  129. """DotsOCR 处理器"""
  130. def __init__(self,
  131. ip: str = "127.0.0.1",
  132. port: int = 8101,
  133. model_name: str = "DotsOCR",
  134. prompt_mode: str = "prompt_layout_all_en",
  135. dpi: int = 200,
  136. min_pixels: int = MIN_PIXELS,
  137. max_pixels: int = MAX_PIXELS,
  138. normalize_numbers: bool = False):
  139. """
  140. 初始化处理器
  141. Args:
  142. ip: vLLM 服务器 IP
  143. port: vLLM 服务器端口
  144. model_name: 模型名称
  145. prompt_mode: 提示模式
  146. dpi: PDF 处理 DPI
  147. min_pixels: 最小像素数
  148. max_pixels: 最大像素数
  149. """
  150. self.ip = ip
  151. self.port = port
  152. self.model_name = model_name
  153. self.prompt_mode = prompt_mode
  154. self.dpi = dpi
  155. self.min_pixels = min_pixels
  156. self.max_pixels = max_pixels
  157. self.normalize_numbers = normalize_numbers
  158. # 初始化解析器
  159. self.parser = DotsOCRParser(
  160. ip=ip,
  161. port=port,
  162. dpi=dpi,
  163. min_pixels=min_pixels,
  164. max_pixels=max_pixels,
  165. model_name=model_name
  166. )
  167. print(f"DotsOCR Parser 初始化完成:")
  168. print(f" - 服务器: {ip}:{port}")
  169. print(f" - 模型: {model_name}")
  170. print(f" - 提示模式: {prompt_mode}")
  171. print(f" - 像素范围: {min_pixels} - {max_pixels}")
  172. def create_temp_session_dir(self) -> tuple:
  173. """创建临时会话目录"""
  174. session_id = uuid.uuid4().hex[:8]
  175. temp_dir = os.path.join(tempfile.gettempdir(), f"omnidocbench_batch_{session_id}")
  176. os.makedirs(temp_dir, exist_ok=True)
  177. return temp_dir, session_id
  178. def save_results_to_output_dir(self, result: Dict, image_name: str, output_dir: str) -> Dict[str, str]:
  179. """
  180. 将处理结果保存到输出目录
  181. Args:
  182. result: 解析结果
  183. image_name: 图片文件名(不含扩展名)
  184. output_dir: 输出目录
  185. Returns:
  186. dict: 保存的文件路径
  187. """
  188. saved_files = {}
  189. try:
  190. # 1. 保存 Markdown 文件(OmniDocBench 评测必需)
  191. output_md_path = os.path.join(output_dir, f"{image_name}.md")
  192. md_content = ""
  193. # 优先使用无页眉页脚的版本(符合 OmniDocBench 评测要求)
  194. if 'md_content_nohf_path' in result and os.path.exists(result['md_content_nohf_path']):
  195. with open(result['md_content_nohf_path'], 'r', encoding='utf-8') as f:
  196. md_content = f.read()
  197. elif 'md_content_path' in result and os.path.exists(result['md_content_path']):
  198. with open(result['md_content_path'], 'r', encoding='utf-8') as f:
  199. md_content = f.read()
  200. else:
  201. md_content = "# 解析失败\n\n未能提取到有效的文档内容。"
  202. # 如果启用数字标准化,处理 Markdown 内容
  203. original_text = md_content
  204. if self.normalize_numbers:
  205. # generated_text = normalize_financial_numbers(generated_text)
  206. # 只对Markdown表格进行数字标准化
  207. generated_text = normalize_markdown_table(md_content)
  208. # 统计标准化的变化
  209. changes_count = len([1 for o, n in zip(original_text, generated_text) if o != n])
  210. if changes_count > 0:
  211. saved_files['md_normalized'] = f"✅ 已标准化 {changes_count} 个字符(全角→半角)"
  212. else:
  213. saved_files['md_normalized'] = "ℹ️ 无需标准化(已是标准格式)"
  214. with open(output_md_path, 'w', encoding='utf-8') as f:
  215. f.write(generated_text)
  216. saved_files['md'] = output_md_path
  217. # 如果启用了标准化,也保存原始版本用于对比
  218. if self.normalize_numbers and original_text != generated_text:
  219. original_markdown_path = Path(output_dir) / f"{Path(image_name).stem}_original.md"
  220. with open(original_markdown_path, 'w', encoding='utf-8') as f:
  221. f.write(original_text)
  222. # 2. 保存 JSON 文件
  223. output_json_path = os.path.join(output_dir, f"{image_name}.json")
  224. json_data = {}
  225. if 'layout_info_path' in result and os.path.exists(result['layout_info_path']):
  226. with open(result['layout_info_path'], 'r', encoding='utf-8') as f:
  227. json_content = f.read()
  228. else:
  229. json_content = f'{{"error": "未能提取到有效的布局信息"}}'
  230. # 对json中的表格内容进行数字标准化,
  231. original_json_text = json_content
  232. if self.normalize_numbers:
  233. json_content = normalize_json_table(json_content)
  234. # 统计标准化的变化
  235. changes_count = len([1 for o, n in zip(original_json_text, json_content) if o != n])
  236. if changes_count > 0:
  237. saved_files['json_normalized'] = f"✅ 已标准化 {changes_count} 个字符(全角→半角)"
  238. else:
  239. saved_files['json_normalized'] = "ℹ️ 无需标准化(已是标准格式)"
  240. with open(output_json_path, 'w', encoding='utf-8') as f:
  241. f.write(json_content)
  242. saved_files['json'] = output_json_path
  243. # 如果启用了标准化,也保存原始版本用于对比
  244. if self.normalize_numbers and original_json_text != json_content:
  245. original_json_path = Path(output_dir) / f"{Path(image_name).stem}_original.json"
  246. with open(original_json_path, 'w', encoding='utf-8') as f:
  247. f.write(original_json_text)
  248. # 3. 保存带标注的布局图片
  249. output_layout_image_path = os.path.join(output_dir, f"{image_name}_layout.jpg")
  250. if 'layout_image_path' in result and os.path.exists(result['layout_image_path']):
  251. # 直接复制布局图片
  252. shutil.copy2(result['layout_image_path'], output_layout_image_path)
  253. saved_files['layout_image'] = output_layout_image_path
  254. else:
  255. # 如果没有布局图片,使用原始图片作为占位符
  256. try:
  257. original_image = Image.open(result.get('original_image_path', ''))
  258. original_image.save(output_layout_image_path, 'JPEG', quality=95)
  259. saved_files['layout_image'] = output_layout_image_path
  260. except Exception as e:
  261. saved_files['layout_image'] = None
  262. except Exception as e:
  263. print(f"Error saving results for {image_name}: {e}")
  264. return saved_files
  265. def process_single_image(self, image_path: str, output_dir: str) -> Dict[str, Any]:
  266. """
  267. 处理单张图片
  268. Args:
  269. image_path: 图片路径
  270. output_dir: 输出目录
  271. Returns:
  272. dict: 处理结果
  273. """
  274. start_time = time.time()
  275. image_name = Path(image_path).stem
  276. result_info = {
  277. "image_path": image_path,
  278. "processing_time": 0,
  279. "success": False,
  280. "device": f"{self.ip}:{self.port}",
  281. "error": None,
  282. "output_files": {},
  283. "is_pdf_page": "_page_" in Path(image_path).name # 标记是否为PDF页面
  284. }
  285. try:
  286. # 检查输出文件是否已存在
  287. output_md_path = os.path.join(output_dir, f"{image_name}.md")
  288. output_json_path = os.path.join(output_dir, f"{image_name}.json")
  289. output_layout_path = os.path.join(output_dir, f"{image_name}_layout.jpg")
  290. if all(os.path.exists(p) for p in [output_md_path, output_json_path, output_layout_path]):
  291. result_info.update({
  292. "success": True,
  293. "processing_time": 0,
  294. "output_files": {
  295. "md": output_md_path,
  296. "json": output_json_path,
  297. "layout_image": output_layout_path
  298. },
  299. "skipped": True
  300. })
  301. return result_info
  302. # 创建临时会话目录
  303. temp_dir, session_id = self.create_temp_session_dir()
  304. try:
  305. # 读取图片
  306. image = Image.open(image_path)
  307. # 使用 DotsOCRParser 处理图片
  308. filename = f"omnidocbench_{session_id}"
  309. results = self.parser.parse_image(
  310. input_path=image,
  311. filename=filename,
  312. prompt_mode=self.prompt_mode,
  313. save_dir=temp_dir,
  314. fitz_preprocess=True # 对图片使用 fitz 预处理
  315. )
  316. # 解析结果
  317. if not results:
  318. raise Exception("未返回解析结果")
  319. result = results[0] # parse_image 返回单个结果的列表
  320. # 保存所有结果文件到输出目录
  321. saved_files = self.save_results_to_output_dir(result, image_name, output_dir)
  322. # 验证保存结果
  323. success_count = sum(1 for path in saved_files.values() if path and os.path.exists(path))
  324. if success_count >= 2: # 至少保存了 md 和 json
  325. result_info.update({
  326. "success": True,
  327. "output_files": saved_files
  328. })
  329. else:
  330. raise Exception(f"保存文件不完整 ({success_count}/3)")
  331. finally:
  332. # 清理临时目录
  333. if os.path.exists(temp_dir):
  334. shutil.rmtree(temp_dir, ignore_errors=True)
  335. except Exception as e:
  336. result_info["error"] = str(e)
  337. print(f"❌ Error processing {image_name}: {e}")
  338. finally:
  339. result_info["processing_time"] = time.time() - start_time
  340. return result_info
  341. def process_images_single_process(image_paths: List[str],
  342. processor: DotsOCRProcessor,
  343. batch_size: int = 1,
  344. output_dir: str = "./output") -> List[Dict[str, Any]]:
  345. """
  346. 单进程版本的图像处理函数
  347. Args:
  348. image_paths: 图像路径列表
  349. processor: DotsOCR处理器实例
  350. batch_size: 批处理大小
  351. output_dir: 输出目录
  352. Returns:
  353. 处理结果列表
  354. """
  355. # 创建输出目录
  356. output_path = Path(output_dir)
  357. output_path.mkdir(parents=True, exist_ok=True)
  358. all_results = []
  359. total_images = len(image_paths)
  360. print(f"Processing {total_images} images with batch size {batch_size}")
  361. # 使用tqdm显示进度,添加更多统计信息
  362. with tqdm(total=total_images, desc="Processing images", unit="img",
  363. bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') as pbar:
  364. # 按批次处理图像(DotsOCR通常单张处理)
  365. for i in range(0, total_images, batch_size):
  366. batch = image_paths[i:i + batch_size]
  367. batch_start_time = time.time()
  368. batch_results = []
  369. try:
  370. # 处理批次中的每张图片
  371. for image_path in batch:
  372. try:
  373. result = processor.process_single_image(image_path, output_dir)
  374. batch_results.append(result)
  375. except Exception as e:
  376. print(f"Error processing {image_path}: {e}", file=sys.stderr)
  377. traceback.print_exc()
  378. batch_results.append({
  379. "image_path": image_path,
  380. "processing_time": 0,
  381. "success": False,
  382. "device": f"{processor.ip}:{processor.port}",
  383. "error": str(e)
  384. })
  385. batch_processing_time = time.time() - batch_start_time
  386. all_results.extend(batch_results)
  387. # 更新进度条
  388. success_count = sum(1 for r in batch_results if r.get('success', False))
  389. skipped_count = sum(1 for r in batch_results if r.get('skipped', False))
  390. total_success = sum(1 for r in all_results if r.get('success', False))
  391. total_skipped = sum(1 for r in all_results if r.get('skipped', False))
  392. avg_time = batch_processing_time / len(batch)
  393. pbar.update(len(batch))
  394. pbar.set_postfix({
  395. 'batch_time': f"{batch_processing_time:.2f}s",
  396. 'avg_time': f"{avg_time:.2f}s/img",
  397. 'success': f"{total_success}/{len(all_results)}",
  398. 'skipped': f"{total_skipped}",
  399. 'rate': f"{total_success/len(all_results)*100:.1f}%"
  400. })
  401. except Exception as e:
  402. print(f"Error processing batch {[Path(p).name for p in batch]}: {e}", file=sys.stderr)
  403. traceback.print_exc()
  404. # 为批次中的所有图像添加错误结果
  405. error_results = []
  406. for img_path in batch:
  407. error_results.append({
  408. "image_path": str(img_path),
  409. "processing_time": 0,
  410. "success": False,
  411. "device": f"{processor.ip}:{processor.port}",
  412. "error": str(e)
  413. })
  414. all_results.extend(error_results)
  415. pbar.update(len(batch))
  416. return all_results
  417. def process_images_concurrent(image_paths: List[str],
  418. processor: DotsOCRProcessor,
  419. batch_size: int = 1,
  420. output_dir: str = "./output",
  421. max_workers: int = 3) -> List[Dict[str, Any]]:
  422. """并发版本的图像处理函数"""
  423. from concurrent.futures import ThreadPoolExecutor, as_completed
  424. Path(output_dir).mkdir(parents=True, exist_ok=True)
  425. def process_batch(batch_images):
  426. """处理一批图像"""
  427. batch_results = []
  428. for image_path in batch_images:
  429. try:
  430. result = processor.process_single_image(image_path, output_dir)
  431. batch_results.append(result)
  432. except Exception as e:
  433. batch_results.append({
  434. "image_path": image_path,
  435. "processing_time": 0,
  436. "success": False,
  437. "device": f"{processor.ip}:{processor.port}",
  438. "error": str(e)
  439. })
  440. return batch_results
  441. # 将图像分批
  442. batches = [image_paths[i:i + batch_size] for i in range(0, len(image_paths), batch_size)]
  443. all_results = []
  444. with ThreadPoolExecutor(max_workers=max_workers) as executor:
  445. # 提交所有批次
  446. future_to_batch = {executor.submit(process_batch, batch): batch for batch in batches}
  447. # 使用 tqdm 显示进度
  448. with tqdm(total=len(image_paths), desc="Processing images") as pbar:
  449. for future in as_completed(future_to_batch):
  450. try:
  451. batch_results = future.result()
  452. all_results.extend(batch_results)
  453. # 更新进度
  454. success_count = sum(1 for r in batch_results if r.get('success', False))
  455. pbar.update(len(batch_results))
  456. pbar.set_postfix({'batch_success': f"{success_count}/{len(batch_results)}"})
  457. except Exception as e:
  458. batch = future_to_batch[future]
  459. # 为批次中的所有图像添加错误结果
  460. error_results = [
  461. {
  462. "image_path": img_path,
  463. "processing_time": 0,
  464. "success": False,
  465. "device": f"{processor.ip}:{processor.port}",
  466. "error": str(e)
  467. }
  468. for img_path in batch
  469. ]
  470. all_results.extend(error_results)
  471. pbar.update(len(batch))
  472. return all_results
  473. def main():
  474. """主函数"""
  475. parser = argparse.ArgumentParser(description="DotsOCR OmniDocBench Processing with PDF Support")
  476. # 输入参数组
  477. input_group = parser.add_mutually_exclusive_group(required=True)
  478. input_group.add_argument("--input_file", type=str, help="Input file (supports both PDF and image file)")
  479. input_group.add_argument("--input_dir", type=str, help="Input directory (supports both PDF and image files)")
  480. input_group.add_argument("--input_file_list", type=str, help="Input file list (one file per line)")
  481. input_group.add_argument("--input_csv", type=str, help="Input CSV file with image_path and status columns")
  482. # 输出参数
  483. parser.add_argument("--output_dir", type=str, help="Output directory")
  484. # DotsOCR 参数
  485. parser.add_argument("--ip", type=str, default="127.0.0.1", help="vLLM server IP")
  486. parser.add_argument("--port", type=int, default=8101, help="vLLM server port")
  487. parser.add_argument("--model_name", type=str, default="DotsOCR", help="Model name")
  488. parser.add_argument("--prompt_mode", type=str, default="prompt_layout_all_en",
  489. choices=list(dict_promptmode_to_prompt.keys()), help="Prompt mode")
  490. parser.add_argument("--min_pixels", type=int, default=MIN_PIXELS, help="Minimum pixels")
  491. parser.add_argument("--max_pixels", type=int, default=MAX_PIXELS, help="Maximum pixels")
  492. parser.add_argument("--dpi", type=int, default=200, help="PDF processing DPI")
  493. parser.add_argument('--no-normalize', action='store_true', help='禁用数字标准化')
  494. # 处理参数
  495. parser.add_argument("--batch_size", type=int, default=1, help="Batch size")
  496. parser.add_argument("--input_pattern", type=str, default="*", help="Input file pattern")
  497. parser.add_argument("--test_mode", action="store_true", help="Test mode (process only 10 images)")
  498. parser.add_argument("--collect_results", type=str, help="收集处理结果到指定CSV文件")
  499. # 并发参数
  500. parser.add_argument("--max_workers", type=int, default=3,
  501. help="Maximum number of concurrent workers (should match vLLM data-parallel-size)")
  502. parser.add_argument("--use_threading", action="store_true",
  503. help="Use multi-threading")
  504. args = parser.parse_args()
  505. try:
  506. # 获取并预处理输入文件
  507. print("🔄 Preprocessing input files...")
  508. image_files = get_input_files(args)
  509. if not image_files:
  510. print("❌ No input files found or processed")
  511. return 1
  512. output_dir = Path(args.output_dir).resolve()
  513. print(f"📁 Output dir: {output_dir}")
  514. print(f"📊 Found {len(image_files)} image files to process")
  515. if args.test_mode:
  516. image_files = image_files[:10]
  517. print(f"🧪 Test mode: processing only {len(image_files)} images")
  518. print(f"🌐 Using server: {args.ip}:{args.port}")
  519. print(f"📦 Batch size: {args.batch_size}")
  520. print(f"🎯 Prompt mode: {args.prompt_mode}")
  521. # 创建处理器
  522. processor = DotsOCRProcessor(
  523. ip=args.ip,
  524. port=args.port,
  525. model_name=args.model_name,
  526. prompt_mode=args.prompt_mode,
  527. dpi=args.dpi,
  528. min_pixels=args.min_pixels,
  529. max_pixels=args.max_pixels,
  530. normalize_numbers=not args.no_normalize
  531. )
  532. # 开始处理
  533. start_time = time.time()
  534. # 选择处理方式
  535. if args.use_threading:
  536. results = process_images_concurrent(
  537. image_files,
  538. processor,
  539. args.batch_size,
  540. str(output_dir),
  541. args.max_workers
  542. )
  543. else:
  544. results = process_images_single_process(
  545. image_files,
  546. processor,
  547. args.batch_size,
  548. str(output_dir)
  549. )
  550. total_time = time.time() - start_time
  551. # 统计结果
  552. success_count = sum(1 for r in results if r.get('success', False))
  553. skipped_count = sum(1 for r in results if r.get('skipped', False))
  554. error_count = len(results) - success_count
  555. pdf_page_count = sum(1 for r in results if r.get('is_pdf_page', False))
  556. print(f"\n" + "="*60)
  557. print(f"✅ Processing completed!")
  558. print(f"📊 Statistics:")
  559. print(f" Total files processed: {len(image_files)}")
  560. print(f" PDF pages processed: {pdf_page_count}")
  561. print(f" Regular images processed: {len(image_files) - pdf_page_count}")
  562. print(f" Successful: {success_count}")
  563. print(f" Skipped: {skipped_count}")
  564. print(f" Failed: {error_count}")
  565. if len(image_files) > 0:
  566. print(f" Success rate: {success_count / len(image_files) * 100:.2f}%")
  567. print(f"⏱️ Performance:")
  568. print(f" Total time: {total_time:.2f} seconds")
  569. if total_time > 0:
  570. print(f" Throughput: {len(image_files) / total_time:.2f} images/second")
  571. print(f" Avg time per image: {total_time / len(image_files):.2f} seconds")
  572. # 保存结果统计
  573. stats = {
  574. "total_files": len(image_files),
  575. "pdf_pages": pdf_page_count,
  576. "regular_images": len(image_files) - pdf_page_count,
  577. "success_count": success_count,
  578. "skipped_count": skipped_count,
  579. "error_count": error_count,
  580. "success_rate": success_count / len(image_files) if len(image_files) > 0 else 0,
  581. "total_time": total_time,
  582. "throughput": len(image_files) / total_time if total_time > 0 else 0,
  583. "avg_time_per_image": total_time / len(image_files) if len(image_files) > 0 else 0,
  584. "batch_size": args.batch_size,
  585. "server": f"{args.ip}:{args.port}",
  586. "model": args.model_name,
  587. "prompt_mode": args.prompt_mode,
  588. "pdf_dpi": args.dpi,
  589. "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
  590. }
  591. # 保存最终结果
  592. output_file_name = Path(output_dir).name
  593. output_file = os.path.join(output_dir, f"{output_file_name}_results.json")
  594. final_results = {
  595. "stats": stats,
  596. "results": results
  597. }
  598. with open(output_file, 'w', encoding='utf-8') as f:
  599. json.dump(final_results, f, ensure_ascii=False, indent=2)
  600. print(f"💾 Results saved to: {output_file}")
  601. # 收集处理结果
  602. if not args.collect_results:
  603. output_file_processed = Path(args.output_dir) / f"processed_files_{time.strftime('%Y%m%d_%H%M%S')}.csv"
  604. else:
  605. output_file_processed = Path(args.collect_results).resolve()
  606. processed_files = collect_pid_files(output_file)
  607. with open(output_file_processed, 'w', encoding='utf-8') as f:
  608. f.write("image_path,status\n")
  609. for file_path, status in processed_files:
  610. f.write(f"{file_path},{status}\n")
  611. print(f"💾 Processed files saved to: {output_file_processed}")
  612. return 0
  613. except Exception as e:
  614. print(f"❌ Processing failed: {e}", file=sys.stderr)
  615. traceback.print_exc()
  616. return 1
  617. if __name__ == "__main__":
  618. print(f"🚀 启动DotsOCR统一PDF/图像处理程序...")
  619. print(f"🔧 CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', 'Not set')}")
  620. if len(sys.argv) == 1:
  621. # 如果没有命令行参数,使用默认配置运行
  622. print("ℹ️ No command line arguments provided. Running with default configuration...")
  623. # 默认配置
  624. default_config = {
  625. "input_file": "./sample_data/2023年度报告母公司_page_003.png",
  626. "output_dir": "./sample_data",
  627. "collect_results": "./sample_data/processed_files.csv",
  628. # "input_dir": "../../OmniDocBench/OpenDataLab___OmniDocBench/images",
  629. # "output_dir": "./OmniDocBench_DotsOCR_Results",
  630. # "collect_results": "./OmniDocBench_DotsOCR_Results/processed_files.csv",
  631. "ip": "10.192.72.11",
  632. # "ip": "127.0.0.1",
  633. "port": "8101",
  634. "model_name": "DotsOCR",
  635. "prompt_mode": "prompt_layout_all_en",
  636. "batch_size": "1",
  637. "max_workers": "3",
  638. "dpi": "200",
  639. }
  640. # 如果需要处理失败的文件,可以使用这个配置
  641. # default_config = {
  642. # "input_csv": "./OmniDocBench_DotsOCR_Results/processed_files.csv",
  643. # "output_dir": "./OmniDocBench_DotsOCR_Results",
  644. # "ip": "127.0.0.1",
  645. # "port": "8101",
  646. # "collect_results": f"./OmniDocBench_DotsOCR_Results/processed_files_{time.strftime('%Y%m%d_%H%M%S')}.csv",
  647. # }
  648. # 构造参数
  649. sys.argv = [sys.argv[0]]
  650. for key, value in default_config.items():
  651. sys.argv.extend([f"--{key}", str(value)])
  652. # 测试模式
  653. sys.argv.append("--use_threading")
  654. # sys.argv.append("--test_mode")
  655. sys.exit(main())