| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242 |
- import os
- import torch
- from loguru import logger
- from .model_list import AtomicModel
- from ...model.layout.doclayout_yolo import DocLayoutYOLOModel
- from ...model.mfd.yolo_v8 import YOLOv8MFDModel
- from ...model.mfr.unimernet.Unimernet import UnimernetModel
- from ...model.ocr.paddleocr2pytorch.pytorch_paddle import PytorchPaddleOCR
- from ...model.ori_cls.paddle_ori_cls import PaddleOrientationClsModel
- from ...model.table.cls.paddle_table_cls import PaddleTableClsModel
- from ...model.table.rec.rapid_table import RapidTableModel
- from ...model.table.rec.unet_table.main import UnetTableModel
- from ...utils.enum_class import ModelPath
- from ...utils.models_download_utils import auto_download_and_get_model_root_path
- def img_orientation_cls_model_init():
- atom_model_manager = AtomModelSingleton()
- ocr_engine = atom_model_manager.get_atom_model(
- atom_model_name=AtomicModel.OCR,
- det_db_box_thresh=0.5,
- det_db_unclip_ratio=1.6,
- lang="ch_lite",
- enable_merge_det_boxes=False
- )
- cls_model = PaddleOrientationClsModel(ocr_engine)
- return cls_model
- def table_cls_model_init():
- return PaddleTableClsModel()
- def wired_table_model_init(lang=None):
- atom_model_manager = AtomModelSingleton()
- ocr_engine = atom_model_manager.get_atom_model(
- atom_model_name=AtomicModel.OCR,
- det_db_box_thresh=0.5,
- det_db_unclip_ratio=1.6,
- lang=lang,
- enable_merge_det_boxes=False
- )
- table_model = UnetTableModel(ocr_engine)
- return table_model
- def wireless_table_model_init(lang=None):
- atom_model_manager = AtomModelSingleton()
- ocr_engine = atom_model_manager.get_atom_model(
- atom_model_name=AtomicModel.OCR,
- det_db_box_thresh=0.5,
- det_db_unclip_ratio=1.6,
- lang=lang,
- enable_merge_det_boxes=False
- )
- table_model = RapidTableModel(ocr_engine)
- return table_model
- def mfd_model_init(weight, device='cpu'):
- if str(device).startswith('npu'):
- device = torch.device(device)
- mfd_model = YOLOv8MFDModel(weight, device)
- return mfd_model
- def mfr_model_init(weight_dir, device='cpu'):
- mfr_model = UnimernetModel(weight_dir, device)
- return mfr_model
- def doclayout_yolo_model_init(weight, device='cpu'):
- if str(device).startswith('npu'):
- device = torch.device(device)
- model = DocLayoutYOLOModel(weight, device)
- return model
- def ocr_model_init(det_db_box_thresh=0.3,
- lang=None,
- det_db_unclip_ratio=1.8,
- enable_merge_det_boxes=True
- ):
- if lang is not None and lang != '':
- model = PytorchPaddleOCR(
- det_db_box_thresh=det_db_box_thresh,
- lang=lang,
- use_dilation=True,
- det_db_unclip_ratio=det_db_unclip_ratio,
- enable_merge_det_boxes=enable_merge_det_boxes,
- )
- else:
- model = PytorchPaddleOCR(
- det_db_box_thresh=det_db_box_thresh,
- use_dilation=True,
- det_db_unclip_ratio=det_db_unclip_ratio,
- enable_merge_det_boxes=enable_merge_det_boxes,
- )
- return model
- class AtomModelSingleton:
- _instance = None
- _models = {}
- def __new__(cls, *args, **kwargs):
- if cls._instance is None:
- cls._instance = super().__new__(cls)
- return cls._instance
- def get_atom_model(self, atom_model_name: str, **kwargs):
- lang = kwargs.get('lang', None)
- if atom_model_name in [AtomicModel.WiredTable, AtomicModel.WirelessTable]:
- key = (atom_model_name, lang)
- elif atom_model_name in [AtomicModel.OCR]:
- key = (atom_model_name,
- kwargs.get('det_db_box_thresh', 0.3),
- lang, kwargs.get('det_db_unclip_ratio', 1.8),
- kwargs.get('enable_merge_det_boxes', True)
- )
- else:
- key = atom_model_name
- if key not in self._models:
- self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
- return self._models[key]
- def atom_model_init(model_name: str, **kwargs):
- atom_model = None
- if model_name == AtomicModel.Layout:
- atom_model = doclayout_yolo_model_init(
- kwargs.get('doclayout_yolo_weights'),
- kwargs.get('device')
- )
- elif model_name == AtomicModel.MFD:
- atom_model = mfd_model_init(
- kwargs.get('mfd_weights'),
- kwargs.get('device')
- )
- elif model_name == AtomicModel.MFR:
- atom_model = mfr_model_init(
- kwargs.get('mfr_weight_dir'),
- kwargs.get('device')
- )
- elif model_name == AtomicModel.OCR:
- atom_model = ocr_model_init(
- kwargs.get('det_db_box_thresh', 0.3),
- kwargs.get('lang'),
- kwargs.get('det_db_unclip_ratio', 1.8),
- kwargs.get('enable_merge_det_boxes', True)
- )
- elif model_name == AtomicModel.WirelessTable:
- atom_model = wireless_table_model_init(
- kwargs.get('lang'),
- )
- elif model_name == AtomicModel.WiredTable:
- atom_model = wired_table_model_init(
- kwargs.get('lang'),
- )
- elif model_name == AtomicModel.TableCls:
- atom_model = table_cls_model_init()
- elif model_name == AtomicModel.ImgOrientationCls:
- atom_model = img_orientation_cls_model_init()
- else:
- logger.error('model name not allow')
- exit(1)
- if atom_model is None:
- logger.error('model init failed')
- exit(1)
- else:
- return atom_model
- class MineruPipelineModel:
- def __init__(self, **kwargs):
- self.formula_config = kwargs.get('formula_config')
- self.apply_formula = self.formula_config.get('enable', True)
- self.table_config = kwargs.get('table_config')
- self.apply_table = self.table_config.get('enable', True)
- self.lang = kwargs.get('lang', None)
- self.device = kwargs.get('device', 'cpu')
- logger.info(
- 'DocAnalysis init, this may take some times......'
- )
- atom_model_manager = AtomModelSingleton()
- if self.apply_formula:
- # 初始化公式检测模型
- self.mfd_model = atom_model_manager.get_atom_model(
- atom_model_name=AtomicModel.MFD,
- mfd_weights=str(
- os.path.join(auto_download_and_get_model_root_path(ModelPath.yolo_v8_mfd), ModelPath.yolo_v8_mfd)
- ),
- device=self.device,
- )
- # 初始化公式解析模型
- mfr_weight_dir = os.path.join(auto_download_and_get_model_root_path(ModelPath.unimernet_small), ModelPath.unimernet_small)
- self.mfr_model = atom_model_manager.get_atom_model(
- atom_model_name=AtomicModel.MFR,
- mfr_weight_dir=mfr_weight_dir,
- device=self.device,
- )
- # 初始化layout模型
- self.layout_model = atom_model_manager.get_atom_model(
- atom_model_name=AtomicModel.Layout,
- doclayout_yolo_weights=str(
- os.path.join(auto_download_and_get_model_root_path(ModelPath.doclayout_yolo), ModelPath.doclayout_yolo)
- ),
- device=self.device,
- )
- # 初始化ocr
- self.ocr_model = atom_model_manager.get_atom_model(
- atom_model_name=AtomicModel.OCR,
- det_db_box_thresh=0.3,
- lang=self.lang
- )
- # init table model
- if self.apply_table:
- self.wired_table_model = atom_model_manager.get_atom_model(
- atom_model_name=AtomicModel.WiredTable,
- lang=self.lang,
- )
- self.wireless_table_model = atom_model_manager.get_atom_model(
- atom_model_name=AtomicModel.WirelessTable,
- lang=self.lang,
- )
- self.table_cls_model = atom_model_manager.get_atom_model(
- atom_model_name=AtomicModel.TableCls,
- )
- self.img_orientation_cls_model = atom_model_manager.get_atom_model(
- atom_model_name=AtomicModel.ImgOrientationCls,
- lang=self.lang,
- )
- logger.info('DocAnalysis init done!')
|