|
|
@@ -1,31 +1,32 @@
|
|
|
import os
|
|
|
-import time
|
|
|
-
|
|
|
import cv2
|
|
|
-import numpy as np
|
|
|
import yaml
|
|
|
-from PIL import Image
|
|
|
-from ultralytics import YOLO
|
|
|
+import time
|
|
|
+import argparse
|
|
|
+import numpy as np
|
|
|
+import torch
|
|
|
from loguru import logger
|
|
|
|
|
|
-from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
|
|
|
+from paddleocr import draw_ocr
|
|
|
+from PIL import Image
|
|
|
+from torchvision import transforms
|
|
|
+from torch.utils.data import Dataset, DataLoader
|
|
|
+from ultralytics import YOLO
|
|
|
from unimernet.common.config import Config
|
|
|
import unimernet.tasks as tasks
|
|
|
from unimernet.processors import load_processor
|
|
|
-import argparse
|
|
|
-from torchvision import transforms
|
|
|
-from torch.utils.data import Dataset, DataLoader
|
|
|
|
|
|
+from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
|
|
|
from magic_pdf.model.pek_sub_modules.post_process import get_croped_image, latex_rm_whitespace
|
|
|
from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
|
|
|
|
|
|
|
|
|
-def layout_model_init(weight, config_file, device):
|
|
|
- model = Layoutlmv3_Predictor(weight, config_file, device)
|
|
|
- return model
|
|
|
+def mfd_model_init(weight):
|
|
|
+ mfd_model = YOLO(weight)
|
|
|
+ return mfd_model
|
|
|
|
|
|
|
|
|
-def mfr_model_init(weight_dir, cfg_path, device='cpu'):
|
|
|
+def mfr_model_init(weight_dir, cfg_path, _device_='cpu'):
|
|
|
args = argparse.Namespace(cfg_path=cfg_path, options=None)
|
|
|
cfg = Config(args)
|
|
|
cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.bin")
|
|
|
@@ -33,11 +34,16 @@ def mfr_model_init(weight_dir, cfg_path, device='cpu'):
|
|
|
cfg.config.model.tokenizer_config.path = weight_dir
|
|
|
task = tasks.setup_task(cfg)
|
|
|
model = task.build_model(cfg)
|
|
|
- model = model.to(device)
|
|
|
+ model = model.to(_device_)
|
|
|
vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
|
|
|
return model, vis_processor
|
|
|
|
|
|
|
|
|
+def layout_model_init(weight, config_file, device):
|
|
|
+ model = Layoutlmv3_Predictor(weight, config_file, device)
|
|
|
+ return model
|
|
|
+
|
|
|
+
|
|
|
class MathDataset(Dataset):
|
|
|
def __init__(self, image_paths, transform=None):
|
|
|
self.image_paths = image_paths
|
|
|
@@ -54,10 +60,11 @@ class MathDataset(Dataset):
|
|
|
raw_image = self.image_paths[idx]
|
|
|
if self.transform:
|
|
|
image = self.transform(raw_image)
|
|
|
- return image
|
|
|
+ return image
|
|
|
|
|
|
|
|
|
class CustomPEKModel:
|
|
|
+
|
|
|
def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
|
|
|
"""
|
|
|
======== model init ========
|
|
|
@@ -88,24 +95,24 @@ class CustomPEKModel:
|
|
|
self.device = kwargs.get("device", self.configs["config"]["device"])
|
|
|
logger.info("using device: {}".format(self.device))
|
|
|
models_dir = kwargs.get("models_dir", os.path.join(root_dir, "resources", "models"))
|
|
|
- # 初始化layout模型
|
|
|
- self.layout_model = layout_model_init(
|
|
|
- os.path.join(models_dir, self.configs['weights']['layout']),
|
|
|
- os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml"),
|
|
|
- device=self.device
|
|
|
- )
|
|
|
+
|
|
|
# 初始化公式识别
|
|
|
if self.apply_formula:
|
|
|
# 初始化公式检测模型
|
|
|
- self.mfd_model = YOLO(model=str(os.path.join(models_dir, self.configs["weights"]["mfd"])))
|
|
|
+ self.mfd_model = mfd_model_init(str(os.path.join(models_dir, self.configs["weights"]["mfd"])))
|
|
|
+
|
|
|
# 初始化公式解析模型
|
|
|
- mfr_config_path = os.path.join(model_config_dir, 'UniMERNet', 'demo.yaml')
|
|
|
- self.mfr_model, mfr_vis_processors = mfr_model_init(
|
|
|
- os.path.join(models_dir, self.configs["weights"]["mfr"]),
|
|
|
- mfr_config_path,
|
|
|
- device=self.device
|
|
|
- )
|
|
|
+ mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"]["mfr"]))
|
|
|
+ mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml"))
|
|
|
+ self.mfr_model, mfr_vis_processors = mfr_model_init(mfr_weight_dir, mfr_cfg_path, _device_=self.device)
|
|
|
self.mfr_transform = transforms.Compose([mfr_vis_processors, ])
|
|
|
+
|
|
|
+ # 初始化layout模型
|
|
|
+ self.layout_model = Layoutlmv3_Predictor(
|
|
|
+ str(os.path.join(models_dir, self.configs['weights']['layout'])),
|
|
|
+ str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
|
|
|
+ device=self.device
|
|
|
+ )
|
|
|
# 初始化ocr
|
|
|
if self.apply_ocr:
|
|
|
self.ocr_model = ModifiedPaddleOCR(show_log=show_log)
|