|
|
@@ -338,17 +338,7 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang
|
|
|
|
|
|
def model_init(model_name: str):
|
|
|
from transformers import LayoutLMv3ForTokenClassification
|
|
|
- device = get_device()
|
|
|
- if torch.cuda.is_available():
|
|
|
- device = torch.device('cuda')
|
|
|
- elif str(device).startswith("npu"):
|
|
|
- import torch_npu
|
|
|
- if torch_npu.npu.is_available():
|
|
|
- device = torch.device('npu')
|
|
|
- else:
|
|
|
- device = torch.device('cpu')
|
|
|
- else:
|
|
|
- device = torch.device('cpu')
|
|
|
+ device = torch.device(get_device())
|
|
|
|
|
|
if model_name == 'layoutreader':
|
|
|
# 检测modelscope的缓存目录是否存在
|