ppstructurev3_parallel_predict.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398
  1. # zhch/omnidocbench_parallel_eval.py
  2. import json
  3. import time
  4. import os
  5. import glob
  6. import traceback
  7. from pathlib import Path
  8. from typing import List, Dict, Any, Tuple
  9. from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
  10. from multiprocessing import Queue, Manager
  11. import cv2
  12. import numpy as np
  13. from paddlex import create_pipeline
  14. from tqdm import tqdm
  15. import threading
  16. class PPStructureV3ParallelPredictor:
  17. """
  18. PP-StructureV3并行预测器,支持多进程批处理
  19. """
  20. def __init__(self, pipeline_config_path: str = "PP-StructureV3", output_path: str = "output", use_gpu: bool = True):
  21. """
  22. 初始化预测器
  23. Args:
  24. pipeline_config_path: PaddleX pipeline配置文件路径
  25. """
  26. self.pipeline_config = pipeline_config_path
  27. self.pipeline = create_pipeline(pipeline=self.pipeline_config)
  28. self.output_path = output_path
  29. self.use_gpu = use_gpu
  30. def create_pipeline(self):
  31. """创建pipeline实例(每个进程单独创建)"""
  32. if self.pipeline is not None:
  33. return self.pipeline
  34. return create_pipeline(pipeline=self.pipeline_config)
  35. def process_single_image(self, image_path: str) -> Dict[str, Any]:
  36. """
  37. 处理单张图像
  38. Args:
  39. image_path: 图像路径
  40. output_path: 输出路径
  41. use_gpu: 是否使用GPU
  42. Returns:
  43. 处理结果{"image_path": str, "success": bool, "processing_time": float, "error": str}
  44. """
  45. try:
  46. # 读取图像获取尺寸信息
  47. image = cv2.imread(image_path)
  48. if image is None:
  49. return {
  50. "image_path": Path(image_path).name,
  51. "error": "无法读取图像",
  52. "success": False,
  53. "processing_time": 0
  54. }
  55. height, width = image.shape[:2]
  56. # 运行PaddleX pipeline
  57. start_time = time.time()
  58. output = self.pipeline.predict(
  59. input=image_path,
  60. device="gpu" if self.use_gpu else "cpu",
  61. use_doc_orientation_classify=True,
  62. use_doc_unwarping=False,
  63. use_seal_recognition=True,
  64. use_chart_recognition=True,
  65. use_table_recognition=True,
  66. use_formula_recognition=True,
  67. )
  68. # 可视化结果并保存 json 结果
  69. for res in output:
  70. res.save_to_json(save_path=self.output_path) # 保存所有结果到指定路径
  71. res.save_to_markdown(save_path=self.output_path) # 保存所有结果到指定路径
  72. process_time = time.time() - start_time
  73. # 添加处理时间信息
  74. result = {"image_path": Path(image_path).name}
  75. if output:
  76. result["processing_time"] = process_time
  77. result["success"] = True
  78. return result
  79. except Exception as e:
  80. return {
  81. "image_path": Path(image_path).name,
  82. "error": str(e),
  83. "success": False,
  84. "processing_time": 0
  85. }
  86. def process_batch(self, image_paths: List[str]) -> List[Dict[str, Any]]:
  87. """
  88. 批处理图像
  89. Args:
  90. image_paths: 图像路径列表
  91. use_gpu: 是否使用GPU
  92. Returns:
  93. 结果列表
  94. """
  95. results = []
  96. for image_path in image_paths:
  97. result = self.process_single_image(image_path=image_path)
  98. results.append(result)
  99. return results
  100. def parallel_process_with_threading(self,
  101. image_paths: List[str],
  102. batch_size: int = 4,
  103. max_workers: int = 4
  104. ) -> List[Dict[str, Any]]:
  105. """
  106. 使用多线程并行处理(推荐用于GPU)
  107. Args:
  108. image_paths: 图像路径列表
  109. batch_size: 批处理大小
  110. max_workers: 最大工作线程数
  111. Returns:
  112. 处理结果列表
  113. """
  114. # 将图像路径分批
  115. batches = [image_paths[i:i + batch_size] for i in range(0, len(image_paths), batch_size)]
  116. all_results = []
  117. completed_count = 0
  118. total_images = len(image_paths)
  119. # 创建进度条
  120. with tqdm(total=total_images, desc="处理图像", unit="张") as pbar:
  121. with ThreadPoolExecutor(max_workers=max_workers) as executor:
  122. # 提交所有批处理任务
  123. future_to_batch = {
  124. executor.submit(self.process_batch, batch): batch
  125. for batch in batches
  126. }
  127. # 收集结果
  128. for future in as_completed(future_to_batch):
  129. batch = future_to_batch[future]
  130. try:
  131. batch_results = future.result()
  132. all_results.extend(batch_results)
  133. completed_count += len(batch)
  134. pbar.update(len(batch))
  135. # 更新进度条描述
  136. success_count = sum(1 for r in batch_results if r.get('success', False))
  137. pbar.set_postfix({
  138. 'batch_success': f"{success_count}/{len(batch)}",
  139. 'total_success': f"{sum(1 for r in all_results if r.get('success', False))}/{completed_count}"
  140. })
  141. except Exception as e:
  142. print(f"批处理失败: {e}")
  143. # 为失败的批次创建错误结果
  144. for img_path in batch:
  145. error_result = {
  146. "image_path": Path(img_path).name,
  147. "error": str(e),
  148. "success": False,
  149. "processing_time": 0
  150. }
  151. all_results.append(error_result)
  152. pbar.update(len(batch))
  153. return all_results
  154. def save_results_incrementally(self,
  155. results: List[Dict[str, Any]],
  156. output_file: str,
  157. save_interval: int = 50):
  158. """
  159. 增量保存结果
  160. Args:
  161. results: 结果列表
  162. output_file: 输出文件路径
  163. save_interval: 保存间隔
  164. """
  165. if len(results) % save_interval == 0 and len(results) > 0:
  166. try:
  167. with open(output_file, 'w', encoding='utf-8') as f:
  168. json.dump(results, f, ensure_ascii=False, indent=2)
  169. print(f"已保存 {len(results)} 个结果到 {output_file}")
  170. except Exception as e:
  171. print(f"保存结果时出错: {e}")
  172. def process_batch_worker(image_paths: List[str], pipeline_config: str, output_path: str, use_gpu: bool) -> List[Dict[str, Any]]:
  173. """
  174. 多进程工作函数
  175. """
  176. try:
  177. # 在每个进程中创建pipeline实例
  178. predictor = PPStructureV3ParallelPredictor(pipeline_config, output_path=output_path, use_gpu=use_gpu)
  179. return predictor.process_batch(image_paths)
  180. except Exception as e:
  181. # 返回错误结果
  182. error_results = []
  183. for img_path in image_paths:
  184. error_results.append({
  185. "image_path": Path(img_path).name,
  186. "error": str(e),
  187. "success": False,
  188. "processing_time": 0
  189. })
  190. return error_results
  191. def parallel_process_with_multiprocessing(image_paths: List[str],
  192. batch_size: int = 4,
  193. max_workers: int = 4,
  194. pipeline_config: str = "PP-StructureV3",
  195. output_path: str = "./output",
  196. use_gpu: bool = True
  197. ) -> List[Dict[str, Any]]:
  198. """
  199. 使用多进程并行处理(推荐用于CPU)
  200. Args:
  201. image_paths: 图像路径列表
  202. batch_size: 批处理大小
  203. max_workers: 最大工作进程数
  204. use_gpu: 是否使用GPU
  205. Returns:
  206. 处理结果列表
  207. """
  208. # 将图像路径分批
  209. batches = [image_paths[i:i + batch_size] for i in range(0, len(image_paths), batch_size)]
  210. all_results = []
  211. completed_count = 0
  212. total_images = len(image_paths)
  213. # 创建进度条
  214. with tqdm(total=total_images, desc="处理图像", unit="张") as pbar:
  215. with ProcessPoolExecutor(max_workers=max_workers) as executor:
  216. # 提交所有批处理任务
  217. future_to_batch = {
  218. executor.submit(process_batch_worker, batch, pipeline_config, output_path, use_gpu): batch
  219. for batch in batches
  220. }
  221. # 收集结果
  222. for future in as_completed(future_to_batch):
  223. batch = future_to_batch[future]
  224. try:
  225. batch_results = future.result()
  226. all_results.extend(batch_results)
  227. completed_count += len(batch)
  228. pbar.update(len(batch))
  229. # 更新进度条描述
  230. success_count = sum(1 for r in batch_results if r.get('success', False))
  231. pbar.set_postfix({
  232. 'batch_success': f"{success_count}/{len(batch)}",
  233. 'total_success': f"{sum(1 for r in all_results if r.get('success', False))}/{completed_count}"
  234. })
  235. except Exception as e:
  236. print(f"批处理失败: {e}")
  237. # 为失败的批次创建错误结果
  238. for img_path in batch:
  239. error_result = {
  240. "image_path": Path(img_path).name,
  241. "error": str(e),
  242. "success": False,
  243. "processing_time": 0
  244. }
  245. all_results.append(error_result)
  246. pbar.update(len(batch))
  247. return all_results
  248. def main():
  249. """主函数 - 并行处理OmniDocBench数据集"""
  250. # 配置参数
  251. dataset_path = "../../OmniDocBench/OpenDataLab___OmniDocBench/images"
  252. output_dir = "./OmniDocBench_Results"
  253. pipeline_config = "PP-StructureV3"
  254. # 并行处理参数
  255. batch_size = 4 # 批处理大小
  256. max_workers = 4 # 最大工作进程/线程数
  257. use_gpu = True # 是否使用GPU
  258. use_multiprocessing = True # False=多线程(GPU推荐), True=多进程(CPU推荐)
  259. # 确保输出目录存在
  260. print(f"输出目录: {Path(output_dir).absolute()}")
  261. os.makedirs(output_dir, exist_ok=True)
  262. dataset_path = Path(dataset_path).resolve()
  263. output_dir = Path(output_dir).resolve()
  264. print("="*60)
  265. print("OmniDocBench 并行评估开始")
  266. print("="*60)
  267. print(f"数据集路径: {dataset_path}")
  268. print(f"输出目录: {output_dir}")
  269. print(f"批处理大小: {batch_size}")
  270. print(f"最大工作线程/进程数: {max_workers}")
  271. print(f"使用GPU: {use_gpu}")
  272. print(f"并行方式: {'多进程' if use_multiprocessing else '多线程'}")
  273. # 查找所有图像文件
  274. image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.tiff']
  275. image_files = []
  276. for ext in image_extensions:
  277. image_files.extend(glob.glob(os.path.join(dataset_path, ext)))
  278. print(f"找到 {len(image_files)} 个图像文件")
  279. if not image_files:
  280. print("未找到任何图像文件,程序终止")
  281. return
  282. # 开始处理
  283. start_time = time.time()
  284. if use_multiprocessing:
  285. # 多进程处理(推荐用于CPU)
  286. print("使用多进程并行处理...")
  287. results = parallel_process_with_multiprocessing(
  288. image_files, batch_size, max_workers, pipeline_config, output_dir, use_gpu
  289. )
  290. else:
  291. # 多线程处理(推荐用于GPU)
  292. print("使用多线程并行处理...")
  293. predictor = PPStructureV3ParallelPredictor(pipeline_config, output_path=output_dir, use_gpu=use_gpu)
  294. results = predictor.parallel_process_with_threading(
  295. image_files, batch_size, max_workers
  296. )
  297. total_time = time.time() - start_time
  298. # 保存最终结果
  299. output_file = os.path.join(output_dir, f"OmniDocBench_PPStructureV3_batch{batch_size}.json")
  300. try:
  301. # 统计信息
  302. success_count = sum(1 for r in results if r.get('success', False))
  303. error_count = len(results) - success_count
  304. total_processing_time = sum(r.get('processing_time', 0) for r in results if r.get('success', False))
  305. avg_processing_time = total_processing_time / success_count if success_count > 0 else 0
  306. print(f"总文件数: {len(image_files)}")
  307. print(f"成功处理: {success_count}")
  308. print(f"失败数量: {error_count}")
  309. print(f"成功率: {success_count / len(image_files) * 100:.2f}%")
  310. print(f"总耗时: {total_time:.2f}秒")
  311. print(f"平均处理时间: {avg_processing_time:.2f}秒/张")
  312. print(f"吞吐量: {len(image_files) / total_time:.2f}张/秒")
  313. print(f"结果保存至: {output_file}")
  314. # 保存统计信息
  315. stats = {
  316. "total_files": len(image_files),
  317. "success_count": success_count,
  318. "error_count": error_count,
  319. "success_rate": success_count / len(image_files),
  320. "total_time": total_time,
  321. "avg_processing_time": avg_processing_time,
  322. "throughput": len(image_files) / total_time,
  323. "batch_size": batch_size,
  324. "max_workers": max_workers,
  325. "use_gpu": use_gpu,
  326. "use_multiprocessing": use_multiprocessing
  327. }
  328. results['stats'] = stats
  329. with open(output_file, 'w', encoding='utf-8') as f:
  330. json.dump(results, f, ensure_ascii=False, indent=2)
  331. print("\n" + "="*60)
  332. print("处理完成!")
  333. print("="*60)
  334. except Exception as e:
  335. print(f"保存结果文件时发生错误: {str(e)}")
  336. traceback.print_exc()
  337. if __name__ == "__main__":
  338. main()