ソースを参照

feat(layoutreader): support local model directory and improve model loading

- Add function to get local LayoutReader model directory- Check and use local model directory if available
- Fall back to online model if local directory not found
- Update model initialization to support local path
- Refactor model loading in singleton class
myhloli 1 年間 前
コミット
ded2818ac2

+ 1 - 0
docs/download_models.py

@@ -1,4 +1,5 @@
 # use modelscope sdk download models
 from modelscope import snapshot_download
 model_dir = snapshot_download('opendatalab/PDF-Extract-Kit')
+layoutreader_model_dir = snapshot_download('ppaanngggg/layoutreader')
 print(f"model dir is: {model_dir}/models")

+ 1 - 0
docs/download_models_hf.py

@@ -1,3 +1,4 @@
 from huggingface_hub import snapshot_download
 model_dir = snapshot_download('opendatalab/PDF-Extract-Kit')
+layoutreader_model_dir = snapshot_download('hantian/layoutreader')
 print(f"model dir is: {model_dir}/models")

+ 1 - 0
magic-pdf.template.json

@@ -4,6 +4,7 @@
         "bucket-name-2":["ak", "sk", "endpoint"]
     },
     "models-dir":"/tmp/models",
+    "layoutreader-model-dir":"/tmp/layoutreader",
     "device-mode":"cpu",
     "table-config": {
         "model": "TableMaster",

+ 12 - 0
magic_pdf/libs/config_reader.py

@@ -67,6 +67,18 @@ def get_local_models_dir():
         return models_dir
 
 
+def get_local_layoutreader_model_dir():
+    config = read_config()
+    layoutreader_model_dir = config.get("layoutreader-model-dir")
+    if layoutreader_model_dir is None or not os.path.exists(layoutreader_model_dir):
+        home_dir = os.path.expanduser("~")
+        layoutreader_at_modelscope_dir_path = os.path.join(home_dir, ".cache/modelscope/hub/ppaanngggg/layoutreader")
+        logger.warning(f"'layoutreader-model-dir' not exists, use {layoutreader_at_modelscope_dir_path} as default")
+        return layoutreader_at_modelscope_dir_path
+    else:
+        return layoutreader_model_dir
+
+
 def get_device():
     config = read_config()
     device = config.get("device-mode")

+ 11 - 8
magic_pdf/pdf_parse_union_core_v2.py

@@ -1,3 +1,4 @@
+import os
 import statistics
 import time
 
@@ -9,6 +10,7 @@ import torch
 
 from magic_pdf.libs.clean_memory import clean_memory
 from magic_pdf.libs.commons import fitz, get_delta_time
+from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir
 from magic_pdf.libs.convert_utils import dict_to_list
 from magic_pdf.libs.drop_reason import DropReason
 from magic_pdf.libs.hash_utils import compute_md5
@@ -95,7 +97,7 @@ def replace_text_span(pymu_spans, ocr_spans):
     return list(filter(lambda x: x["type"] != ContentType.Text, ocr_spans)) + pymu_spans
 
 
-def model_init(model_name: str, local_path=None):
+def model_init(model_name: str):
     from transformers import LayoutLMv3ForTokenClassification
     if torch.cuda.is_available():
         device = torch.device("cuda")
@@ -108,9 +110,13 @@ def model_init(model_name: str, local_path=None):
         supports_bfloat16 = False
 
     if model_name == "layoutreader":
-        if local_path:
-            model = LayoutLMv3ForTokenClassification.from_pretrained(local_path)
+        # 检测modelscope的缓存目录是否存在
+        layoutreader_model_dir = get_local_layoutreader_model_dir()
+        if os.path.exists(layoutreader_model_dir):
+            model = LayoutLMv3ForTokenClassification.from_pretrained(layoutreader_model_dir)
         else:
+            logger.warning(
+                f"local layoutreader model not exists, use online model from huggingface")
             model = LayoutLMv3ForTokenClassification.from_pretrained("hantian/layoutreader")
         # 检查设备是否支持 bfloat16
         if supports_bfloat16:
@@ -131,12 +137,9 @@ class ModelSingleton:
             cls._instance = super().__new__(cls)
         return cls._instance
 
-    def get_model(self, model_name: str, local_path=None):
+    def get_model(self, model_name: str):
         if model_name not in self._models:
-            if local_path:
-                self._models[model_name] = model_init(model_name=model_name, local_path=local_path)
-            else:
-                self._models[model_name] = model_init(model_name=model_name)
+            self._models[model_name] = model_init(model_name=model_name)
         return self._models[model_name]