|
@@ -286,12 +286,9 @@ def model_init(model_name: str):
|
|
|
supports_bfloat16 = False
|
|
supports_bfloat16 = False
|
|
|
elif str(device).startswith("npu"):
|
|
elif str(device).startswith("npu"):
|
|
|
import torch_npu
|
|
import torch_npu
|
|
|
- if torch.npu.is_available():
|
|
|
|
|
|
|
+ if torch_npu.npu.is_available():
|
|
|
device = torch.device('npu')
|
|
device = torch.device('npu')
|
|
|
- if torch.npu.is_bf16_supported():
|
|
|
|
|
- supports_bfloat16 = True
|
|
|
|
|
- else:
|
|
|
|
|
- supports_bfloat16 = False
|
|
|
|
|
|
|
+ supports_bfloat16 = False
|
|
|
else:
|
|
else:
|
|
|
device = torch.device('cpu')
|
|
device = torch.device('cpu')
|
|
|
supports_bfloat16 = False
|
|
supports_bfloat16 = False
|