Просмотр исходного кода

feat(zhch): 优化多进程处理逻辑并增强错误处理

- 改进 worker 函数,增加异常捕获和处理
- 增加 None 类型检查以支持结束信号
- 优化进程结束逻辑,去除超时设置
- 保存结果统计和最终结果到 JSON 文件
- 更新示例配置,使用多个 GPU 设备
zhch158_admin 3 месяцев назад
Родитель
Сommit
e7cc5250e3
1 измененных файлов с 42 добавлено и 9 удалено
  1. 42 9
      zhch/ppstructurev3_multi_gpu_multiprocess_official.py

+ 42 - 9
zhch/ppstructurev3_multi_gpu_multiprocess_official.py

@@ -75,9 +75,18 @@ def worker(pipeline_name_or_config_path: str,
                 input_path = task_queue.get_nowait()
             except Empty:
                 should_end = True
+            except Exception as e:
+                # 处理其他可能的异常
+                print(f"Unexpected error while getting task: {e}", file=sys.stderr)
+                traceback.print_exc()
+                should_end = True
             else:
-                batch.append(input_path)
-            
+                # 检查是否为结束信号
+                if input_path is None:
+                    should_end = True
+                else:
+                    batch.append(input_path)            
+
             if batch and (len(batch) == batch_size or should_end):
                 try:
                     start_time = time.time()
@@ -157,6 +166,7 @@ def worker(pipeline_name_or_config_path: str,
                     result_queue.put(error_results)
                     
                     print(f"Error processing batch {batch} on {device}: {e}", file=sys.stderr)
+                    traceback.print_exc()
                 
                 batch.clear()
     
@@ -265,7 +275,7 @@ def parallel_process_with_official_approach(image_paths: List[str],
             
             while completed_count < total_images:
                 try:
-                    batch_results = result_queue.get(timeout=300)  # 5分钟超时
+                    batch_results = result_queue.get(timeout=600)  # 10分钟超时
                     all_results.extend(batch_results)
                     
                     # 更新进度条
@@ -285,9 +295,7 @@ def parallel_process_with_official_approach(image_paths: List[str],
         
         # 等待所有进程结束
         for p in processes:
-            p.join(timeout=10)
-            if p.is_alive():
-                p.terminate()
+            p.join()
     
     return all_results
 
@@ -309,7 +317,9 @@ def main():
     
     try:
         # 获取图像文件列表
-        input_dir = Path(args.input_dir)
+        input_dir = Path(args.input_dir).resolve()
+        output_dir = Path(args.output_dir).resolve()
+        print(f"Input dir: {input_dir}, Output dir: {output_dir}")
         if not input_dir.exists():
             print(f"Input directory does not exist: {input_dir}")
             return 1
@@ -340,7 +350,7 @@ def main():
             args.device,
             args.instances_per_device,
             args.batch_size,
-            args.output_dir
+            str(output_dir)
         )
         total_time = time.time() - start_time
         
@@ -356,6 +366,29 @@ def main():
         print(f"Success rate: {success_count / len(image_files) * 100:.2f}%")
         print(f"Total time: {total_time:.2f} seconds")
         print(f"Throughput: {len(image_files) / total_time:.2f} images/second")
+
+                # 保存结果统计
+        stats = {
+            "total_files": len(image_files),
+            "success_count": success_count,
+            "error_count": error_count,
+            "success_rate": success_count / len(image_files),
+            "total_time": total_time,
+            "throughput": len(image_files) / total_time,
+            "batch_size": args.batch_size,
+            "gpu_ids": args.device,
+            "pipelines_per_gpu": args.instances_per_device
+        }
+        
+        # 保存最终结果
+        output_file = os.path.join(output_dir, f"OmniDocBench_MultiGPU_batch{args.batch_size}.json")
+        final_results = {
+            "stats": stats,
+            "results": results
+        }
+        
+        with open(output_file, 'w', encoding='utf-8') as f:
+            json.dump(final_results, f, ensure_ascii=False, indent=2)
         
         return 0
         
@@ -385,7 +418,7 @@ if __name__ == "__main__":
             "input_dir": "../../OmniDocBench/OpenDataLab___OmniDocBench/images",
             "output_dir": "./OmniDocBench_Results_Official",
             "pipeline": "PP-StructureV3",
-            "device": "gpu:0",
+            "device": "gpu:0,1,2,3",
             "instances_per_device": 1,
             "batch_size": 4,
             # "test_mode": False