Bläddra i källkod

Merge pull request #1661 from myhloli/dev

perf(model): optimize batch ratio for different GPU memory sizes
Xiaomeng Zhao 9 månader sedan
förälder
incheckning
9bb2d58139
1 ändrade filer med 11 tillägg och 12 borttagningar
  1. 11 12
      magic_pdf/model/doc_analyze_by_custom_model.py

+ 11 - 12
magic_pdf/model/doc_analyze_by_custom_model.py

@@ -178,21 +178,20 @@ def doc_analyze(
         gpu_memory = int(os.getenv("VIRTUAL_VRAM_SIZE", round(get_vram(device))))
         gpu_memory = int(os.getenv("VIRTUAL_VRAM_SIZE", round(get_vram(device))))
         if gpu_memory is not None and gpu_memory >= 8:
         if gpu_memory is not None and gpu_memory >= 8:
 
 
-            if 8 <= gpu_memory < 10:
-                batch_ratio = 2
-            elif 10 <= gpu_memory <= 12:
-                batch_ratio = 4
-            elif 12 < gpu_memory <= 20:
-                batch_ratio = 8
-            elif 20 < gpu_memory <= 32:
+            if gpu_memory >= 40:
+                batch_ratio = 32
+            elif gpu_memory >=20:
                 batch_ratio = 16
                 batch_ratio = 16
+            elif gpu_memory >= 16:
+                batch_ratio = 8
+            elif gpu_memory >= 10:
+                batch_ratio = 4
             else:
             else:
-                batch_ratio = 32
+                batch_ratio = 2
 
 
-            if batch_ratio >= 1:
-                logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
-                batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
-                batch_analyze = True
+            logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
+            batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
+            batch_analyze = True
 
 
     model_json = []
     model_json = []
     doc_analyze_start = time.time()
     doc_analyze_start = time.time()