omnidocbench_parallel_eval.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566
  1. # zhch/omnidocbench_parallel_eval.py
  2. import json
  3. import time
  4. import os
  5. import glob
  6. import traceback
  7. from pathlib import Path
  8. from typing import List, Dict, Any, Tuple
  9. from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
  10. from multiprocessing import Queue, Manager
  11. import cv2
  12. import numpy as np
  13. from paddlex import create_pipeline
  14. from tqdm import tqdm
  15. import threading
  16. class OmniDocBenchParallelEvaluator:
  17. """
  18. OmniDocBench并行评估器,支持多进程批处理
  19. """
  20. def __init__(self, pipeline_config_path: str = "PP-StructureV3"):
  21. """
  22. 初始化评估器
  23. Args:
  24. pipeline_config_path: PaddleX pipeline配置文件路径
  25. """
  26. self.pipeline_config = pipeline_config_path
  27. self.category_mapping = self._get_category_mapping()
  28. def _get_category_mapping(self) -> Dict[str, str]:
  29. """获取PaddleX类别到OmniDocBench类别的映射"""
  30. return {
  31. 'title': 'title',
  32. 'text': 'text_block',
  33. 'figure': 'figure',
  34. 'figure_caption': 'figure_caption',
  35. 'table': 'table',
  36. 'table_caption': 'table_caption',
  37. 'equation': 'equation_isolated',
  38. 'header': 'header',
  39. 'footer': 'footer',
  40. 'reference': 'reference',
  41. 'seal': 'abandon',
  42. 'number': 'page_number',
  43. }
  44. def create_pipeline(self):
  45. """创建pipeline实例(每个进程单独创建)"""
  46. return create_pipeline(pipeline=self.pipeline_config)
  47. def process_single_image(self, image_path: str, use_gpu: bool = True) -> Dict[str, Any]:
  48. """
  49. 处理单张图像
  50. Args:
  51. image_path: 图像路径
  52. use_gpu: 是否使用GPU
  53. Returns:
  54. OmniDocBench格式的结果字典
  55. """
  56. try:
  57. # 每个进程创建自己的pipeline
  58. pipeline = self.create_pipeline()
  59. # 读取图像获取尺寸信息
  60. image = cv2.imread(image_path)
  61. if image is None:
  62. return None
  63. height, width = image.shape[:2]
  64. # 运行PaddleX pipeline
  65. start_time = time.time()
  66. output = list(pipeline.predict(
  67. input=image_path,
  68. device="gpu" if use_gpu else "cpu",
  69. use_doc_orientation_classify=True,
  70. use_doc_unwarping=False,
  71. use_seal_recognition=True,
  72. use_chart_recognition=True,
  73. use_table_recognition=True,
  74. use_formula_recognition=True,
  75. ))
  76. process_time = time.time() - start_time
  77. # 转换为OmniDocBench格式
  78. result = self._convert_to_omnidocbench_format(
  79. output, image_path, width, height
  80. )
  81. # 添加处理时间信息
  82. if result:
  83. result["processing_time"] = process_time
  84. result["success"] = True
  85. return result
  86. except Exception as e:
  87. return {
  88. "image_path": Path(image_path).name,
  89. "error": str(e),
  90. "success": False,
  91. "processing_time": 0
  92. }
  93. def process_batch(self, image_paths: List[str], use_gpu: bool = True) -> List[Dict[str, Any]]:
  94. """
  95. 批处理图像
  96. Args:
  97. image_paths: 图像路径列表
  98. use_gpu: 是否使用GPU
  99. Returns:
  100. 结果列表
  101. """
  102. results = []
  103. pipeline = self.create_pipeline()
  104. for image_path in image_paths:
  105. try:
  106. result = self._process_with_pipeline(pipeline, image_path, use_gpu)
  107. if result:
  108. results.append(result)
  109. except Exception as e:
  110. error_result = {
  111. "image_path": Path(image_path).name,
  112. "error": str(e),
  113. "success": False,
  114. "processing_time": 0
  115. }
  116. results.append(error_result)
  117. return results
  118. def _process_with_pipeline(self, pipeline, image_path: str, use_gpu: bool) -> Dict[str, Any]:
  119. """使用给定的pipeline处理图像"""
  120. # 读取图像获取尺寸信息
  121. image = cv2.imread(image_path)
  122. if image is None:
  123. return None
  124. height, width = image.shape[:2]
  125. # 运行pipeline
  126. start_time = time.time()
  127. output = list(pipeline.predict(
  128. input=image_path,
  129. device="gpu" if use_gpu else "cpu",
  130. use_doc_orientation_classify=True,
  131. use_doc_unwarping=False,
  132. use_seal_recognition=True,
  133. use_chart_recognition=True,
  134. use_table_recognition=True,
  135. use_formula_recognition=True,
  136. ))
  137. process_time = time.time() - start_time
  138. # 转换格式
  139. result = self._convert_to_omnidocbench_format(
  140. output, image_path, width, height
  141. )
  142. if result:
  143. result["processing_time"] = process_time
  144. result["success"] = True
  145. return result
  146. def _convert_to_omnidocbench_format(self,
  147. paddlex_output: List,
  148. image_path: str,
  149. width: int,
  150. height: int) -> Dict[str, Any]:
  151. """将PaddleX输出转换为OmniDocBench格式"""
  152. layout_dets = []
  153. anno_id_counter = 0
  154. # 处理PaddleX的输出
  155. for res in paddlex_output:
  156. res_json = res.json.get('res', {})
  157. parsing_list = res_json.get('parsing_res_list', [])
  158. for item in parsing_list:
  159. bbox = item.get('block_bbox', [])
  160. category = item.get('block_label', 'text_block')
  161. content = item.get('block_content', '')
  162. # 转换bbox格式
  163. if len(bbox) == 4:
  164. x1, y1, x2, y2 = bbox
  165. poly = [x1, y1, x2, y1, x2, y2, x1, y2]
  166. else:
  167. poly = bbox
  168. # 映射类别
  169. omni_category = self.category_mapping.get(category, 'text_block')
  170. # 创建layout检测结果
  171. layout_det = {
  172. "category_type": omni_category,
  173. "poly": poly,
  174. "ignore": False,
  175. "order": anno_id_counter,
  176. "anno_id": anno_id_counter,
  177. }
  178. # 添加内容
  179. if content and content.strip():
  180. if omni_category == 'table':
  181. layout_det["html"] = content
  182. else:
  183. layout_det["text"] = content.strip()
  184. # 添加属性
  185. layout_det["attribute"] = self._extract_attributes(item, omni_category)
  186. layout_det["line_with_spans"] = []
  187. layout_dets.append(layout_det)
  188. anno_id_counter += 1
  189. # 构建完整结果
  190. result = {
  191. "layout_dets": layout_dets,
  192. "page_info": {
  193. "page_no": 0,
  194. "height": height,
  195. "width": width,
  196. "image_path": Path(image_path).name,
  197. "page_attribute": {
  198. "data_source": "research_report",
  199. "language": "simplified_chinese",
  200. "layout": "single_column",
  201. "watermark": False,
  202. "fuzzy_scan": False,
  203. "colorful_backgroud": False
  204. }
  205. },
  206. "extra": {
  207. "relation": []
  208. }
  209. }
  210. return result
  211. def _extract_attributes(self, item: Dict, category: str) -> Dict:
  212. """提取属性标签"""
  213. attributes = {}
  214. if category == 'table':
  215. attributes.update({
  216. "table_layout": "vertical",
  217. "with_span": False,
  218. "line": "full_line",
  219. "language": "table_simplified_chinese",
  220. "include_equation": False,
  221. "include_backgroud": False,
  222. "table_vertical": False
  223. })
  224. content = item.get('block_content', '')
  225. if 'colspan' in content or 'rowspan' in content:
  226. attributes["with_span"] = True
  227. elif category in ['text_block', 'title']:
  228. attributes.update({
  229. "text_language": "text_simplified_chinese",
  230. "text_background": "white",
  231. "text_rotate": "normal"
  232. })
  233. elif 'equation' in category:
  234. attributes.update({
  235. "formula_type": "print"
  236. })
  237. return attributes
  238. def parallel_process_with_threading(self,
  239. image_paths: List[str],
  240. batch_size: int = 4,
  241. max_workers: int = 4,
  242. use_gpu: bool = True) -> List[Dict[str, Any]]:
  243. """
  244. 使用多线程并行处理(推荐用于GPU)
  245. Args:
  246. image_paths: 图像路径列表
  247. batch_size: 批处理大小
  248. max_workers: 最大工作线程数
  249. use_gpu: 是否使用GPU
  250. Returns:
  251. 处理结果列表
  252. """
  253. # 将图像路径分批
  254. batches = [image_paths[i:i + batch_size] for i in range(0, len(image_paths), batch_size)]
  255. all_results = []
  256. completed_count = 0
  257. total_images = len(image_paths)
  258. # 创建进度条
  259. with tqdm(total=total_images, desc="处理图像", unit="张") as pbar:
  260. with ThreadPoolExecutor(max_workers=max_workers) as executor:
  261. # 提交所有批处理任务
  262. future_to_batch = {
  263. executor.submit(self.process_batch, batch, use_gpu): batch
  264. for batch in batches
  265. }
  266. # 收集结果
  267. for future in as_completed(future_to_batch):
  268. batch = future_to_batch[future]
  269. try:
  270. batch_results = future.result()
  271. all_results.extend(batch_results)
  272. completed_count += len(batch)
  273. pbar.update(len(batch))
  274. # 更新进度条描述
  275. success_count = sum(1 for r in batch_results if r.get('success', False))
  276. pbar.set_postfix({
  277. 'batch_success': f"{success_count}/{len(batch)}",
  278. 'total_success': f"{sum(1 for r in all_results if r.get('success', False))}/{completed_count}"
  279. })
  280. except Exception as e:
  281. print(f"批处理失败: {e}")
  282. # 为失败的批次创建错误结果
  283. for img_path in batch:
  284. error_result = {
  285. "image_path": Path(img_path).name,
  286. "error": str(e),
  287. "success": False,
  288. "processing_time": 0
  289. }
  290. all_results.append(error_result)
  291. pbar.update(len(batch))
  292. return all_results
  293. def parallel_process_with_multiprocessing(self,
  294. image_paths: List[str],
  295. batch_size: int = 4,
  296. max_workers: int = 4,
  297. use_gpu: bool = False) -> List[Dict[str, Any]]:
  298. """
  299. 使用多进程并行处理(推荐用于CPU)
  300. Args:
  301. image_paths: 图像路径列表
  302. batch_size: 批处理大小
  303. max_workers: 最大工作进程数
  304. use_gpu: 是否使用GPU
  305. Returns:
  306. 处理结果列表
  307. """
  308. # 将图像路径分批
  309. batches = [image_paths[i:i + batch_size] for i in range(0, len(image_paths), batch_size)]
  310. all_results = []
  311. completed_count = 0
  312. total_images = len(image_paths)
  313. # 创建进度条
  314. with tqdm(total=total_images, desc="处理图像", unit="张") as pbar:
  315. with ProcessPoolExecutor(max_workers=max_workers) as executor:
  316. # 提交所有批处理任务
  317. future_to_batch = {
  318. executor.submit(process_batch_worker, batch, self.pipeline_config, use_gpu): batch
  319. for batch in batches
  320. }
  321. # 收集结果
  322. for future in as_completed(future_to_batch):
  323. batch = future_to_batch[future]
  324. try:
  325. batch_results = future.result()
  326. all_results.extend(batch_results)
  327. completed_count += len(batch)
  328. pbar.update(len(batch))
  329. # 更新进度条描述
  330. success_count = sum(1 for r in batch_results if r.get('success', False))
  331. pbar.set_postfix({
  332. 'batch_success': f"{success_count}/{len(batch)}",
  333. 'total_success': f"{sum(1 for r in all_results if r.get('success', False))}/{completed_count}"
  334. })
  335. except Exception as e:
  336. print(f"批处理失败: {e}")
  337. # 为失败的批次创建错误结果
  338. for img_path in batch:
  339. error_result = {
  340. "image_path": Path(img_path).name,
  341. "error": str(e),
  342. "success": False,
  343. "processing_time": 0
  344. }
  345. all_results.append(error_result)
  346. pbar.update(len(batch))
  347. return all_results
  348. def save_results_incrementally(self,
  349. results: List[Dict[str, Any]],
  350. output_file: str,
  351. save_interval: int = 50):
  352. """
  353. 增量保存结果
  354. Args:
  355. results: 结果列表
  356. output_file: 输出文件路径
  357. save_interval: 保存间隔
  358. """
  359. if len(results) % save_interval == 0 and len(results) > 0:
  360. try:
  361. with open(output_file, 'w', encoding='utf-8') as f:
  362. json.dump(results, f, ensure_ascii=False, indent=2)
  363. print(f"已保存 {len(results)} 个结果到 {output_file}")
  364. except Exception as e:
  365. print(f"保存结果时出错: {e}")
  366. def process_batch_worker(image_paths: List[str], pipeline_config: str, use_gpu: bool) -> List[Dict[str, Any]]:
  367. """
  368. 多进程工作函数
  369. """
  370. try:
  371. # 在每个进程中创建评估器
  372. evaluator = OmniDocBenchParallelEvaluator(pipeline_config)
  373. return evaluator.process_batch(image_paths, use_gpu)
  374. except Exception as e:
  375. # 返回错误结果
  376. error_results = []
  377. for img_path in image_paths:
  378. error_results.append({
  379. "image_path": Path(img_path).name,
  380. "error": str(e),
  381. "success": False,
  382. "processing_time": 0
  383. })
  384. return error_results
  385. def main():
  386. """主函数 - 并行处理OmniDocBench数据集"""
  387. # 配置参数
  388. dataset_path = "/Users/zhch158/workspace/repository.git/OmniDocBench/OpenDataLab___OmniDocBench/images"
  389. output_dir = "/Users/zhch158/workspace/repository.git/PaddleX/zhch/OmniDocBench_Results"
  390. pipeline_config = "PP-StructureV3"
  391. # 并行处理参数
  392. batch_size = 4 # 批处理大小
  393. max_workers = 4 # 最大工作进程/线程数
  394. use_gpu = True # 是否使用GPU
  395. use_multiprocessing = False # False=多线程(GPU推荐), True=多进程(CPU推荐)
  396. # 确保输出目录存在
  397. os.makedirs(output_dir, exist_ok=True)
  398. print("="*60)
  399. print("OmniDocBench 并行评估开始")
  400. print("="*60)
  401. print(f"数据集路径: {dataset_path}")
  402. print(f"输出目录: {output_dir}")
  403. print(f"批处理大小: {batch_size}")
  404. print(f"最大工作线程/进程数: {max_workers}")
  405. print(f"使用GPU: {use_gpu}")
  406. print(f"并行方式: {'多进程' if use_multiprocessing else '多线程'}")
  407. # 查找所有图像文件
  408. image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.tiff']
  409. image_files = []
  410. for ext in image_extensions:
  411. image_files.extend(glob.glob(os.path.join(dataset_path, ext)))
  412. print(f"找到 {len(image_files)} 个图像文件")
  413. if not image_files:
  414. print("未找到任何图像文件,程序终止")
  415. return
  416. # 创建评估器
  417. evaluator = OmniDocBenchParallelEvaluator(pipeline_config)
  418. # 开始处理
  419. start_time = time.time()
  420. if use_multiprocessing:
  421. # 多进程处理(推荐用于CPU)
  422. print("使用多进程并行处理...")
  423. results = evaluator.parallel_process_with_multiprocessing(
  424. image_files, batch_size, max_workers, use_gpu
  425. )
  426. else:
  427. # 多线程处理(推荐用于GPU)
  428. print("使用多线程并行处理...")
  429. results = evaluator.parallel_process_with_threading(
  430. image_files, batch_size, max_workers, use_gpu
  431. )
  432. total_time = time.time() - start_time
  433. # 保存最终结果
  434. output_file = os.path.join(output_dir, f"OmniDocBench_PPStructureV3_batch{batch_size}.json")
  435. try:
  436. with open(output_file, 'w', encoding='utf-8') as f:
  437. json.dump(results, f, ensure_ascii=False, indent=2)
  438. print("\n" + "="*60)
  439. print("处理完成!")
  440. print("="*60)
  441. # 统计信息
  442. success_count = sum(1 for r in results if r.get('success', False))
  443. error_count = len(results) - success_count
  444. total_processing_time = sum(r.get('processing_time', 0) for r in results if r.get('success', False))
  445. avg_processing_time = total_processing_time / success_count if success_count > 0 else 0
  446. print(f"总文件数: {len(image_files)}")
  447. print(f"成功处理: {success_count}")
  448. print(f"失败数量: {error_count}")
  449. print(f"成功率: {success_count / len(image_files) * 100:.2f}%")
  450. print(f"总耗时: {total_time:.2f}秒")
  451. print(f"平均处理时间: {avg_processing_time:.2f}秒/张")
  452. print(f"吞吐量: {len(image_files) / total_time:.2f}张/秒")
  453. print(f"结果保存至: {output_file}")
  454. # 保存统计信息
  455. stats = {
  456. "total_files": len(image_files),
  457. "success_count": success_count,
  458. "error_count": error_count,
  459. "success_rate": success_count / len(image_files),
  460. "total_time": total_time,
  461. "avg_processing_time": avg_processing_time,
  462. "throughput": len(image_files) / total_time,
  463. "batch_size": batch_size,
  464. "max_workers": max_workers,
  465. "use_gpu": use_gpu,
  466. "use_multiprocessing": use_multiprocessing
  467. }
  468. stats_file = os.path.join(output_dir, f"processing_stats_batch{batch_size}.json")
  469. with open(stats_file, 'w', encoding='utf-8') as f:
  470. json.dump(stats, f, ensure_ascii=False, indent=2)
  471. print(f"统计信息保存至: {stats_file}")
  472. except Exception as e:
  473. print(f"保存结果文件时发生错误: {str(e)}")
  474. traceback.print_exc()
  475. if __name__ == "__main__":
  476. main()