|
@@ -333,8 +333,14 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang
|
|
|
|
|
|
|
|
def model_init(model_name: str):
|
|
def model_init(model_name: str):
|
|
|
from transformers import LayoutLMv3ForTokenClassification
|
|
from transformers import LayoutLMv3ForTokenClassification
|
|
|
- device = torch.device(get_device())
|
|
|
|
|
-
|
|
|
|
|
|
|
+ device_name = get_device()
|
|
|
|
|
+ bf_16_support = False
|
|
|
|
|
+ if device_name.startswith("cuda"):
|
|
|
|
|
+ bf_16_support = torch.cuda.is_bf16_supported()
|
|
|
|
|
+ elif device_name.startswith("mps"):
|
|
|
|
|
+ bf_16_support = True
|
|
|
|
|
+
|
|
|
|
|
+ device = torch.device(device_name)
|
|
|
if model_name == 'layoutreader':
|
|
if model_name == 'layoutreader':
|
|
|
# 检测modelscope的缓存目录是否存在
|
|
# 检测modelscope的缓存目录是否存在
|
|
|
layoutreader_model_dir = get_local_layoutreader_model_dir()
|
|
layoutreader_model_dir = get_local_layoutreader_model_dir()
|
|
@@ -349,7 +355,10 @@ def model_init(model_name: str):
|
|
|
model = LayoutLMv3ForTokenClassification.from_pretrained(
|
|
model = LayoutLMv3ForTokenClassification.from_pretrained(
|
|
|
'hantian/layoutreader'
|
|
'hantian/layoutreader'
|
|
|
)
|
|
)
|
|
|
- model.to(device).eval()
|
|
|
|
|
|
|
+ if bf_16_support:
|
|
|
|
|
+ model.to(device).eval().bfloat16()
|
|
|
|
|
+ else:
|
|
|
|
|
+ model.to(device).eval()
|
|
|
else:
|
|
else:
|
|
|
logger.error('model name not allow')
|
|
logger.error('model name not allow')
|
|
|
exit(1)
|
|
exit(1)
|