|
|
@@ -7,6 +7,7 @@ import yaml
|
|
|
from PIL import Image
|
|
|
from ultralytics import YOLO
|
|
|
from loguru import logger
|
|
|
+
|
|
|
from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
|
|
|
from unimernet.common.config import Config
|
|
|
import unimernet.tasks as tasks
|
|
|
@@ -84,23 +85,26 @@ class CustomPEKModel:
|
|
|
)
|
|
|
assert self.apply_layout, "DocAnalysis must contain layout model."
|
|
|
# 初始化解析方案
|
|
|
- self.device = self.configs["config"]["device"]
|
|
|
+ self.device = kwargs.get("device", self.configs["config"]["device"])
|
|
|
logger.info("using device: {}".format(self.device))
|
|
|
+ models_dir = kwargs.get("models_dir", os.path.join(root_dir, "resources", "models"))
|
|
|
# 初始化layout模型
|
|
|
self.layout_model = layout_model_init(
|
|
|
- os.path.join(root_dir, self.configs['weights']['layout']),
|
|
|
+ os.path.join(models_dir, self.configs['weights']['layout']),
|
|
|
os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml"),
|
|
|
device=self.device
|
|
|
)
|
|
|
# 初始化公式识别
|
|
|
if self.apply_formula:
|
|
|
# 初始化公式检测模型
|
|
|
- self.mfd_model = YOLO(model=str(os.path.join(root_dir, self.configs["weights"]["mfd"])))
|
|
|
+ self.mfd_model = YOLO(model=str(os.path.join(models_dir, self.configs["weights"]["mfd"])))
|
|
|
# 初始化公式解析模型
|
|
|
mfr_config_path = os.path.join(model_config_dir, 'UniMERNet', 'demo.yaml')
|
|
|
self.mfr_model, mfr_vis_processors = mfr_model_init(
|
|
|
- os.path.join(root_dir, self.configs["weights"]["mfr"]), mfr_config_path,
|
|
|
- device=self.device)
|
|
|
+ os.path.join(models_dir, self.configs["weights"]["mfr"]),
|
|
|
+ mfr_config_path,
|
|
|
+ device=self.device
|
|
|
+ )
|
|
|
self.mfr_transform = transforms.Compose([mfr_vis_processors, ])
|
|
|
# 初始化ocr
|
|
|
if self.apply_ocr:
|