Просмотр исходного кода

fix(npu): correct module name for NPU operations

- Update `clean_memory.py` to use `torch_npu.npu` instead of `torch.npu`
- Update `model_utils.py` to use `torch_npu.npu` instead of `torch.npu`
- Simplify NPU availability check and bfloat16 support in `pdf_parse_union_core_v2.py`
myhloli 10 месяцев назад
Родитель
Сommit
2684e7753b

+ 2 - 3
magic_pdf/libs/clean_memory.py

@@ -10,7 +10,6 @@ def clean_memory(device='cuda'):
             torch.cuda.ipc_collect()
             torch.cuda.ipc_collect()
     elif str(device).startswith("npu"):
     elif str(device).startswith("npu"):
         import torch_npu
         import torch_npu
-        if torch.npu.is_available():
-            torch_npu.empty_cache()
-            torch_npu.ipc_collect()
+        if torch_npu.npu.is_available():
+            torch_npu.npu.empty_cache()
     gc.collect()
     gc.collect()

+ 2 - 2
magic_pdf/model/sub_modules/model_utils.py

@@ -56,8 +56,8 @@ def get_vram(device):
         return total_memory
         return total_memory
     elif str(device).startswith("npu"):
     elif str(device).startswith("npu"):
         import torch_npu
         import torch_npu
-        if torch.npu.is_available():
-            total_memory = torch.npu.get_device_properties(device).total_memory / (1024 ** 3)  # 转为 GB
+        if torch_npu.npu.is_available():
+            total_memory = torch_npu.npu.get_device_properties(device).total_memory / (1024 ** 3)  # 转为 GB
             return total_memory
             return total_memory
     else:
     else:
         return None
         return None

+ 2 - 5
magic_pdf/pdf_parse_union_core_v2.py

@@ -286,12 +286,9 @@ def model_init(model_name: str):
             supports_bfloat16 = False
             supports_bfloat16 = False
     elif str(device).startswith("npu"):
     elif str(device).startswith("npu"):
         import torch_npu
         import torch_npu
-        if torch.npu.is_available():
+        if torch_npu.npu.is_available():
             device = torch.device('npu')
             device = torch.device('npu')
-            if torch.npu.is_bf16_supported():
-                supports_bfloat16 = True
-            else:
-                supports_bfloat16 = False
+            supports_bfloat16 = False
         else:
         else:
             device = torch.device('cpu')
             device = torch.device('cpu')
             supports_bfloat16 = False
             supports_bfloat16 = False