|
|
@@ -179,13 +179,14 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
|
|
|
def model_init(model_name: str):
|
|
|
from transformers import LayoutLMv3ForTokenClassification
|
|
|
device_name = get_device()
|
|
|
+ device = torch.device(device_name)
|
|
|
bf_16_support = False
|
|
|
if device_name.startswith("cuda"):
|
|
|
- bf_16_support = torch.cuda.is_bf16_supported()
|
|
|
+ if torch.cuda.get_device_properties(device).major >= 8:
|
|
|
+ bf_16_support = True
|
|
|
elif device_name.startswith("mps"):
|
|
|
bf_16_support = True
|
|
|
|
|
|
- device = torch.device(device_name)
|
|
|
if model_name == 'layoutreader':
|
|
|
# 检测modelscope的缓存目录是否存在
|
|
|
layoutreader_model_dir = os.path.join(auto_download_and_get_model_root_path(ModelPath.layout_reader), ModelPath.layout_reader)
|