Explorar o código

perf(magic_pdf): optimize batch processing for GPU

- Improve batch ratio calculation based on GPU memory
- Enhance performance for devices with 8GB or more VRAM
myhloli hai 10 meses
pai
achega
55447c8b64
Modificáronse 1 ficheiros con 12 adicións e 1 borrados
  1. 12 1
      magic_pdf/model/doc_analyze_by_custom_model.py

+ 12 - 1
magic_pdf/model/doc_analyze_by_custom_model.py

@@ -177,7 +177,18 @@ def doc_analyze(
     if torch.cuda.is_available() and device != 'cpu' or npu_support:
         gpu_memory = int(os.getenv("VIRTUAL_VRAM_SIZE", round(get_vram(device))))
         if gpu_memory is not None and gpu_memory >= 8:
-            batch_ratio = int(gpu_memory-6)
+
+            if 8 <= gpu_memory <= 10:
+                batch_ratio = 2
+            elif 10 < gpu_memory <= 12:
+                batch_ratio = 4
+            elif 12 < gpu_memory <= 16:
+                batch_ratio = 8
+            elif 16 < gpu_memory <= 24:
+                batch_ratio = 16
+            else:
+                batch_ratio = 32
+
             if batch_ratio >= 1:
                 logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
                 batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)