main.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560
  1. #!/usr/bin/env python3
  2. """
  3. 批量处理图片/PDF文件并生成符合评测要求的预测结果(DotsOCR版本)
  4. 根据 OmniDocBench 评测要求:
  5. - 输入:支持 PDF 和各种图片格式(统一使用 --input 参数)
  6. - 输出:每个文件对应的 .md、.json 和带标注的 layout 图片文件
  7. - 调用方式:通过 DotsOCR vLLM 服务器处理
  8. 使用方法:
  9. python main.py --input document.pdf --output_dir ./output
  10. python main.py --input ./images/ --output_dir ./output
  11. python main.py --input file_list.txt --output_dir ./output
  12. python main.py --input results.csv --output_dir ./output --dry_run
  13. """
  14. import os
  15. import sys
  16. import json
  17. import time
  18. import traceback
  19. from pathlib import Path
  20. from typing import List, Dict, Any
  21. from tqdm import tqdm
  22. import argparse
  23. from loguru import logger
  24. # 导入 ocr_utils
  25. ocr_platform_root = Path(__file__).parents[2]
  26. if str(ocr_platform_root) not in sys.path:
  27. sys.path.insert(0, str(ocr_platform_root))
  28. from ocr_utils import (
  29. get_input_files,
  30. collect_pid_files,
  31. setup_logging
  32. )
  33. # 导入处理器
  34. try:
  35. from .processor import DotsOCRProcessor
  36. except ImportError:
  37. from processor import DotsOCRProcessor
  38. # 导入 dots.ocr 相关模块
  39. from dots_ocr.utils import dict_promptmode_to_prompt
  40. from dots_ocr.utils.consts import MIN_PIXELS, MAX_PIXELS
  41. def process_images_single_process(
  42. image_paths: List[str],
  43. processor: DotsOCRProcessor,
  44. batch_size: int = 1,
  45. output_dir: str = "./output"
  46. ) -> List[Dict[str, Any]]:
  47. """
  48. 单进程版本的图像处理函数
  49. Args:
  50. image_paths: 图像文件路径列表
  51. processor: DotsOCR处理器实例
  52. batch_size: 批处理大小
  53. output_dir: 输出目录
  54. Returns:
  55. 处理结果列表
  56. """
  57. # 创建输出目录
  58. output_path = Path(output_dir)
  59. output_path.mkdir(parents=True, exist_ok=True)
  60. all_results = []
  61. total_images = len(image_paths)
  62. logger.info(f"Processing {total_images} images with batch size {batch_size}")
  63. with tqdm(total=total_images, desc="Processing images", unit="img",
  64. bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') as pbar:
  65. for i in range(0, total_images, batch_size):
  66. batch = image_paths[i:i + batch_size]
  67. batch_start_time = time.time()
  68. batch_results = []
  69. try:
  70. for image_path in batch:
  71. try:
  72. result = processor.process_single_image(image_path, output_dir)
  73. batch_results.append(result)
  74. except Exception as e:
  75. logger.error(f"Error processing {image_path}: {e}")
  76. batch_results.append({
  77. "image_path": image_path,
  78. "processing_time": 0,
  79. "success": False,
  80. "device": f"{processor.ip}:{processor.port}",
  81. "error": str(e)
  82. })
  83. batch_processing_time = time.time() - batch_start_time
  84. all_results.extend(batch_results)
  85. # 更新进度条
  86. success_count = sum(1 for r in batch_results if r.get('success', False))
  87. skipped_count = sum(1 for r in batch_results if r.get('skipped', False))
  88. total_success = sum(1 for r in all_results if r.get('success', False))
  89. total_skipped = sum(1 for r in all_results if r.get('skipped', False))
  90. avg_time = batch_processing_time / len(batch) if len(batch) > 0 else 0
  91. pbar.update(len(batch))
  92. pbar.set_postfix({
  93. 'batch_time': f"{batch_processing_time:.2f}s",
  94. 'avg_time': f"{avg_time:.2f}s/img",
  95. 'success': f"{total_success}/{len(all_results)}",
  96. 'skipped': f"{total_skipped}",
  97. 'rate': f"{total_success/len(all_results)*100:.1f}%" if len(all_results) > 0 else "0%"
  98. })
  99. except Exception as e:
  100. logger.error(f"Error processing batch {[Path(p).name for p in batch]}: {e}")
  101. error_results = []
  102. for img_path in batch:
  103. error_results.append({
  104. "image_path": str(img_path),
  105. "processing_time": 0,
  106. "success": False,
  107. "device": f"{processor.ip}:{processor.port}",
  108. "error": str(e)
  109. })
  110. all_results.extend(error_results)
  111. pbar.update(len(batch))
  112. return all_results
  113. def process_images_concurrent(
  114. image_paths: List[str],
  115. processor: DotsOCRProcessor,
  116. batch_size: int = 1,
  117. output_dir: str = "./output",
  118. max_workers: int = 3
  119. ) -> List[Dict[str, Any]]:
  120. """并发版本的图像处理函数"""
  121. from concurrent.futures import ThreadPoolExecutor, as_completed
  122. Path(output_dir).mkdir(parents=True, exist_ok=True)
  123. def process_batch(batch_images):
  124. """处理一批图像"""
  125. batch_results = []
  126. for image_path in batch_images:
  127. try:
  128. result = processor.process_single_image(image_path, output_dir)
  129. batch_results.append(result)
  130. except Exception as e:
  131. batch_results.append({
  132. "image_path": image_path,
  133. "processing_time": 0,
  134. "success": False,
  135. "device": f"{processor.ip}:{processor.port}",
  136. "error": str(e)
  137. })
  138. return batch_results
  139. # 将图像分批
  140. batches = [image_paths[i:i + batch_size] for i in range(0, len(image_paths), batch_size)]
  141. all_results = []
  142. with ThreadPoolExecutor(max_workers=max_workers) as executor:
  143. # 提交所有批次
  144. future_to_batch = {executor.submit(process_batch, batch): batch for batch in batches}
  145. # 使用 tqdm 显示进度
  146. with tqdm(total=len(image_paths), desc="Processing images") as pbar:
  147. for future in as_completed(future_to_batch):
  148. try:
  149. batch_results = future.result()
  150. all_results.extend(batch_results)
  151. # 更新进度
  152. success_count = sum(1 for r in batch_results if r.get('success', False))
  153. pbar.update(len(batch_results))
  154. pbar.set_postfix({'batch_success': f"{success_count}/{len(batch_results)}"})
  155. except Exception as e:
  156. batch = future_to_batch[future]
  157. # 为批次中的所有图像添加错误结果
  158. error_results = [
  159. {
  160. "image_path": img_path,
  161. "processing_time": 0,
  162. "success": False,
  163. "device": f"{processor.ip}:{processor.port}",
  164. "error": str(e)
  165. }
  166. for img_path in batch
  167. ]
  168. all_results.extend(error_results)
  169. pbar.update(len(batch))
  170. return all_results
  171. def main():
  172. """主函数"""
  173. parser = argparse.ArgumentParser(
  174. description="DotsOCR vLLM Batch Processing",
  175. formatter_class=argparse.RawDescriptionHelpFormatter,
  176. epilog="""
  177. 示例:
  178. # 处理单个PDF文件
  179. python main.py --input document.pdf --output_dir ./output
  180. # 处理图片目录
  181. python main.py --input ./images/ --output_dir ./output
  182. # 处理文件列表
  183. python main.py --input file_list.txt --output_dir ./output
  184. # 处理CSV文件(失败的文件)
  185. python main.py --input results.csv --output_dir ./output
  186. # 指定页面范围(仅PDF)
  187. python main.py --input document.pdf --output_dir ./output --pages "1-5,7"
  188. # 仅验证配置(dry run)
  189. python main.py --input document.pdf --output_dir ./output --dry_run
  190. # 使用 DEBUG 日志级别获取详细错误信息
  191. python main.py --input document.pdf --output_dir ./output --log_level DEBUG
  192. """
  193. )
  194. # 输入参数(统一使用 --input)
  195. parser.add_argument(
  196. "--input", "-i",
  197. required=True,
  198. type=str,
  199. help="输入路径(支持PDF文件、图片文件、图片目录、文件列表.txt、CSV文件)"
  200. )
  201. # 输出参数
  202. parser.add_argument(
  203. "--output_dir", "-o",
  204. type=str,
  205. required=True,
  206. help="输出目录"
  207. )
  208. # DotsOCR vLLM 参数
  209. parser.add_argument(
  210. "--ip",
  211. type=str,
  212. default="10.192.72.11",
  213. help="vLLM 服务器 IP"
  214. )
  215. parser.add_argument(
  216. "--port",
  217. type=int,
  218. default=8101,
  219. help="vLLM 服务器端口"
  220. )
  221. parser.add_argument(
  222. "--model_name",
  223. type=str,
  224. default="DotsOCR",
  225. help="模型名称"
  226. )
  227. parser.add_argument(
  228. "--prompt_mode",
  229. type=str,
  230. default="prompt_layout_all_en",
  231. choices=list(dict_promptmode_to_prompt.keys()),
  232. help="提示模式"
  233. )
  234. parser.add_argument(
  235. "--min_pixels",
  236. type=int,
  237. default=MIN_PIXELS,
  238. help="最小像素数"
  239. )
  240. parser.add_argument(
  241. "--max_pixels",
  242. type=int,
  243. default=MAX_PIXELS,
  244. help="最大像素数"
  245. )
  246. parser.add_argument(
  247. "--dpi",
  248. type=int,
  249. default=200,
  250. help="PDF 转图片的 DPI"
  251. )
  252. parser.add_argument(
  253. '--no-normalize',
  254. action='store_true',
  255. help='禁用数字标准化'
  256. )
  257. # 处理参数
  258. parser.add_argument(
  259. "--batch_size",
  260. type=int,
  261. default=1,
  262. help="Batch size"
  263. )
  264. parser.add_argument(
  265. "--pages", "-p",
  266. type=str,
  267. help="页面范围(PDF和图片目录有效),如: '1-5,7,9-12', '1-', '-10'"
  268. )
  269. parser.add_argument(
  270. "--collect_results",
  271. type=str,
  272. help="收集处理结果到指定CSV文件"
  273. )
  274. # 并发参数
  275. parser.add_argument(
  276. "--max_workers",
  277. type=int,
  278. default=3,
  279. help="Maximum number of concurrent workers (should match vLLM data-parallel-size)"
  280. )
  281. parser.add_argument(
  282. "--use_threading",
  283. action="store_true",
  284. help="Use multi-threading"
  285. )
  286. # 日志参数
  287. parser.add_argument(
  288. "--log_level",
  289. default="INFO",
  290. choices=["DEBUG", "INFO", "WARNING", "ERROR"],
  291. help="日志级别(默认: INFO)"
  292. )
  293. parser.add_argument(
  294. "--log_file",
  295. type=str,
  296. help="日志文件路径"
  297. )
  298. # Dry run 参数
  299. parser.add_argument(
  300. "--dry_run",
  301. action="store_true",
  302. help="仅验证配置和输入,不执行实际处理"
  303. )
  304. args = parser.parse_args()
  305. # 设置日志
  306. setup_logging(args.log_level, args.log_file)
  307. try:
  308. # 创建参数对象(用于 get_input_files)
  309. class Args:
  310. def __init__(self, input_path, output_dir, pdf_dpi):
  311. self.input = input_path
  312. self.output_dir = output_dir
  313. self.pdf_dpi = pdf_dpi
  314. args_obj = Args(args.input, args.output_dir, args.dpi)
  315. # 获取并预处理输入文件(页面范围过滤已在 get_input_files 中处理)
  316. logger.info("🔄 Preprocessing input files...")
  317. if args.pages:
  318. logger.info(f"📄 页面范围: {args.pages}")
  319. image_files = get_input_files(args_obj, page_range=args.pages)
  320. if not image_files:
  321. logger.error("❌ No input files found or processed")
  322. return 1
  323. output_dir = Path(args.output_dir).resolve()
  324. logger.info(f"📁 Output dir: {output_dir}")
  325. logger.info(f"📊 Found {len(image_files)} image files to process")
  326. # Dry run 模式
  327. if args.dry_run:
  328. logger.info("🔍 Dry run mode: 仅验证配置,不执行处理")
  329. logger.info(f"📋 配置信息:")
  330. logger.info(f" - 输入: {args.input}")
  331. logger.info(f" - 输出目录: {output_dir}")
  332. logger.info(f" - 服务器: {args.ip}:{args.port}")
  333. logger.info(f" - 模型: {args.model_name}")
  334. logger.info(f" - 提示模式: {args.prompt_mode}")
  335. logger.info(f" - 批次大小: {args.batch_size}")
  336. logger.info(f" - PDF DPI: {args.dpi}")
  337. logger.info(f" - 数字标准化: {not args.no_normalize}")
  338. logger.info(f" - 日志级别: {args.log_level}")
  339. if args.pages:
  340. logger.info(f" - 页面范围: {args.pages}")
  341. if args.use_threading:
  342. logger.info(f" - 并发工作数: {args.max_workers}")
  343. logger.info(f"📋 将要处理的文件 ({len(image_files)} 个):")
  344. for i, img_file in enumerate(image_files[:20], 1): # 只显示前20个
  345. logger.info(f" {i}. {img_file}")
  346. if len(image_files) > 20:
  347. logger.info(f" ... 还有 {len(image_files) - 20} 个文件")
  348. logger.info("✅ Dry run 完成:配置验证通过")
  349. return 0
  350. logger.info(f"🌐 Using server: {args.ip}:{args.port}")
  351. logger.info(f"📦 Batch size: {args.batch_size}")
  352. logger.info(f"🎯 Prompt mode: {args.prompt_mode}")
  353. # 创建处理器
  354. processor = DotsOCRProcessor(
  355. ip=args.ip,
  356. port=args.port,
  357. model_name=args.model_name,
  358. prompt_mode=args.prompt_mode,
  359. dpi=args.dpi,
  360. min_pixels=args.min_pixels,
  361. max_pixels=args.max_pixels,
  362. normalize_numbers=not args.no_normalize,
  363. log_level=args.log_level
  364. )
  365. # 开始处理
  366. start_time = time.time()
  367. # 选择处理方式
  368. if args.use_threading:
  369. results = process_images_concurrent(
  370. image_files,
  371. processor,
  372. args.batch_size,
  373. str(output_dir),
  374. args.max_workers
  375. )
  376. else:
  377. results = process_images_single_process(
  378. image_files,
  379. processor,
  380. args.batch_size,
  381. str(output_dir)
  382. )
  383. total_time = time.time() - start_time
  384. # 统计结果
  385. success_count = sum(1 for r in results if r.get('success', False))
  386. skipped_count = sum(1 for r in results if r.get('skipped', False))
  387. error_count = len(results) - success_count
  388. pdf_page_count = sum(1 for r in results if r.get('is_pdf_page', False))
  389. print(f"\n" + "="*60)
  390. print(f"✅ Processing completed!")
  391. print(f"📊 Statistics:")
  392. print(f" Total files processed: {len(image_files)}")
  393. print(f" PDF pages processed: {pdf_page_count}")
  394. print(f" Regular images processed: {len(image_files) - pdf_page_count}")
  395. print(f" Successful: {success_count}")
  396. print(f" Skipped: {skipped_count}")
  397. print(f" Failed: {error_count}")
  398. if len(image_files) > 0:
  399. print(f" Success rate: {success_count / len(image_files) * 100:.2f}%")
  400. print(f"⏱️ Performance:")
  401. print(f" Total time: {total_time:.2f} seconds")
  402. if total_time > 0:
  403. print(f" Throughput: {len(image_files) / total_time:.2f} images/second")
  404. print(f" Avg time per image: {total_time / len(image_files):.2f} seconds")
  405. print(f"\n📁 Output Structure:")
  406. print(f" output_dir/")
  407. print(f" ├── filename.md # Markdown content")
  408. print(f" ├── filename.json # Layout info JSON")
  409. print(f" └── filename_layout.jpg # Layout visualization")
  410. # 保存结果统计
  411. stats = {
  412. "total_files": len(image_files),
  413. "pdf_pages": pdf_page_count,
  414. "regular_images": len(image_files) - pdf_page_count,
  415. "success_count": success_count,
  416. "skipped_count": skipped_count,
  417. "error_count": error_count,
  418. "success_rate": success_count / len(image_files) if len(image_files) > 0 else 0,
  419. "total_time": total_time,
  420. "throughput": len(image_files) / total_time if total_time > 0 else 0,
  421. "avg_time_per_image": total_time / len(image_files) if len(image_files) > 0 else 0,
  422. "batch_size": args.batch_size,
  423. "server": f"{args.ip}:{args.port}",
  424. "model": args.model_name,
  425. "prompt_mode": args.prompt_mode,
  426. "pdf_dpi": args.dpi,
  427. "normalization_enabled": not args.no_normalize,
  428. "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
  429. }
  430. # 保存最终结果
  431. output_file_name = Path(output_dir).name
  432. output_file = output_dir / f"{output_file_name}_results.json"
  433. final_results = {
  434. "stats": stats,
  435. "results": results
  436. }
  437. with open(output_file, 'w', encoding='utf-8') as f:
  438. json.dump(final_results, f, ensure_ascii=False, indent=2)
  439. logger.info(f"💾 Results saved to: {output_file}")
  440. # 收集处理结果
  441. if not args.collect_results:
  442. output_file_processed = output_dir / f"processed_files_{time.strftime('%Y%m%d_%H%M%S')}.csv"
  443. else:
  444. output_file_processed = Path(args.collect_results).resolve()
  445. processed_files = collect_pid_files(str(output_file))
  446. with open(output_file_processed, 'w', encoding='utf-8') as f:
  447. f.write("image_path,status\n")
  448. for file_path, status in processed_files:
  449. f.write(f"{file_path},{status}\n")
  450. logger.info(f"💾 Processed files saved to: {output_file_processed}")
  451. return 0
  452. except Exception as e:
  453. logger.error(f"Processing failed: {e}")
  454. traceback.print_exc()
  455. return 1
  456. if __name__ == "__main__":
  457. logger.info(f"🚀 启动DotsOCR vLLM统一PDF/图像处理程序...")
  458. logger.info(f"🔧 CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', 'Not set')}")
  459. if len(sys.argv) == 1:
  460. # 如果没有命令行参数,使用默认配置运行
  461. logger.info("ℹ️ No command line arguments provided. Running with default configuration...")
  462. # 默认配置
  463. default_config = {
  464. # "input": "/Users/zhch158/workspace/data/流水分析/马公账流水_工商银行.pdf",
  465. "input": "/Users/zhch158/workspace/repository.git/ocr_platform/ocr_tools/dots.ocr_vl_tool/output/processed_files_20251218_164332.csv",
  466. "output_dir": "./output",
  467. "ip": "10.192.72.11",
  468. "port": "8101",
  469. "model_name": "DotsOCR",
  470. "prompt_mode": "prompt_layout_all_en",
  471. "batch_size": "1",
  472. "dpi": "200",
  473. "pages": "-2",
  474. }
  475. # 构造参数
  476. sys.argv = [sys.argv[0]]
  477. for key, value in default_config.items():
  478. sys.argv.extend([f"--{key}", str(value)])
  479. sys.exit(main())