소스 검색

fix(model): improve VRAM detection and handling

- Refactor VRAM detection logic for better readability and efficiency
- Add fallback mechanism for unknown VRAM sizes
- Improve device checking in get_vram function
myhloli 7 달 전
부모
커밋
d32a63cada
2개의 변경된 파일8개의 추가작업 그리고 3개의 파일을 삭제
  1. 7 2
      magic_pdf/model/doc_analyze_by_custom_model.py
  2. 1 1
      magic_pdf/model/sub_modules/model_utils.py

+ 7 - 2
magic_pdf/model/doc_analyze_by_custom_model.py

@@ -255,8 +255,9 @@ def may_batch_image_analyze(
             torch.npu.set_compile_mode(jit_compile=False)
 
     if str(device).startswith('npu') or str(device).startswith('cuda'):
-        gpu_memory = int(os.getenv('VIRTUAL_VRAM_SIZE', round(get_vram(device))))
-        if gpu_memory is not None:
+        vram = get_vram(device)
+        if vram is not None:
+            gpu_memory = int(os.getenv('VIRTUAL_VRAM_SIZE', round(vram)))
             if gpu_memory >= 16:
                 batch_ratio = 16
             elif gpu_memory >= 12:
@@ -268,6 +269,10 @@ def may_batch_image_analyze(
             else:
                 batch_ratio = 1
             logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
+        else:
+            # Default batch_ratio when VRAM can't be determined
+            batch_ratio = 1
+            logger.info(f'Could not determine GPU memory, using default batch_ratio: {batch_ratio}')
 
 
     # doc_analyze_start = time.time()

+ 1 - 1
magic_pdf/model/sub_modules/model_utils.py

@@ -57,7 +57,7 @@ def clean_vram(device, vram_threshold=8):
 
 
 def get_vram(device):
-    if torch.cuda.is_available() and device != 'cpu':
+    if torch.cuda.is_available() and str(device).startswith("cuda"):
         total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3)  # 将字节转换为 GB
         return total_memory
     elif str(device).startswith("npu"):