Ver código fonte

feat(zhch): 增加GPU缓存清理和单精度设置,优化worker初始化

zhch158_admin 3 meses atrás
pai
commit
596eb32176
1 arquivos alterados com 12 adições e 0 exclusões
  1. 12 0
      zhch/ppstructurev3_multi_gpu_multiprocess_official.py

+ 12 - 0
zhch/ppstructurev3_multi_gpu_multiprocess_official.py

@@ -48,7 +48,14 @@ def worker(pipeline_name_or_config_path: str,
         # 设置子进程的CUDA设备
         device_id = device.split(':')[1] if ':' in device else '0'
         os.environ['CUDA_VISIBLE_DEVICES'] = device_id
+
+        # 设置paddle使用单精度,避免混合精度问题
+        paddle.set_default_dtype("float32")
         
+        # 清理GPU缓存
+        if paddle.device.cuda.device_count() > 0:
+            paddle.device.cuda.empty_cache()        
+
         # 直接创建pipeline,让PaddleX自动处理设备初始化
         pipeline = create_pipeline(pipeline_name_or_config_path, device=device)
         print(f"Worker {worker_id} initialized with device {device}")
@@ -174,6 +181,11 @@ def worker(pipeline_name_or_config_path: str,
         print(f"Worker {worker_id} ({device}) initialization failed: {e}", file=sys.stderr)
         traceback.print_exc()
     finally:
+        # 清理GPU缓存
+        try:
+            paddle.device.cuda.empty_cache()
+        except Exception as e:
+            print(f"Error clearing GPU cache: {e}", file=sys.stderr)
         print(f"Worker {worker_id} ({device}) finished")
 
 def parallel_process_with_official_approach(image_paths: List[str],