ppstructurev3_parallel_predict.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393
  1. # zhch/omnidocbench_parallel_eval.py
  2. import json
  3. import time
  4. import os
  5. import glob
  6. import traceback
  7. from pathlib import Path
  8. from typing import List, Dict, Any, Tuple
  9. from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
  10. from multiprocessing import Queue, Manager
  11. import cv2
  12. import numpy as np
  13. from paddlex import create_pipeline
  14. from tqdm import tqdm
  15. import threading
  16. class PPStructureV3ParallelPredictor:
  17. """
  18. PP-StructureV3并行预测器,支持多进程批处理
  19. """
  20. def __init__(self, pipeline_config_path: str = "PP-StructureV3", output_path: str = "output", use_gpu: bool = True):
  21. """
  22. 初始化预测器
  23. Args:
  24. pipeline_config_path: PaddleX pipeline配置文件路径
  25. """
  26. self.pipeline_config = pipeline_config_path
  27. self.pipeline = create_pipeline(pipeline=self.pipeline_config)
  28. self.output_path = output_path
  29. self.use_gpu = use_gpu
  30. def create_pipeline(self):
  31. """创建pipeline实例(每个进程单独创建)"""
  32. if self.pipeline is not None:
  33. return self.pipeline
  34. return create_pipeline(pipeline=self.pipeline_config)
  35. def process_single_image(self, image_path: str) -> Dict[str, Any]:
  36. """
  37. 处理单张图像
  38. Args:
  39. image_path: 图像路径
  40. output_path: 输出路径
  41. use_gpu: 是否使用GPU
  42. Returns:
  43. 处理结果{"image_path": str, "success": bool, "processing_time": float, "error": str}
  44. """
  45. try:
  46. # 读取图像获取尺寸信息
  47. image = cv2.imread(image_path)
  48. if image is None:
  49. return {
  50. "image_path": Path(image_path).name,
  51. "error": "无法读取图像",
  52. "success": False,
  53. "processing_time": 0
  54. }
  55. height, width = image.shape[:2]
  56. # 运行PaddleX pipeline
  57. start_time = time.time()
  58. output = list(self.pipeline.predict(
  59. input=image_path,
  60. device="gpu" if self.use_gpu else "cpu",
  61. use_doc_orientation_classify=True,
  62. use_doc_unwarping=False,
  63. use_seal_recognition=True,
  64. use_chart_recognition=True,
  65. use_table_recognition=True,
  66. use_formula_recognition=True,
  67. ))
  68. output.save_to_json(save_path=self.output_path) # 保存JSON结果
  69. output.save_to_markdown(save_path=self.output_path) # 保存Markdown结果
  70. process_time = time.time() - start_time
  71. # 添加处理时间信息
  72. result = {"image_path": Path(image_path).name}
  73. if output:
  74. result["processing_time"] = process_time
  75. result["success"] = True
  76. return result
  77. except Exception as e:
  78. return {
  79. "image_path": Path(image_path).name,
  80. "error": str(e),
  81. "success": False,
  82. "processing_time": 0
  83. }
  84. def process_batch(self, image_paths: List[str]) -> List[Dict[str, Any]]:
  85. """
  86. 批处理图像
  87. Args:
  88. image_paths: 图像路径列表
  89. use_gpu: 是否使用GPU
  90. Returns:
  91. 结果列表
  92. """
  93. results = []
  94. for image_path in image_paths:
  95. result = self.process_single_image(image_path=image_path)
  96. results.append(result)
  97. return results
  98. def parallel_process_with_threading(self,
  99. image_paths: List[str],
  100. batch_size: int = 4,
  101. max_workers: int = 4
  102. ) -> List[Dict[str, Any]]:
  103. """
  104. 使用多线程并行处理(推荐用于GPU)
  105. Args:
  106. image_paths: 图像路径列表
  107. batch_size: 批处理大小
  108. max_workers: 最大工作线程数
  109. Returns:
  110. 处理结果列表
  111. """
  112. # 将图像路径分批
  113. batches = [image_paths[i:i + batch_size] for i in range(0, len(image_paths), batch_size)]
  114. all_results = []
  115. completed_count = 0
  116. total_images = len(image_paths)
  117. # 创建进度条
  118. with tqdm(total=total_images, desc="处理图像", unit="张") as pbar:
  119. with ThreadPoolExecutor(max_workers=max_workers) as executor:
  120. # 提交所有批处理任务
  121. future_to_batch = {
  122. executor.submit(self.process_batch, batch): batch
  123. for batch in batches
  124. }
  125. # 收集结果
  126. for future in as_completed(future_to_batch):
  127. batch = future_to_batch[future]
  128. try:
  129. batch_results = future.result()
  130. all_results.extend(batch_results)
  131. completed_count += len(batch)
  132. pbar.update(len(batch))
  133. # 更新进度条描述
  134. success_count = sum(1 for r in batch_results if r.get('success', False))
  135. pbar.set_postfix({
  136. 'batch_success': f"{success_count}/{len(batch)}",
  137. 'total_success': f"{sum(1 for r in all_results if r.get('success', False))}/{completed_count}"
  138. })
  139. except Exception as e:
  140. print(f"批处理失败: {e}")
  141. # 为失败的批次创建错误结果
  142. for img_path in batch:
  143. error_result = {
  144. "image_path": Path(img_path).name,
  145. "error": str(e),
  146. "success": False,
  147. "processing_time": 0
  148. }
  149. all_results.append(error_result)
  150. pbar.update(len(batch))
  151. return all_results
  152. def save_results_incrementally(self,
  153. results: List[Dict[str, Any]],
  154. output_file: str,
  155. save_interval: int = 50):
  156. """
  157. 增量保存结果
  158. Args:
  159. results: 结果列表
  160. output_file: 输出文件路径
  161. save_interval: 保存间隔
  162. """
  163. if len(results) % save_interval == 0 and len(results) > 0:
  164. try:
  165. with open(output_file, 'w', encoding='utf-8') as f:
  166. json.dump(results, f, ensure_ascii=False, indent=2)
  167. print(f"已保存 {len(results)} 个结果到 {output_file}")
  168. except Exception as e:
  169. print(f"保存结果时出错: {e}")
  170. def process_batch_worker(image_paths: List[str], pipeline_config: str, output_path: str, use_gpu: bool) -> List[Dict[str, Any]]:
  171. """
  172. 多进程工作函数
  173. """
  174. try:
  175. # 在每个进程中创建pipeline实例
  176. predictor = PPStructureV3ParallelPredictor(pipeline_config, output_path=output_path, use_gpu=use_gpu)
  177. return predictor.process_batch(image_paths)
  178. except Exception as e:
  179. # 返回错误结果
  180. error_results = []
  181. for img_path in image_paths:
  182. error_results.append({
  183. "image_path": Path(img_path).name,
  184. "error": str(e),
  185. "success": False,
  186. "processing_time": 0
  187. })
  188. return error_results
  189. def parallel_process_with_multiprocessing(image_paths: List[str],
  190. batch_size: int = 4,
  191. max_workers: int = 4,
  192. pipeline_config: str = "PP-StructureV3",
  193. output_path: str = "./output",
  194. use_gpu: bool = True
  195. ) -> List[Dict[str, Any]]:
  196. """
  197. 使用多进程并行处理(推荐用于CPU)
  198. Args:
  199. image_paths: 图像路径列表
  200. batch_size: 批处理大小
  201. max_workers: 最大工作进程数
  202. use_gpu: 是否使用GPU
  203. Returns:
  204. 处理结果列表
  205. """
  206. # 将图像路径分批
  207. batches = [image_paths[i:i + batch_size] for i in range(0, len(image_paths), batch_size)]
  208. all_results = []
  209. completed_count = 0
  210. total_images = len(image_paths)
  211. # 创建进度条
  212. with tqdm(total=total_images, desc="处理图像", unit="张") as pbar:
  213. with ProcessPoolExecutor(max_workers=max_workers) as executor:
  214. # 提交所有批处理任务
  215. future_to_batch = {
  216. executor.submit(process_batch_worker, batch, pipeline_config, output_path, use_gpu): batch
  217. for batch in batches
  218. }
  219. # 收集结果
  220. for future in as_completed(future_to_batch):
  221. batch = future_to_batch[future]
  222. try:
  223. batch_results = future.result()
  224. all_results.extend(batch_results)
  225. completed_count += len(batch)
  226. pbar.update(len(batch))
  227. # 更新进度条描述
  228. success_count = sum(1 for r in batch_results if r.get('success', False))
  229. pbar.set_postfix({
  230. 'batch_success': f"{success_count}/{len(batch)}",
  231. 'total_success': f"{sum(1 for r in all_results if r.get('success', False))}/{completed_count}"
  232. })
  233. except Exception as e:
  234. print(f"批处理失败: {e}")
  235. # 为失败的批次创建错误结果
  236. for img_path in batch:
  237. error_result = {
  238. "image_path": Path(img_path).name,
  239. "error": str(e),
  240. "success": False,
  241. "processing_time": 0
  242. }
  243. all_results.append(error_result)
  244. pbar.update(len(batch))
  245. return all_results
  246. def main():
  247. """主函数 - 并行处理OmniDocBench数据集"""
  248. # 配置参数
  249. dataset_path = "/Users/zhch158/workspace/repository.git/OmniDocBench/OpenDataLab___OmniDocBench/images"
  250. output_dir = "/Users/zhch158/workspace/repository.git/PaddleX/zhch/OmniDocBench_Results"
  251. pipeline_config = "PP-StructureV3"
  252. # 并行处理参数
  253. batch_size = 4 # 批处理大小
  254. max_workers = 4 # 最大工作进程/线程数
  255. use_gpu = True # 是否使用GPU
  256. use_multiprocessing = False # False=多线程(GPU推荐), True=多进程(CPU推荐)
  257. # 确保输出目录存在
  258. os.makedirs(output_dir, exist_ok=True)
  259. print("="*60)
  260. print("OmniDocBench 并行评估开始")
  261. print("="*60)
  262. print(f"数据集路径: {dataset_path}")
  263. print(f"输出目录: {output_dir}")
  264. print(f"批处理大小: {batch_size}")
  265. print(f"最大工作线程/进程数: {max_workers}")
  266. print(f"使用GPU: {use_gpu}")
  267. print(f"并行方式: {'多进程' if use_multiprocessing else '多线程'}")
  268. # 查找所有图像文件
  269. image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.tiff']
  270. image_files = []
  271. for ext in image_extensions:
  272. image_files.extend(glob.glob(os.path.join(dataset_path, ext)))
  273. print(f"找到 {len(image_files)} 个图像文件")
  274. if not image_files:
  275. print("未找到任何图像文件,程序终止")
  276. return
  277. # 开始处理
  278. start_time = time.time()
  279. if use_multiprocessing:
  280. # 多进程处理(推荐用于CPU)
  281. print("使用多进程并行处理...")
  282. results = parallel_process_with_multiprocessing(
  283. image_files, batch_size, max_workers
  284. )
  285. else:
  286. # 多线程处理(推荐用于GPU)
  287. print("使用多线程并行处理...")
  288. predictor = PPStructureV3ParallelPredictor(pipeline_config, output_path=output_dir, use_gpu=use_gpu)
  289. results = predictor.parallel_process_with_threading(
  290. image_files, batch_size, max_workers
  291. )
  292. total_time = time.time() - start_time
  293. # 保存最终结果
  294. output_file = os.path.join(output_dir, f"OmniDocBench_PPStructureV3_batch{batch_size}.json")
  295. try:
  296. # 统计信息
  297. success_count = sum(1 for r in results if r.get('success', False))
  298. error_count = len(results) - success_count
  299. total_processing_time = sum(r.get('processing_time', 0) for r in results if r.get('success', False))
  300. avg_processing_time = total_processing_time / success_count if success_count > 0 else 0
  301. print(f"总文件数: {len(image_files)}")
  302. print(f"成功处理: {success_count}")
  303. print(f"失败数量: {error_count}")
  304. print(f"成功率: {success_count / len(image_files) * 100:.2f}%")
  305. print(f"总耗时: {total_time:.2f}秒")
  306. print(f"平均处理时间: {avg_processing_time:.2f}秒/张")
  307. print(f"吞吐量: {len(image_files) / total_time:.2f}张/秒")
  308. print(f"结果保存至: {output_file}")
  309. # 保存统计信息
  310. stats = {
  311. "total_files": len(image_files),
  312. "success_count": success_count,
  313. "error_count": error_count,
  314. "success_rate": success_count / len(image_files),
  315. "total_time": total_time,
  316. "avg_processing_time": avg_processing_time,
  317. "throughput": len(image_files) / total_time,
  318. "batch_size": batch_size,
  319. "max_workers": max_workers,
  320. "use_gpu": use_gpu,
  321. "use_multiprocessing": use_multiprocessing
  322. }
  323. results['stats'] = stats
  324. with open(output_file, 'w', encoding='utf-8') as f:
  325. json.dump(results, f, ensure_ascii=False, indent=2)
  326. print("\n" + "="*60)
  327. print("处理完成!")
  328. print("="*60)
  329. except Exception as e:
  330. print(f"保存结果文件时发生错误: {str(e)}")
  331. traceback.print_exc()
  332. if __name__ == "__main__":
  333. main()