Bladeren bron

Merge pull request #698 from myhloli/dev

feat(layoutreader): support local model directory and improve model loading
Xiaomeng Zhao 1 jaar geleden
bovenliggende
commit
8786d208fb

+ 1 - 0
README.md

@@ -395,6 +395,7 @@ This project currently uses PyMuPDF to achieve advanced functionality. However,
 - [StructEqTable](https://github.com/UniModal4Reasoning/StructEqTable-Deploy)
 - [StructEqTable](https://github.com/UniModal4Reasoning/StructEqTable-Deploy)
 - [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR)
 - [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR)
 - [PyMuPDF](https://github.com/pymupdf/PyMuPDF)
 - [PyMuPDF](https://github.com/pymupdf/PyMuPDF)
+- [layoutreader](https://github.com/ppaanngggg/layoutreader)
 - [fast-langdetect](https://github.com/LlmKira/fast-langdetect)
 - [fast-langdetect](https://github.com/LlmKira/fast-langdetect)
 - [pdfminer.six](https://github.com/pdfminer/pdfminer.six)
 - [pdfminer.six](https://github.com/pdfminer/pdfminer.six)
 
 

+ 1 - 0
README_zh-CN.md

@@ -400,6 +400,7 @@ TODO
 - [StructEqTable](https://github.com/UniModal4Reasoning/StructEqTable-Deploy)
 - [StructEqTable](https://github.com/UniModal4Reasoning/StructEqTable-Deploy)
 - [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR)
 - [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR)
 - [PyMuPDF](https://github.com/pymupdf/PyMuPDF)
 - [PyMuPDF](https://github.com/pymupdf/PyMuPDF)
+- [layoutreader](https://github.com/ppaanngggg/layoutreader)
 - [fast-langdetect](https://github.com/LlmKira/fast-langdetect)
 - [fast-langdetect](https://github.com/LlmKira/fast-langdetect)
 - [pdfminer.six](https://github.com/pdfminer/pdfminer.six)
 - [pdfminer.six](https://github.com/pdfminer/pdfminer.six)
 
 

+ 1 - 0
docs/download_models.py

@@ -1,4 +1,5 @@
 # use modelscope sdk download models
 # use modelscope sdk download models
 from modelscope import snapshot_download
 from modelscope import snapshot_download
 model_dir = snapshot_download('opendatalab/PDF-Extract-Kit')
 model_dir = snapshot_download('opendatalab/PDF-Extract-Kit')
+layoutreader_model_dir = snapshot_download('ppaanngggg/layoutreader')
 print(f"model dir is: {model_dir}/models")
 print(f"model dir is: {model_dir}/models")

+ 1 - 0
docs/download_models_hf.py

@@ -1,3 +1,4 @@
 from huggingface_hub import snapshot_download
 from huggingface_hub import snapshot_download
 model_dir = snapshot_download('opendatalab/PDF-Extract-Kit')
 model_dir = snapshot_download('opendatalab/PDF-Extract-Kit')
+layoutreader_model_dir = snapshot_download('hantian/layoutreader')
 print(f"model dir is: {model_dir}/models")
 print(f"model dir is: {model_dir}/models")

+ 7 - 0
docs/how_to_download_models_zh_cn.md

@@ -38,6 +38,13 @@ python脚本执行完毕后,会输出模型下载目录
 
 
 如此前通过 git lfs 下载过模型文件,可以进入到之前的下载目录中,通过`git pull`命令更新模型。
 如此前通过 git lfs 下载过模型文件,可以进入到之前的下载目录中,通过`git pull`命令更新模型。
 
 
+> 0.9.x及以后版本由于新增layout排序模型,且该模型和此前的模型不在同一仓库,不能通过`git pull`命令更新,需要单独下载。
+> 
+>``` 
+>from modelscope import snapshot_download
+>snapshot_download('ppaanngggg/layoutreader')
+>```
+
 ## 2. 通过 Hugging Face 或 Model Scope 下载过模型
 ## 2. 通过 Hugging Face 或 Model Scope 下载过模型
 
 
 如此前通过 HuggingFace 或 Model Scope 下载过模型,可以重复执行此前的模型下载python脚本,将会自动将模型目录更新到最新版本。
 如此前通过 HuggingFace 或 Model Scope 下载过模型,可以重复执行此前的模型下载python脚本,将会自动将模型目录更新到最新版本。

+ 1 - 0
magic-pdf.template.json

@@ -4,6 +4,7 @@
         "bucket-name-2":["ak", "sk", "endpoint"]
         "bucket-name-2":["ak", "sk", "endpoint"]
     },
     },
     "models-dir":"/tmp/models",
     "models-dir":"/tmp/models",
+    "layoutreader-model-dir":"/tmp/layoutreader",
     "device-mode":"cpu",
     "device-mode":"cpu",
     "table-config": {
     "table-config": {
         "model": "TableMaster",
         "model": "TableMaster",

+ 12 - 0
magic_pdf/libs/config_reader.py

@@ -67,6 +67,18 @@ def get_local_models_dir():
         return models_dir
         return models_dir
 
 
 
 
+def get_local_layoutreader_model_dir():
+    config = read_config()
+    layoutreader_model_dir = config.get("layoutreader-model-dir")
+    if layoutreader_model_dir is None or not os.path.exists(layoutreader_model_dir):
+        home_dir = os.path.expanduser("~")
+        layoutreader_at_modelscope_dir_path = os.path.join(home_dir, ".cache/modelscope/hub/ppaanngggg/layoutreader")
+        logger.warning(f"'layoutreader-model-dir' not exists, use {layoutreader_at_modelscope_dir_path} as default")
+        return layoutreader_at_modelscope_dir_path
+    else:
+        return layoutreader_model_dir
+
+
 def get_device():
 def get_device():
     config = read_config()
     config = read_config()
     device = config.get("device-mode")
     device = config.get("device-mode")

+ 11 - 8
magic_pdf/pdf_parse_union_core_v2.py

@@ -1,3 +1,4 @@
+import os
 import statistics
 import statistics
 import time
 import time
 
 
@@ -9,6 +10,7 @@ import torch
 
 
 from magic_pdf.libs.clean_memory import clean_memory
 from magic_pdf.libs.clean_memory import clean_memory
 from magic_pdf.libs.commons import fitz, get_delta_time
 from magic_pdf.libs.commons import fitz, get_delta_time
+from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir
 from magic_pdf.libs.convert_utils import dict_to_list
 from magic_pdf.libs.convert_utils import dict_to_list
 from magic_pdf.libs.drop_reason import DropReason
 from magic_pdf.libs.drop_reason import DropReason
 from magic_pdf.libs.hash_utils import compute_md5
 from magic_pdf.libs.hash_utils import compute_md5
@@ -95,7 +97,7 @@ def replace_text_span(pymu_spans, ocr_spans):
     return list(filter(lambda x: x["type"] != ContentType.Text, ocr_spans)) + pymu_spans
     return list(filter(lambda x: x["type"] != ContentType.Text, ocr_spans)) + pymu_spans
 
 
 
 
-def model_init(model_name: str, local_path=None):
+def model_init(model_name: str):
     from transformers import LayoutLMv3ForTokenClassification
     from transformers import LayoutLMv3ForTokenClassification
     if torch.cuda.is_available():
     if torch.cuda.is_available():
         device = torch.device("cuda")
         device = torch.device("cuda")
@@ -108,9 +110,13 @@ def model_init(model_name: str, local_path=None):
         supports_bfloat16 = False
         supports_bfloat16 = False
 
 
     if model_name == "layoutreader":
     if model_name == "layoutreader":
-        if local_path:
-            model = LayoutLMv3ForTokenClassification.from_pretrained(local_path)
+        # 检测modelscope的缓存目录是否存在
+        layoutreader_model_dir = get_local_layoutreader_model_dir()
+        if os.path.exists(layoutreader_model_dir):
+            model = LayoutLMv3ForTokenClassification.from_pretrained(layoutreader_model_dir)
         else:
         else:
+            logger.warning(
+                f"local layoutreader model not exists, use online model from huggingface")
             model = LayoutLMv3ForTokenClassification.from_pretrained("hantian/layoutreader")
             model = LayoutLMv3ForTokenClassification.from_pretrained("hantian/layoutreader")
         # 检查设备是否支持 bfloat16
         # 检查设备是否支持 bfloat16
         if supports_bfloat16:
         if supports_bfloat16:
@@ -131,12 +137,9 @@ class ModelSingleton:
             cls._instance = super().__new__(cls)
             cls._instance = super().__new__(cls)
         return cls._instance
         return cls._instance
 
 
-    def get_model(self, model_name: str, local_path=None):
+    def get_model(self, model_name: str):
         if model_name not in self._models:
         if model_name not in self._models:
-            if local_path:
-                self._models[model_name] = model_init(model_name=model_name, local_path=local_path)
-            else:
-                self._models[model_name] = model_init(model_name=model_name)
+            self._models[model_name] = model_init(model_name=model_name)
         return self._models[model_name]
         return self._models[model_name]