Ver Fonte

refactor(pdf_extract_kit): implement singleton pattern for atomic models (#533)

Refactor the pdf_extract_kit module to utilize a singleton pattern when initializing
atomic models. This change ensures that atomic models are instantiated at most once,
optimizing memory usage and reducing redundant initialization steps. The AtomModelSingleton
class now manages the instantiation and retrieval of atomic models, improving the
overall structure and efficiency of the codebase.
Xiaomeng Zhao há 1 ano atrás
pai
commit
aac9109414
2 ficheiros alterados com 111 adições e 16 exclusões
  1. 8 0
      magic_pdf/model/model_list.py
  2. 103 16
      magic_pdf/model/pdf_extract_kit.py

+ 8 - 0
magic_pdf/model/model_list.py

@@ -1,3 +1,11 @@
 class MODEL:
     Paddle = "pp_structure_v2"
     PEK = "pdf_extract_kit"
+
+
+class AtomicModel:
+    Layout = "layout"
+    MFD = "mfd"
+    MFR = "mfr"
+    OCR = "ocr"
+    Table = "table"

+ 103 - 16
magic_pdf/model/pdf_extract_kit.py

@@ -3,6 +3,7 @@ import os
 import time
 
 from magic_pdf.libs.Constants import *
+from magic_pdf.model.model_list import AtomicModel
 
 os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1'  # 禁止albumentations检查更新
 try:
@@ -64,7 +65,8 @@ def mfr_model_init(weight_dir, cfg_path, _device_='cpu'):
     model = task.build_model(cfg)
     model = model.to(_device_)
     vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
-    return model, vis_processor
+    mfr_transform = transforms.Compose([vis_processor, ])
+    return [model, mfr_transform]
 
 
 def layout_model_init(weight, config_file, device):
@@ -72,6 +74,11 @@ def layout_model_init(weight, config_file, device):
     return model
 
 
+def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3):
+    model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh)
+    return model
+
+
 class MathDataset(Dataset):
     def __init__(self, image_paths, transform=None):
         self.image_paths = image_paths
@@ -91,6 +98,58 @@ class MathDataset(Dataset):
             return image
 
 
+class AtomModelSingleton:
+    _instance = None
+    _models = {}
+
+    def __new__(cls, *args, **kwargs):
+        if cls._instance is None:
+            cls._instance = super().__new__(cls)
+        return cls._instance
+
+    def get_atom_model(self, atom_model_name: str, **kwargs):
+        if atom_model_name not in self._models:
+            self._models[atom_model_name] = atom_model_init(model_name=atom_model_name, **kwargs)
+        return self._models[atom_model_name]
+
+
+def atom_model_init(model_name: str, **kwargs):
+
+    if model_name == AtomicModel.Layout:
+        atom_model = layout_model_init(
+            kwargs.get("layout_weights"),
+            kwargs.get("layout_config_file"),
+            kwargs.get("device")
+        )
+    elif model_name == AtomicModel.MFD:
+        atom_model = mfd_model_init(
+            kwargs.get("mfd_weights")
+        )
+    elif model_name == AtomicModel.MFR:
+        atom_model = mfr_model_init(
+            kwargs.get("mfr_weight_dir"),
+            kwargs.get("mfr_cfg_path"),
+            kwargs.get("device")
+        )
+    elif model_name == AtomicModel.OCR:
+        atom_model = ocr_model_init(
+            kwargs.get("ocr_show_log"),
+            kwargs.get("det_db_box_thresh")
+        )
+    elif model_name == AtomicModel.Table:
+        atom_model = table_model_init(
+            kwargs.get("table_model_type"),
+            kwargs.get("table_model_path"),
+            kwargs.get("table_max_time"),
+            kwargs.get("device")
+        )
+    else:
+        logger.error("model name not allow")
+        exit(1)
+
+    return atom_model
+
+
 class CustomPEKModel:
 
     def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
@@ -130,32 +189,60 @@ class CustomPEKModel:
         models_dir = kwargs.get("models_dir", os.path.join(root_dir, "resources", "models"))
         logger.info("using models_dir: {}".format(models_dir))
 
+        atom_model_manager = AtomModelSingleton()
+
         # 初始化公式识别
         if self.apply_formula:
             # 初始化公式检测模型
-            self.mfd_model = mfd_model_init(str(os.path.join(models_dir, self.configs["weights"]["mfd"])))
-
+            # self.mfd_model = mfd_model_init(str(os.path.join(models_dir, self.configs["weights"]["mfd"])))
+            self.mfd_model = atom_model_manager.get_atom_model(
+                atom_model_name=AtomicModel.MFD,
+                mfd_weights=str(os.path.join(models_dir, self.configs["weights"]["mfd"]))
+            )
             # 初始化公式解析模型
             mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"]["mfr"]))
             mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml"))
-            self.mfr_model, mfr_vis_processors = mfr_model_init(mfr_weight_dir, mfr_cfg_path, _device_=self.device)
-            self.mfr_transform = transforms.Compose([mfr_vis_processors, ])
+            # self.mfr_model, mfr_vis_processors = mfr_model_init(mfr_weight_dir, mfr_cfg_path, _device_=self.device)
+            # self.mfr_transform = transforms.Compose([mfr_vis_processors, ])
+            self.mfr_model, self.mfr_transform = atom_model_manager.get_atom_model(
+                atom_model_name=AtomicModel.MFR,
+                mfr_weight_dir=mfr_weight_dir,
+                mfr_cfg_path=mfr_cfg_path,
+                device=self.device
+            )
 
         # 初始化layout模型
-        self.layout_model = Layoutlmv3_Predictor(
-            str(os.path.join(models_dir, self.configs['weights']['layout'])),
-            str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
+        # self.layout_model = Layoutlmv3_Predictor(
+        #     str(os.path.join(models_dir, self.configs['weights']['layout'])),
+        #     str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
+        #     device=self.device
+        # )
+        self.layout_model = atom_model_manager.get_atom_model(
+            atom_model_name=AtomicModel.Layout,
+            layout_weights=str(os.path.join(models_dir, self.configs['weights']['layout'])),
+            layout_config_file=str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
             device=self.device
         )
         # 初始化ocr
         if self.apply_ocr:
-            self.ocr_model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=0.3)
-
+            # self.ocr_model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=0.3)
+            self.ocr_model = atom_model_manager.get_atom_model(
+                atom_model_name=AtomicModel.OCR,
+                ocr_show_log=show_log,
+                det_db_box_thresh=0.3
+            )
         # init table model
         if self.apply_table:
             table_model_dir = self.configs["weights"][self.table_model_type]
-            self.table_model = table_model_init(self.table_model_type, str(os.path.join(models_dir, table_model_dir)),
-                                                max_time=self.table_max_time, _device_=self.device)
+            # self.table_model = table_model_init(self.table_model_type, str(os.path.join(models_dir, table_model_dir)),
+            #                                     max_time=self.table_max_time, _device_=self.device)
+            self.table_model = atom_model_manager.get_atom_model(
+                atom_model_name=AtomicModel.Table,
+                table_model_type=self.table_model_type,
+                table_model_path=str(os.path.join(models_dir, table_model_dir)),
+                table_max_time=self.table_max_time,
+                device=self.device
+            )
         logger.info('DocAnalysis init done!')
 
     def __call__(self, image):
@@ -291,11 +378,11 @@ class CustomPEKModel:
                 logger.info("------------------table recognition processing begins-----------------")
                 latex_code = None
                 html_code = None
-                with torch.no_grad():
-                    if self.table_model_type == STRUCT_EQTABLE:
+                if self.table_model_type == STRUCT_EQTABLE:
+                    with torch.no_grad():
                         latex_code = self.table_model.image2latex(new_image)[0]
-                    else:
-                        html_code = self.table_model.img2html(new_image)
+                else:
+                    html_code = self.table_model.img2html(new_image)
                 run_time = time.time() - single_table_start_time
                 logger.info(f"------------table recognition processing ends within {run_time}s-----")
                 if run_time > self.table_max_time: