소스 검색

feat(config-reader): add models-dir and device-mode configurations

Add new configuration options for custom model directories and device modeselection. This allows users to specify the directory where models are stored
and choose between CPU and GPU modes for model inference. The configurations
are read from a JSON file and can be easily extended to support additional
options in the future.
myhloli 1 년 전
부모
커밋
695b357994

+ 36 - 36
magic_pdf/resources/models/README.md → docs/how_to_download_models.md

@@ -1,37 +1,37 @@
-#### Install Git LFS
-Before you begin, make sure Git Large File Storage (Git LFS) is installed on your system. Install it using the following command:
-
-```bash
-git lfs install
-```
-
-#### Download the Model from Hugging Face
-To download the `PDF-Extract-Kit` model from Hugging Face, use the following command:
-
-```bash
-git lfs clone https://huggingface.co/wanderkid/PDF-Extract-Kit
-```
-
-Ensure that Git LFS is enabled during the clone to properly download all large files.
-
-
-
-Put [model files]() here:
-
-```
-./
-├── Layout
-│   ├── config.json
-│   └── model_final.pth
-├── MFD
-│   └── weights.pt
-├── MFR
-│   └── UniMERNet
-│       ├── config.json
-│       ├── preprocessor_config.json
-│       ├── pytorch_model.bin
-│       ├── README.md
-│       ├── tokenizer_config.json
-│       └── tokenizer.json
-└── README.md
+#### Install Git LFS
+Before you begin, make sure Git Large File Storage (Git LFS) is installed on your system. Install it using the following command:
+
+```bash
+git lfs install
+```
+
+#### Download the Model from Hugging Face
+To download the `PDF-Extract-Kit` model from Hugging Face, use the following command:
+
+```bash
+git lfs clone https://huggingface.co/wanderkid/PDF-Extract-Kit
+```
+
+Ensure that Git LFS is enabled during the clone to properly download all large files.
+
+
+
+Put [model files]() here:
+
+```
+./
+├── Layout
+│   ├── config.json
+│   └── model_final.pth
+├── MFD
+│   └── weights.pt
+├── MFR
+│   └── UniMERNet
+│       ├── config.json
+│       ├── preprocessor_config.json
+│       ├── pytorch_model.bin
+│       ├── README.md
+│       ├── tokenizer_config.json
+│       └── tokenizer.json
+└── README.md
 ```

+ 3 - 1
magic-pdf.template.json

@@ -3,5 +3,7 @@
         "bucket-name-1":["ak", "sk", "endpoint"],
         "bucket-name-2":["ak", "sk", "endpoint"]
     },
-    "temp-output-dir":"/tmp"
+    "temp-output-dir":"/tmp",
+    "models-dir":"/tmp/models",
+    "device-mode":"cpu"
 }

+ 4 - 2
magic_pdf/cli/magicpdf.py

@@ -33,13 +33,15 @@ from magic_pdf.libs.draw_bbox import draw_layout_bbox, draw_span_bbox
 from magic_pdf.pipe.UNIPipe import UNIPipe
 from magic_pdf.pipe.OCRPipe import OCRPipe
 from magic_pdf.pipe.TXTPipe import TXTPipe
-from magic_pdf.libs.config_reader import get_s3_config
 from magic_pdf.libs.path_utils import (
     parse_s3path,
     parse_s3_range_params,
     remove_non_official_s3_args,
 )
-from magic_pdf.libs.config_reader import get_local_dir
+from magic_pdf.libs.config_reader import (
+    get_local_dir,
+    get_s3_config,
+)
 from magic_pdf.rw.S3ReaderWriter import S3ReaderWriter
 from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
 from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter

+ 10 - 0
magic_pdf/libs/config_reader.py

@@ -59,5 +59,15 @@ def get_local_dir():
     return config.get("temp-output-dir", "/tmp")
 
 
+def get_local_models_dir():
+    config = read_config()
+    return config.get("models-dir", "/tmp/models")
+
+
+def get_device():
+    config = read_config()
+    return config.get("device-mode", "cpu")
+
+
 if __name__ == "__main__":
     ak, sk, endpoint = get_s3_config("llm-raw")

+ 2 - 2
magic_pdf/model/__init__.py

@@ -1,2 +1,2 @@
-__use_inside_model__ = False
-__model_mode__ = "lite"
+__use_inside_model__ = True
+__model_mode__ = "full"

+ 6 - 1
magic_pdf/model/doc_analyze_by_custom_model.py

@@ -3,6 +3,8 @@ import time
 import fitz
 import numpy as np
 from loguru import logger
+
+from magic_pdf.libs.config_reader import get_local_models_dir, get_device
 from magic_pdf.model.model_list import MODEL
 import magic_pdf.model as model_config
 
@@ -61,7 +63,10 @@ def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False):
             custom_model = CustomPaddleModel(ocr=ocr, show_log=show_log)
         elif model == MODEL.PEK:
             from magic_pdf.model.pdf_extract_kit import CustomPEKModel
-            custom_model = CustomPEKModel(ocr=ocr, show_log=show_log)
+            # 从配置文件读取model-dir和device
+            local_models_dir = get_local_models_dir()
+            device = get_device()
+            custom_model = CustomPEKModel(ocr=ocr, show_log=show_log, models_dir=local_models_dir, device=device)
         else:
             logger.error("Not allow model_name!")
             exit(1)

+ 9 - 5
magic_pdf/model/pdf_extract_kit.py

@@ -7,6 +7,7 @@ import yaml
 from PIL import Image
 from ultralytics import YOLO
 from loguru import logger
+
 from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
 from unimernet.common.config import Config
 import unimernet.tasks as tasks
@@ -84,23 +85,26 @@ class CustomPEKModel:
         )
         assert self.apply_layout, "DocAnalysis must contain layout model."
         # 初始化解析方案
-        self.device = self.configs["config"]["device"]
+        self.device = kwargs.get("device", self.configs["config"]["device"])
         logger.info("using device: {}".format(self.device))
+        models_dir = kwargs.get("models_dir", os.path.join(root_dir, "resources", "models"))
         # 初始化layout模型
         self.layout_model = layout_model_init(
-            os.path.join(root_dir, self.configs['weights']['layout']),
+            os.path.join(models_dir, self.configs['weights']['layout']),
             os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml"),
             device=self.device
         )
         # 初始化公式识别
         if self.apply_formula:
             # 初始化公式检测模型
-            self.mfd_model = YOLO(model=str(os.path.join(root_dir, self.configs["weights"]["mfd"])))
+            self.mfd_model = YOLO(model=str(os.path.join(models_dir, self.configs["weights"]["mfd"])))
             # 初始化公式解析模型
             mfr_config_path = os.path.join(model_config_dir, 'UniMERNet', 'demo.yaml')
             self.mfr_model, mfr_vis_processors = mfr_model_init(
-                os.path.join(root_dir, self.configs["weights"]["mfr"]), mfr_config_path,
-                device=self.device)
+                os.path.join(models_dir, self.configs["weights"]["mfr"]),
+                mfr_config_path,
+                device=self.device
+            )
             self.mfr_transform = transforms.Compose([mfr_vis_processors, ])
         # 初始化ocr
         if self.apply_ocr:

+ 3 - 3
magic_pdf/resources/model_config/model_configs.yaml

@@ -4,6 +4,6 @@ config:
   formula: True
 
 weights:
-  layout: resources/models/Layout/model_final.pth
-  mfd: resources/models/MFD/weights.pt
-  mfr: resources/models/MFR/UniMERNet
+  layout: Layout/model_final.pth
+  mfd: MFD/weights.pt
+  mfr: MFR/UniMERNet