浏览代码

refactor(model): update init methods and improve model loading logic

zhaoxiaomeng 1 年之前
父节点
当前提交
4101c35796
共有 2 个文件被更改,包括 36 次插入29 次删除
  1. 1 1
      magic_pdf/model/__init__.py
  2. 35 28
      magic_pdf/model/pdf_extract_kit.py

+ 1 - 1
magic_pdf/model/__init__.py

@@ -1,2 +1,2 @@
-__use_inside_model__ = True
+__use_inside_model__ = False
 __model_mode__ = "full"

+ 35 - 28
magic_pdf/model/pdf_extract_kit.py

@@ -1,31 +1,32 @@
 import os
-import time
-
 import cv2
-import numpy as np
 import yaml
-from PIL import Image
-from ultralytics import YOLO
+import time
+import argparse
+import numpy as np
+import torch
 from loguru import logger
 
-from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
+from paddleocr import draw_ocr
+from PIL import Image
+from torchvision import transforms
+from torch.utils.data import Dataset, DataLoader
+from ultralytics import YOLO
 from unimernet.common.config import Config
 import unimernet.tasks as tasks
 from unimernet.processors import load_processor
-import argparse
-from torchvision import transforms
-from torch.utils.data import Dataset, DataLoader
 
+from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
 from magic_pdf.model.pek_sub_modules.post_process import get_croped_image, latex_rm_whitespace
 from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
 
 
-def layout_model_init(weight, config_file, device):
-    model = Layoutlmv3_Predictor(weight, config_file, device)
-    return model
+def mfd_model_init(weight):
+    mfd_model = YOLO(weight)
+    return mfd_model
 
 
-def mfr_model_init(weight_dir, cfg_path, device='cpu'):
+def mfr_model_init(weight_dir, cfg_path, _device_='cpu'):
     args = argparse.Namespace(cfg_path=cfg_path, options=None)
     cfg = Config(args)
     cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.bin")
@@ -33,11 +34,16 @@ def mfr_model_init(weight_dir, cfg_path, device='cpu'):
     cfg.config.model.tokenizer_config.path = weight_dir
     task = tasks.setup_task(cfg)
     model = task.build_model(cfg)
-    model = model.to(device)
+    model = model.to(_device_)
     vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
     return model, vis_processor
 
 
+def layout_model_init(weight, config_file, device):
+    model = Layoutlmv3_Predictor(weight, config_file, device)
+    return model
+
+
 class MathDataset(Dataset):
     def __init__(self, image_paths, transform=None):
         self.image_paths = image_paths
@@ -54,10 +60,11 @@ class MathDataset(Dataset):
             raw_image = self.image_paths[idx]
         if self.transform:
             image = self.transform(raw_image)
-        return image
+            return image
 
 
 class CustomPEKModel:
+
     def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
         """
         ======== model init ========
@@ -88,24 +95,24 @@ class CustomPEKModel:
         self.device = kwargs.get("device", self.configs["config"]["device"])
         logger.info("using device: {}".format(self.device))
         models_dir = kwargs.get("models_dir", os.path.join(root_dir, "resources", "models"))
-        # 初始化layout模型
-        self.layout_model = layout_model_init(
-            os.path.join(models_dir, self.configs['weights']['layout']),
-            os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml"),
-            device=self.device
-        )
+
         # 初始化公式识别
         if self.apply_formula:
             # 初始化公式检测模型
-            self.mfd_model = YOLO(model=str(os.path.join(models_dir, self.configs["weights"]["mfd"])))
+            self.mfd_model = mfd_model_init(str(os.path.join(models_dir, self.configs["weights"]["mfd"])))
+
             # 初始化公式解析模型
-            mfr_config_path = os.path.join(model_config_dir, 'UniMERNet', 'demo.yaml')
-            self.mfr_model, mfr_vis_processors = mfr_model_init(
-                os.path.join(models_dir, self.configs["weights"]["mfr"]),
-                mfr_config_path,
-                device=self.device
-            )
+            mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"]["mfr"]))
+            mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml"))
+            self.mfr_model, mfr_vis_processors = mfr_model_init(mfr_weight_dir, mfr_cfg_path, _device_=self.device)
             self.mfr_transform = transforms.Compose([mfr_vis_processors, ])
+
+        # 初始化layout模型
+        self.layout_model = Layoutlmv3_Predictor(
+            str(os.path.join(models_dir, self.configs['weights']['layout'])),
+            str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
+            device=self.device
+        )
         # 初始化ocr
         if self.apply_ocr:
             self.ocr_model = ModifiedPaddleOCR(show_log=show_log)