浏览代码

feat: 添加PP-StructureV3并行预测器,支持多进程和多线程处理OmniDocBench数据集

zhch158_admin 3 月之前
父节点
当前提交
56f65059bb
共有 1 个文件被更改,包括 393 次插入0 次删除
  1. 393 0
      zhch/ppstructurev3_parallel_predict.py

+ 393 - 0
zhch/ppstructurev3_parallel_predict.py

@@ -0,0 +1,393 @@
+# zhch/omnidocbench_parallel_eval.py
+import json
+import time
+import os
+import glob
+import traceback
+from pathlib import Path
+from typing import List, Dict, Any, Tuple
+from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
+from multiprocessing import Queue, Manager
+import cv2
+import numpy as np
+from paddlex import create_pipeline
+from tqdm import tqdm
+import threading
+
+class PPStructureV3ParallelPredictor:
+    """
+    PP-StructureV3并行预测器,支持多进程批处理
+    """
+
+    def __init__(self, pipeline_config_path: str = "PP-StructureV3", output_path: str = "output", use_gpu: bool = True):
+        """
+        初始化预测器
+        
+        Args:
+            pipeline_config_path: PaddleX pipeline配置文件路径
+        """
+        self.pipeline_config = pipeline_config_path
+        self.pipeline = create_pipeline(pipeline=self.pipeline_config)
+        self.output_path = output_path
+        self.use_gpu = use_gpu
+
+    def create_pipeline(self):
+        """创建pipeline实例(每个进程单独创建)"""
+        if self.pipeline is not None:
+            return self.pipeline
+        return create_pipeline(pipeline=self.pipeline_config)
+
+    def process_single_image(self, image_path: str) -> Dict[str, Any]:
+        """
+        处理单张图像
+        
+        Args:
+            image_path: 图像路径
+            output_path: 输出路径
+            use_gpu: 是否使用GPU
+            
+        Returns:
+            处理结果{"image_path": str, "success": bool, "processing_time": float, "error": str}
+        """
+        try:
+            # 读取图像获取尺寸信息
+            image = cv2.imread(image_path)
+            if image is None:
+                return {
+                    "image_path": Path(image_path).name,
+                    "error": "无法读取图像",
+                    "success": False,
+                    "processing_time": 0
+                }
+                
+            height, width = image.shape[:2]
+            
+            # 运行PaddleX pipeline
+            start_time = time.time()
+            
+            output = list(self.pipeline.predict(
+                input=image_path,
+                device="gpu" if self.use_gpu else "cpu",
+                use_doc_orientation_classify=True,
+                use_doc_unwarping=False,
+                use_seal_recognition=True,
+                use_chart_recognition=True,
+                use_table_recognition=True,
+                use_formula_recognition=True,
+            ))
+            
+            output.save_to_json(save_path=self.output_path)  # 保存JSON结果
+            output.save_to_markdown(save_path=self.output_path)  # 保存Markdown结果
+            process_time = time.time() - start_time
+            
+            # 添加处理时间信息
+            result = {"image_path": Path(image_path).name}
+            if output:
+                result["processing_time"] = process_time
+                result["success"] = True
+            
+            return result
+            
+        except Exception as e:
+            return {
+                "image_path": Path(image_path).name,
+                "error": str(e),
+                "success": False,
+                "processing_time": 0
+            }
+    
+    def process_batch(self, image_paths: List[str]) -> List[Dict[str, Any]]:
+        """
+        批处理图像
+        
+        Args:
+            image_paths: 图像路径列表
+            use_gpu: 是否使用GPU
+            
+        Returns:
+            结果列表
+        """
+        results = []
+        
+        for image_path in image_paths:
+            result = self.process_single_image(image_path=image_path)
+            results.append(result)
+        
+        return results
+    
+    def parallel_process_with_threading(self, 
+                                      image_paths: List[str], 
+                                      batch_size: int = 4,
+                                      max_workers: int = 4
+                                      ) -> List[Dict[str, Any]]:
+        """
+        使用多线程并行处理(推荐用于GPU)
+        
+        Args:
+            image_paths: 图像路径列表
+            batch_size: 批处理大小
+            max_workers: 最大工作线程数
+            
+        Returns:
+            处理结果列表
+        """
+        # 将图像路径分批
+        batches = [image_paths[i:i + batch_size] for i in range(0, len(image_paths), batch_size)]
+        
+        all_results = []
+        completed_count = 0
+        total_images = len(image_paths)
+        
+        # 创建进度条
+        with tqdm(total=total_images, desc="处理图像", unit="张") as pbar:
+            with ThreadPoolExecutor(max_workers=max_workers) as executor:
+                # 提交所有批处理任务
+                future_to_batch = {
+                    executor.submit(self.process_batch, batch): batch 
+                    for batch in batches
+                }
+                
+                # 收集结果
+                for future in as_completed(future_to_batch):
+                    batch = future_to_batch[future]
+                    try:
+                        batch_results = future.result()
+                        all_results.extend(batch_results)
+                        completed_count += len(batch)
+                        pbar.update(len(batch))
+                        
+                        # 更新进度条描述
+                        success_count = sum(1 for r in batch_results if r.get('success', False))
+                        pbar.set_postfix({
+                            'batch_success': f"{success_count}/{len(batch)}",
+                            'total_success': f"{sum(1 for r in all_results if r.get('success', False))}/{completed_count}"
+                        })
+                        
+                    except Exception as e:
+                        print(f"批处理失败: {e}")
+                        # 为失败的批次创建错误结果
+                        for img_path in batch:
+                            error_result = {
+                                "image_path": Path(img_path).name,
+                                "error": str(e),
+                                "success": False,
+                                "processing_time": 0
+                            }
+                            all_results.append(error_result)
+                        pbar.update(len(batch))
+        
+        return all_results
+    
+    
+    def save_results_incrementally(self, 
+                                 results: List[Dict[str, Any]], 
+                                 output_file: str,
+                                 save_interval: int = 50):
+        """
+        增量保存结果
+        
+        Args:
+            results: 结果列表
+            output_file: 输出文件路径
+            save_interval: 保存间隔
+        """
+        if len(results) % save_interval == 0 and len(results) > 0:
+            try:
+                with open(output_file, 'w', encoding='utf-8') as f:
+                    json.dump(results, f, ensure_ascii=False, indent=2)
+                print(f"已保存 {len(results)} 个结果到 {output_file}")
+            except Exception as e:
+                print(f"保存结果时出错: {e}")
+
+def process_batch_worker(image_paths: List[str], pipeline_config: str, output_path: str, use_gpu: bool) -> List[Dict[str, Any]]:
+    """
+    多进程工作函数
+    """
+    try:
+        # 在每个进程中创建pipeline实例
+        predictor = PPStructureV3ParallelPredictor(pipeline_config, output_path=output_path, use_gpu=use_gpu)
+        return predictor.process_batch(image_paths)
+    except Exception as e:
+        # 返回错误结果
+        error_results = []
+        for img_path in image_paths:
+            error_results.append({
+                "image_path": Path(img_path).name,
+                "error": str(e),
+                "success": False,
+                "processing_time": 0
+            })
+        return error_results
+
+def parallel_process_with_multiprocessing(image_paths: List[str],
+                                        batch_size: int = 4,
+                                        max_workers: int = 4,
+                                        pipeline_config: str = "PP-StructureV3",
+                                        output_path: str = "./output",
+                                        use_gpu: bool = True
+                                        ) -> List[Dict[str, Any]]:
+    """
+    使用多进程并行处理(推荐用于CPU)
+    
+    Args:
+        image_paths: 图像路径列表
+        batch_size: 批处理大小
+        max_workers: 最大工作进程数
+        use_gpu: 是否使用GPU
+        
+    Returns:
+        处理结果列表
+    """
+    # 将图像路径分批
+    batches = [image_paths[i:i + batch_size] for i in range(0, len(image_paths), batch_size)]
+    
+    all_results = []
+    completed_count = 0
+    total_images = len(image_paths)
+    
+    # 创建进度条
+    with tqdm(total=total_images, desc="处理图像", unit="张") as pbar:
+        with ProcessPoolExecutor(max_workers=max_workers) as executor:
+            # 提交所有批处理任务
+            future_to_batch = {
+                executor.submit(process_batch_worker, batch, pipeline_config, output_path, use_gpu): batch
+                for batch in batches
+            }
+            
+            # 收集结果
+            for future in as_completed(future_to_batch):
+                batch = future_to_batch[future]
+                try:
+                    batch_results = future.result()
+                    all_results.extend(batch_results)
+                    completed_count += len(batch)
+                    pbar.update(len(batch))
+                    
+                    # 更新进度条描述
+                    success_count = sum(1 for r in batch_results if r.get('success', False))
+                    pbar.set_postfix({
+                        'batch_success': f"{success_count}/{len(batch)}",
+                        'total_success': f"{sum(1 for r in all_results if r.get('success', False))}/{completed_count}"
+                    })
+                    
+                except Exception as e:
+                    print(f"批处理失败: {e}")
+                    # 为失败的批次创建错误结果
+                    for img_path in batch:
+                        error_result = {
+                            "image_path": Path(img_path).name,
+                            "error": str(e),
+                            "success": False,
+                            "processing_time": 0
+                        }
+                        all_results.append(error_result)
+                    pbar.update(len(batch))
+    
+    return all_results
+
+def main():
+    """主函数 - 并行处理OmniDocBench数据集"""
+    
+    # 配置参数
+    dataset_path = "/Users/zhch158/workspace/repository.git/OmniDocBench/OpenDataLab___OmniDocBench/images"
+    output_dir = "/Users/zhch158/workspace/repository.git/PaddleX/zhch/OmniDocBench_Results"
+    pipeline_config = "PP-StructureV3"
+    
+    # 并行处理参数
+    batch_size = 4          # 批处理大小
+    max_workers = 4         # 最大工作进程/线程数
+    use_gpu = True          # 是否使用GPU
+    use_multiprocessing = False  # False=多线程(GPU推荐), True=多进程(CPU推荐)
+    
+    # 确保输出目录存在
+    os.makedirs(output_dir, exist_ok=True)
+    
+    print("="*60)
+    print("OmniDocBench 并行评估开始")
+    print("="*60)
+    print(f"数据集路径: {dataset_path}")
+    print(f"输出目录: {output_dir}")
+    print(f"批处理大小: {batch_size}")
+    print(f"最大工作线程/进程数: {max_workers}")
+    print(f"使用GPU: {use_gpu}")
+    print(f"并行方式: {'多进程' if use_multiprocessing else '多线程'}")
+    
+    # 查找所有图像文件
+    image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.tiff']
+    image_files = []
+    
+    for ext in image_extensions:
+        image_files.extend(glob.glob(os.path.join(dataset_path, ext)))
+    
+    print(f"找到 {len(image_files)} 个图像文件")
+    
+    if not image_files:
+        print("未找到任何图像文件,程序终止")
+        return
+    
+    
+    # 开始处理
+    start_time = time.time()
+    
+    if use_multiprocessing:
+        # 多进程处理(推荐用于CPU)
+        print("使用多进程并行处理...")
+        results = parallel_process_with_multiprocessing(
+            image_files, batch_size, max_workers
+        )
+    else:
+        # 多线程处理(推荐用于GPU)
+        print("使用多线程并行处理...")
+        predictor = PPStructureV3ParallelPredictor(pipeline_config, output_path=output_dir, use_gpu=use_gpu)
+        results = predictor.parallel_process_with_threading(
+            image_files, batch_size, max_workers
+        )
+    
+    total_time = time.time() - start_time
+    
+    # 保存最终结果
+    output_file = os.path.join(output_dir, f"OmniDocBench_PPStructureV3_batch{batch_size}.json")
+    try:
+        # 统计信息
+        success_count = sum(1 for r in results if r.get('success', False))
+        error_count = len(results) - success_count
+        total_processing_time = sum(r.get('processing_time', 0) for r in results if r.get('success', False))
+        avg_processing_time = total_processing_time / success_count if success_count > 0 else 0
+        
+        print(f"总文件数: {len(image_files)}")
+        print(f"成功处理: {success_count}")
+        print(f"失败数量: {error_count}")
+        print(f"成功率: {success_count / len(image_files) * 100:.2f}%")
+        print(f"总耗时: {total_time:.2f}秒")
+        print(f"平均处理时间: {avg_processing_time:.2f}秒/张")
+        print(f"吞吐量: {len(image_files) / total_time:.2f}张/秒")
+        print(f"结果保存至: {output_file}")
+        
+        # 保存统计信息
+        stats = {
+            "total_files": len(image_files),
+            "success_count": success_count,
+            "error_count": error_count,
+            "success_rate": success_count / len(image_files),
+            "total_time": total_time,
+            "avg_processing_time": avg_processing_time,
+            "throughput": len(image_files) / total_time,
+            "batch_size": batch_size,
+            "max_workers": max_workers,
+            "use_gpu": use_gpu,
+            "use_multiprocessing": use_multiprocessing
+        }
+        results['stats'] = stats
+        with open(output_file, 'w', encoding='utf-8') as f:
+            json.dump(results, f, ensure_ascii=False, indent=2)
+        
+        print("\n" + "="*60)
+        print("处理完成!")
+        print("="*60)
+        
+    except Exception as e:
+        print(f"保存结果文件时发生错误: {str(e)}")
+        traceback.print_exc()
+
+if __name__ == "__main__":
+    main()