|
|
@@ -341,21 +341,14 @@ def model_init(model_name: str):
|
|
|
device = get_device()
|
|
|
if torch.cuda.is_available():
|
|
|
device = torch.device('cuda')
|
|
|
- if torch.cuda.is_bf16_supported():
|
|
|
- supports_bfloat16 = True
|
|
|
- else:
|
|
|
- supports_bfloat16 = False
|
|
|
elif str(device).startswith("npu"):
|
|
|
import torch_npu
|
|
|
if torch_npu.npu.is_available():
|
|
|
device = torch.device('npu')
|
|
|
- supports_bfloat16 = False
|
|
|
else:
|
|
|
device = torch.device('cpu')
|
|
|
- supports_bfloat16 = False
|
|
|
else:
|
|
|
device = torch.device('cpu')
|
|
|
- supports_bfloat16 = False
|
|
|
|
|
|
if model_name == 'layoutreader':
|
|
|
# 检测modelscope的缓存目录是否存在
|
|
|
@@ -371,9 +364,6 @@ def model_init(model_name: str):
|
|
|
model = LayoutLMv3ForTokenClassification.from_pretrained(
|
|
|
'hantian/layoutreader'
|
|
|
)
|
|
|
- # 检查设备是否支持 bfloat16
|
|
|
- if supports_bfloat16:
|
|
|
- model.bfloat16()
|
|
|
model.to(device).eval()
|
|
|
else:
|
|
|
logger.error('model name not allow')
|