|
|
@@ -1,3 +1,4 @@
|
|
|
+import os
|
|
|
import statistics
|
|
|
import time
|
|
|
|
|
|
@@ -9,6 +10,7 @@ import torch
|
|
|
|
|
|
from magic_pdf.libs.clean_memory import clean_memory
|
|
|
from magic_pdf.libs.commons import fitz, get_delta_time
|
|
|
+from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir
|
|
|
from magic_pdf.libs.convert_utils import dict_to_list
|
|
|
from magic_pdf.libs.drop_reason import DropReason
|
|
|
from magic_pdf.libs.hash_utils import compute_md5
|
|
|
@@ -95,7 +97,7 @@ def replace_text_span(pymu_spans, ocr_spans):
|
|
|
return list(filter(lambda x: x["type"] != ContentType.Text, ocr_spans)) + pymu_spans
|
|
|
|
|
|
|
|
|
-def model_init(model_name: str, local_path=None):
|
|
|
+def model_init(model_name: str):
|
|
|
from transformers import LayoutLMv3ForTokenClassification
|
|
|
if torch.cuda.is_available():
|
|
|
device = torch.device("cuda")
|
|
|
@@ -108,9 +110,13 @@ def model_init(model_name: str, local_path=None):
|
|
|
supports_bfloat16 = False
|
|
|
|
|
|
if model_name == "layoutreader":
|
|
|
- if local_path:
|
|
|
- model = LayoutLMv3ForTokenClassification.from_pretrained(local_path)
|
|
|
+ # 检测modelscope的缓存目录是否存在
|
|
|
+ layoutreader_model_dir = get_local_layoutreader_model_dir()
|
|
|
+ if os.path.exists(layoutreader_model_dir):
|
|
|
+ model = LayoutLMv3ForTokenClassification.from_pretrained(layoutreader_model_dir)
|
|
|
else:
|
|
|
+ logger.warning(
|
|
|
+ f"local layoutreader model not exists, use online model from huggingface")
|
|
|
model = LayoutLMv3ForTokenClassification.from_pretrained("hantian/layoutreader")
|
|
|
# 检查设备是否支持 bfloat16
|
|
|
if supports_bfloat16:
|
|
|
@@ -131,12 +137,9 @@ class ModelSingleton:
|
|
|
cls._instance = super().__new__(cls)
|
|
|
return cls._instance
|
|
|
|
|
|
- def get_model(self, model_name: str, local_path=None):
|
|
|
+ def get_model(self, model_name: str):
|
|
|
if model_name not in self._models:
|
|
|
- if local_path:
|
|
|
- self._models[model_name] = model_init(model_name=model_name, local_path=local_path)
|
|
|
- else:
|
|
|
- self._models[model_name] = model_init(model_name=model_name)
|
|
|
+ self._models[model_name] = model_init(model_name=model_name)
|
|
|
return self._models[model_name]
|
|
|
|
|
|
|