|
|
@@ -3,6 +3,7 @@ import os
|
|
|
import time
|
|
|
|
|
|
from magic_pdf.libs.Constants import *
|
|
|
+from magic_pdf.model.model_list import AtomicModel
|
|
|
|
|
|
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
|
|
|
try:
|
|
|
@@ -64,7 +65,8 @@ def mfr_model_init(weight_dir, cfg_path, _device_='cpu'):
|
|
|
model = task.build_model(cfg)
|
|
|
model = model.to(_device_)
|
|
|
vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
|
|
|
- return model, vis_processor
|
|
|
+ mfr_transform = transforms.Compose([vis_processor, ])
|
|
|
+ return [model, mfr_transform]
|
|
|
|
|
|
|
|
|
def layout_model_init(weight, config_file, device):
|
|
|
@@ -72,6 +74,11 @@ def layout_model_init(weight, config_file, device):
|
|
|
return model
|
|
|
|
|
|
|
|
|
+def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3):
|
|
|
+ model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh)
|
|
|
+ return model
|
|
|
+
|
|
|
+
|
|
|
class MathDataset(Dataset):
|
|
|
def __init__(self, image_paths, transform=None):
|
|
|
self.image_paths = image_paths
|
|
|
@@ -91,6 +98,58 @@ class MathDataset(Dataset):
|
|
|
return image
|
|
|
|
|
|
|
|
|
+class AtomModelSingleton:
|
|
|
+ _instance = None
|
|
|
+ _models = {}
|
|
|
+
|
|
|
+ def __new__(cls, *args, **kwargs):
|
|
|
+ if cls._instance is None:
|
|
|
+ cls._instance = super().__new__(cls)
|
|
|
+ return cls._instance
|
|
|
+
|
|
|
+ def get_atom_model(self, atom_model_name: str, **kwargs):
|
|
|
+ if atom_model_name not in self._models:
|
|
|
+ self._models[atom_model_name] = atom_model_init(model_name=atom_model_name, **kwargs)
|
|
|
+ return self._models[atom_model_name]
|
|
|
+
|
|
|
+
|
|
|
+def atom_model_init(model_name: str, **kwargs):
|
|
|
+
|
|
|
+ if model_name == AtomicModel.Layout:
|
|
|
+ atom_model = layout_model_init(
|
|
|
+ kwargs.get("layout_weights"),
|
|
|
+ kwargs.get("layout_config_file"),
|
|
|
+ kwargs.get("device")
|
|
|
+ )
|
|
|
+ elif model_name == AtomicModel.MFD:
|
|
|
+ atom_model = mfd_model_init(
|
|
|
+ kwargs.get("mfd_weights")
|
|
|
+ )
|
|
|
+ elif model_name == AtomicModel.MFR:
|
|
|
+ atom_model = mfr_model_init(
|
|
|
+ kwargs.get("mfr_weight_dir"),
|
|
|
+ kwargs.get("mfr_cfg_path"),
|
|
|
+ kwargs.get("device")
|
|
|
+ )
|
|
|
+ elif model_name == AtomicModel.OCR:
|
|
|
+ atom_model = ocr_model_init(
|
|
|
+ kwargs.get("ocr_show_log"),
|
|
|
+ kwargs.get("det_db_box_thresh")
|
|
|
+ )
|
|
|
+ elif model_name == AtomicModel.Table:
|
|
|
+ atom_model = table_model_init(
|
|
|
+ kwargs.get("table_model_type"),
|
|
|
+ kwargs.get("table_model_path"),
|
|
|
+ kwargs.get("table_max_time"),
|
|
|
+ kwargs.get("device")
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ logger.error("model name not allow")
|
|
|
+ exit(1)
|
|
|
+
|
|
|
+ return atom_model
|
|
|
+
|
|
|
+
|
|
|
class CustomPEKModel:
|
|
|
|
|
|
def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
|
|
|
@@ -130,32 +189,60 @@ class CustomPEKModel:
|
|
|
models_dir = kwargs.get("models_dir", os.path.join(root_dir, "resources", "models"))
|
|
|
logger.info("using models_dir: {}".format(models_dir))
|
|
|
|
|
|
+ atom_model_manager = AtomModelSingleton()
|
|
|
+
|
|
|
# 初始化公式识别
|
|
|
if self.apply_formula:
|
|
|
# 初始化公式检测模型
|
|
|
- self.mfd_model = mfd_model_init(str(os.path.join(models_dir, self.configs["weights"]["mfd"])))
|
|
|
-
|
|
|
+ # self.mfd_model = mfd_model_init(str(os.path.join(models_dir, self.configs["weights"]["mfd"])))
|
|
|
+ self.mfd_model = atom_model_manager.get_atom_model(
|
|
|
+ atom_model_name=AtomicModel.MFD,
|
|
|
+ mfd_weights=str(os.path.join(models_dir, self.configs["weights"]["mfd"]))
|
|
|
+ )
|
|
|
# 初始化公式解析模型
|
|
|
mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"]["mfr"]))
|
|
|
mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml"))
|
|
|
- self.mfr_model, mfr_vis_processors = mfr_model_init(mfr_weight_dir, mfr_cfg_path, _device_=self.device)
|
|
|
- self.mfr_transform = transforms.Compose([mfr_vis_processors, ])
|
|
|
+ # self.mfr_model, mfr_vis_processors = mfr_model_init(mfr_weight_dir, mfr_cfg_path, _device_=self.device)
|
|
|
+ # self.mfr_transform = transforms.Compose([mfr_vis_processors, ])
|
|
|
+ self.mfr_model, self.mfr_transform = atom_model_manager.get_atom_model(
|
|
|
+ atom_model_name=AtomicModel.MFR,
|
|
|
+ mfr_weight_dir=mfr_weight_dir,
|
|
|
+ mfr_cfg_path=mfr_cfg_path,
|
|
|
+ device=self.device
|
|
|
+ )
|
|
|
|
|
|
# 初始化layout模型
|
|
|
- self.layout_model = Layoutlmv3_Predictor(
|
|
|
- str(os.path.join(models_dir, self.configs['weights']['layout'])),
|
|
|
- str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
|
|
|
+ # self.layout_model = Layoutlmv3_Predictor(
|
|
|
+ # str(os.path.join(models_dir, self.configs['weights']['layout'])),
|
|
|
+ # str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
|
|
|
+ # device=self.device
|
|
|
+ # )
|
|
|
+ self.layout_model = atom_model_manager.get_atom_model(
|
|
|
+ atom_model_name=AtomicModel.Layout,
|
|
|
+ layout_weights=str(os.path.join(models_dir, self.configs['weights']['layout'])),
|
|
|
+ layout_config_file=str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
|
|
|
device=self.device
|
|
|
)
|
|
|
# 初始化ocr
|
|
|
if self.apply_ocr:
|
|
|
- self.ocr_model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=0.3)
|
|
|
-
|
|
|
+ # self.ocr_model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=0.3)
|
|
|
+ self.ocr_model = atom_model_manager.get_atom_model(
|
|
|
+ atom_model_name=AtomicModel.OCR,
|
|
|
+ ocr_show_log=show_log,
|
|
|
+ det_db_box_thresh=0.3
|
|
|
+ )
|
|
|
# init table model
|
|
|
if self.apply_table:
|
|
|
table_model_dir = self.configs["weights"][self.table_model_type]
|
|
|
- self.table_model = table_model_init(self.table_model_type, str(os.path.join(models_dir, table_model_dir)),
|
|
|
- max_time=self.table_max_time, _device_=self.device)
|
|
|
+ # self.table_model = table_model_init(self.table_model_type, str(os.path.join(models_dir, table_model_dir)),
|
|
|
+ # max_time=self.table_max_time, _device_=self.device)
|
|
|
+ self.table_model = atom_model_manager.get_atom_model(
|
|
|
+ atom_model_name=AtomicModel.Table,
|
|
|
+ table_model_type=self.table_model_type,
|
|
|
+ table_model_path=str(os.path.join(models_dir, table_model_dir)),
|
|
|
+ table_max_time=self.table_max_time,
|
|
|
+ device=self.device
|
|
|
+ )
|
|
|
logger.info('DocAnalysis init done!')
|
|
|
|
|
|
def __call__(self, image):
|
|
|
@@ -291,11 +378,11 @@ class CustomPEKModel:
|
|
|
logger.info("------------------table recognition processing begins-----------------")
|
|
|
latex_code = None
|
|
|
html_code = None
|
|
|
- with torch.no_grad():
|
|
|
- if self.table_model_type == STRUCT_EQTABLE:
|
|
|
+ if self.table_model_type == STRUCT_EQTABLE:
|
|
|
+ with torch.no_grad():
|
|
|
latex_code = self.table_model.image2latex(new_image)[0]
|
|
|
- else:
|
|
|
- html_code = self.table_model.img2html(new_image)
|
|
|
+ else:
|
|
|
+ html_code = self.table_model.img2html(new_image)
|
|
|
run_time = time.time() - single_table_start_time
|
|
|
logger.info(f"------------table recognition processing ends within {run_time}s-----")
|
|
|
if run_time > self.table_max_time:
|