|
|
@@ -25,6 +25,7 @@ try:
|
|
|
from unimernet.common.config import Config
|
|
|
import unimernet.tasks as tasks
|
|
|
from unimernet.processors import load_processor
|
|
|
+ from doclayout_yolo import YOLOv10
|
|
|
|
|
|
except ImportError as e:
|
|
|
logger.exception(e)
|
|
|
@@ -41,7 +42,7 @@ from magic_pdf.model.ppTableModel import ppTableModel
|
|
|
|
|
|
|
|
|
def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
|
|
|
- if table_model_type == STRUCT_EQTABLE:
|
|
|
+ if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
|
|
|
table_model = StructTableModel(model_path, max_time=max_time, device=_device_)
|
|
|
else:
|
|
|
config = {
|
|
|
@@ -77,6 +78,11 @@ def layout_model_init(weight, config_file, device):
|
|
|
return model
|
|
|
|
|
|
|
|
|
+def doclayout_yolo_model_init(weight):
|
|
|
+ model = YOLOv10(weight)
|
|
|
+ return model
|
|
|
+
|
|
|
+
|
|
|
def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3, lang=None, use_dilation=True, det_db_unclip_ratio=2.4):
|
|
|
if lang is not None:
|
|
|
model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, lang=lang, use_dilation=use_dilation, det_db_unclip_ratio=det_db_unclip_ratio)
|
|
|
@@ -114,19 +120,27 @@ class AtomModelSingleton:
|
|
|
return cls._instance
|
|
|
|
|
|
def get_atom_model(self, atom_model_name: str, **kwargs):
|
|
|
- if atom_model_name not in self._models:
|
|
|
- self._models[atom_model_name] = atom_model_init(model_name=atom_model_name, **kwargs)
|
|
|
- return self._models[atom_model_name]
|
|
|
+ lang = kwargs.get("lang", None)
|
|
|
+ layout_model_name = kwargs.get("layout_model_name", None)
|
|
|
+ key = (atom_model_name, layout_model_name, lang)
|
|
|
+ if key not in self._models:
|
|
|
+ self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
|
|
|
+ return self._models[key]
|
|
|
|
|
|
|
|
|
def atom_model_init(model_name: str, **kwargs):
|
|
|
|
|
|
if model_name == AtomicModel.Layout:
|
|
|
- atom_model = layout_model_init(
|
|
|
- kwargs.get("layout_weights"),
|
|
|
- kwargs.get("layout_config_file"),
|
|
|
- kwargs.get("device")
|
|
|
- )
|
|
|
+ if kwargs.get("layout_model_name") == MODEL_NAME.LAYOUTLMv3:
|
|
|
+ atom_model = layout_model_init(
|
|
|
+ kwargs.get("layout_weights"),
|
|
|
+ kwargs.get("layout_config_file"),
|
|
|
+ kwargs.get("device")
|
|
|
+ )
|
|
|
+ elif kwargs.get("layout_model_name") == MODEL_NAME.DocLayout_YOLO:
|
|
|
+ atom_model = doclayout_yolo_model_init(
|
|
|
+ kwargs.get("doclayout_yolo_weights"),
|
|
|
+ )
|
|
|
elif model_name == AtomicModel.MFD:
|
|
|
atom_model = mfd_model_init(
|
|
|
kwargs.get("mfd_weights")
|
|
|
@@ -145,7 +159,7 @@ def atom_model_init(model_name: str, **kwargs):
|
|
|
)
|
|
|
elif model_name == AtomicModel.Table:
|
|
|
atom_model = table_model_init(
|
|
|
- kwargs.get("table_model_type"),
|
|
|
+ kwargs.get("table_model_name"),
|
|
|
kwargs.get("table_model_path"),
|
|
|
kwargs.get("table_max_time"),
|
|
|
kwargs.get("device")
|
|
|
@@ -193,23 +207,35 @@ class CustomPEKModel:
|
|
|
with open(config_path, "r", encoding='utf-8') as f:
|
|
|
self.configs = yaml.load(f, Loader=yaml.FullLoader)
|
|
|
# 初始化解析配置
|
|
|
- self.apply_layout = kwargs.get("apply_layout", self.configs["config"]["layout"])
|
|
|
- self.apply_formula = kwargs.get("apply_formula", self.configs["config"]["formula"])
|
|
|
+
|
|
|
+ # layout config
|
|
|
+ self.layout_config = kwargs.get("layout_config")
|
|
|
+ self.layout_model_name = self.layout_config.get("model", MODEL_NAME.DocLayout_YOLO)
|
|
|
+
|
|
|
+ # formula config
|
|
|
+ self.formula_config = kwargs.get("formula_config")
|
|
|
+ self.mfd_model_name = self.formula_config.get("mfd_model", MODEL_NAME.YOLO_V8_MFD)
|
|
|
+ self.mfr_model_name = self.formula_config.get("mfr_model", MODEL_NAME.UniMerNet_v2_Small)
|
|
|
+ self.apply_formula = self.formula_config.get("enable", True)
|
|
|
+
|
|
|
# table config
|
|
|
- self.table_config = kwargs.get("table_config", self.configs["config"]["table_config"])
|
|
|
- self.apply_table = self.table_config.get("is_table_recog_enable", False)
|
|
|
+ self.table_config = kwargs.get("table_config")
|
|
|
+ self.apply_table = self.table_config.get("enable", False)
|
|
|
self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE)
|
|
|
- self.table_model_type = self.table_config.get("model", TABLE_MASTER)
|
|
|
+ self.table_model_name = self.table_config.get("model", MODEL_NAME.TABLE_MASTER)
|
|
|
+
|
|
|
+ # ocr config
|
|
|
self.apply_ocr = ocr
|
|
|
self.lang = kwargs.get("lang", None)
|
|
|
+
|
|
|
logger.info(
|
|
|
- "DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}, apply_table: {}, lang: {}".format(
|
|
|
- self.apply_layout, self.apply_formula, self.apply_ocr, self.apply_table, self.lang
|
|
|
+ "DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, "
|
|
|
+ "apply_table: {}, table_model: {}, lang: {}".format(
|
|
|
+ self.layout_model_name, self.apply_formula, self.apply_ocr, self.apply_table, self.table_model_name, self.lang
|
|
|
)
|
|
|
)
|
|
|
- assert self.apply_layout, "DocAnalysis must contain layout model."
|
|
|
# 初始化解析方案
|
|
|
- self.device = kwargs.get("device", self.configs["config"]["device"])
|
|
|
+ self.device = kwargs.get("device", "cpu")
|
|
|
logger.info("using device: {}".format(self.device))
|
|
|
models_dir = kwargs.get("models_dir", os.path.join(root_dir, "resources", "models"))
|
|
|
logger.info("using models_dir: {}".format(models_dir))
|
|
|
@@ -218,17 +244,16 @@ class CustomPEKModel:
|
|
|
|
|
|
# 初始化公式识别
|
|
|
if self.apply_formula:
|
|
|
+
|
|
|
# 初始化公式检测模型
|
|
|
- # self.mfd_model = mfd_model_init(str(os.path.join(models_dir, self.configs["weights"]["mfd"])))
|
|
|
self.mfd_model = atom_model_manager.get_atom_model(
|
|
|
atom_model_name=AtomicModel.MFD,
|
|
|
- mfd_weights=str(os.path.join(models_dir, self.configs["weights"]["mfd"]))
|
|
|
+ mfd_weights=str(os.path.join(models_dir, self.configs["weights"][self.mfd_model_name]))
|
|
|
)
|
|
|
+
|
|
|
# 初始化公式解析模型
|
|
|
- mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"]["mfr"]))
|
|
|
+ mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"][self.mfr_model_name]))
|
|
|
mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml"))
|
|
|
- # self.mfr_model, mfr_vis_processors = mfr_model_init(mfr_weight_dir, mfr_cfg_path, _device_=self.device)
|
|
|
- # self.mfr_transform = transforms.Compose([mfr_vis_processors, ])
|
|
|
self.mfr_model, self.mfr_transform = atom_model_manager.get_atom_model(
|
|
|
atom_model_name=AtomicModel.MFR,
|
|
|
mfr_weight_dir=mfr_weight_dir,
|
|
|
@@ -237,17 +262,20 @@ class CustomPEKModel:
|
|
|
)
|
|
|
|
|
|
# 初始化layout模型
|
|
|
- # self.layout_model = Layoutlmv3_Predictor(
|
|
|
- # str(os.path.join(models_dir, self.configs['weights']['layout'])),
|
|
|
- # str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
|
|
|
- # device=self.device
|
|
|
- # )
|
|
|
- self.layout_model = atom_model_manager.get_atom_model(
|
|
|
- atom_model_name=AtomicModel.Layout,
|
|
|
- layout_weights=str(os.path.join(models_dir, self.configs['weights']['layout'])),
|
|
|
- layout_config_file=str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
|
|
|
- device=self.device
|
|
|
- )
|
|
|
+ if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
|
|
|
+ self.layout_model = atom_model_manager.get_atom_model(
|
|
|
+ atom_model_name=AtomicModel.Layout,
|
|
|
+ layout_model_name=MODEL_NAME.LAYOUTLMv3,
|
|
|
+ layout_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name])),
|
|
|
+ layout_config_file=str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
|
|
|
+ device=self.device
|
|
|
+ )
|
|
|
+ elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
|
|
|
+ self.layout_model = atom_model_manager.get_atom_model(
|
|
|
+ atom_model_name=AtomicModel.Layout,
|
|
|
+ layout_model_name=MODEL_NAME.DocLayout_YOLO,
|
|
|
+ doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name]))
|
|
|
+ )
|
|
|
# 初始化ocr
|
|
|
if self.apply_ocr:
|
|
|
|
|
|
@@ -260,12 +288,10 @@ class CustomPEKModel:
|
|
|
)
|
|
|
# init table model
|
|
|
if self.apply_table:
|
|
|
- table_model_dir = self.configs["weights"][self.table_model_type]
|
|
|
- # self.table_model = table_model_init(self.table_model_type, str(os.path.join(models_dir, table_model_dir)),
|
|
|
- # max_time=self.table_max_time, _device_=self.device)
|
|
|
+ table_model_dir = self.configs["weights"][self.table_model_name]
|
|
|
self.table_model = atom_model_manager.get_atom_model(
|
|
|
atom_model_name=AtomicModel.Table,
|
|
|
- table_model_type=self.table_model_type,
|
|
|
+ table_model_name=self.table_model_name,
|
|
|
table_model_path=str(os.path.join(models_dir, table_model_dir)),
|
|
|
table_max_time=self.table_max_time,
|
|
|
device=self.device
|
|
|
@@ -282,7 +308,21 @@ class CustomPEKModel:
|
|
|
|
|
|
# layout检测
|
|
|
layout_start = time.time()
|
|
|
- layout_res = self.layout_model(image, ignore_catids=[])
|
|
|
+ if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
|
|
|
+ # layoutlmv3
|
|
|
+ layout_res = self.layout_model(image, ignore_catids=[])
|
|
|
+ elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
|
|
|
+ # doclayout_yolo
|
|
|
+ layout_res = []
|
|
|
+ doclayout_yolo_res = self.layout_model.predict(image, imgsz=1024, conf=0.15, iou=0.45, verbose=True, device=self.device)[0]
|
|
|
+ for xyxy, conf, cla in zip(doclayout_yolo_res.boxes.xyxy.cpu(), doclayout_yolo_res.boxes.conf.cpu(), doclayout_yolo_res.boxes.cls.cpu()):
|
|
|
+ xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
|
|
|
+ new_item = {
|
|
|
+ 'category_id': int(cla.item()),
|
|
|
+ 'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
|
|
|
+ 'score': round(float(conf.item()), 3),
|
|
|
+ }
|
|
|
+ layout_res.append(new_item)
|
|
|
layout_cost = round(time.time() - layout_start, 2)
|
|
|
logger.info(f"layout detection time: {layout_cost}")
|
|
|
|
|
|
@@ -291,7 +331,7 @@ class CustomPEKModel:
|
|
|
if self.apply_formula:
|
|
|
# 公式检测
|
|
|
mfd_start = time.time()
|
|
|
- mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True)[0]
|
|
|
+ mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
|
|
|
logger.info(f"mfd time: {round(time.time() - mfd_start, 2)}")
|
|
|
for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
|
|
|
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
|
|
|
@@ -303,7 +343,6 @@ class CustomPEKModel:
|
|
|
}
|
|
|
layout_res.append(new_item)
|
|
|
latex_filling_list.append(new_item)
|
|
|
- # bbox_img = get_croped_image(pil_img, [xmin, ymin, xmax, ymax])
|
|
|
bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
|
|
|
mf_image_list.append(bbox_img)
|
|
|
|
|
|
@@ -405,7 +444,7 @@ class CustomPEKModel:
|
|
|
# logger.info("------------------table recognition processing begins-----------------")
|
|
|
latex_code = None
|
|
|
html_code = None
|
|
|
- if self.table_model_type == STRUCT_EQTABLE:
|
|
|
+ if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
|
|
|
with torch.no_grad():
|
|
|
latex_code = self.table_model.image2latex(new_image)[0]
|
|
|
else:
|