Ver código fonte

fix: add conditional imports for torch and torch_npu in model_utils.py

myhloli 5 meses atrás
pai
commit
d58b24b5dd
1 arquivos alterados com 6 adições e 6 exclusões
  1. 6 6
      mineru/utils/model_utils.py

+ 6 - 6
mineru/utils/model_utils.py

@@ -6,6 +6,12 @@ import numpy as np
 
 from mineru.utils.boxbase import get_minbox_if_overlap_by_ratio
 
+try:
+    import torch
+    import torch_npu
+except ImportError:
+    pass
+
 
 def crop_img(input_res, input_img, crop_paste_x=0, crop_paste_y=0):
 
@@ -297,14 +303,11 @@ def get_res_list_from_layout_res(layout_res, iou_threshold=0.7, overlap_threshol
 
 
 def clean_memory(device='cuda'):
-    import torch
-
     if device == 'cuda':
         if torch.cuda.is_available():
             torch.cuda.empty_cache()
             torch.cuda.ipc_collect()
     elif str(device).startswith("npu"):
-        import torch_npu
         if torch_npu.npu.is_available():
             torch_npu.npu.empty_cache()
     elif str(device).startswith("mps"):
@@ -322,13 +325,10 @@ def clean_vram(device, vram_threshold=8):
 
 
 def get_vram(device):
-    import torch
-
     if torch.cuda.is_available() and str(device).startswith("cuda"):
         total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3)  # 将字节转换为 GB
         return total_memory
     elif str(device).startswith("npu"):
-        import torch_npu
         if torch_npu.npu.is_available():
             total_memory = torch_npu.npu.get_device_properties(device).total_memory / (1024 ** 3)  # 转为 GB
             return total_memory