model_init.py 8.9 KB


  1. import os
  2. import torch
  3. from loguru import logger
  4. from .model_list import AtomicModel
  5. from ...model.layout.doclayoutyolo import DocLayoutYOLOModel
  6. from ...model.mfd.yolo_v8 import YOLOv8MFDModel
  7. from ...model.mfr.unimernet.Unimernet import UnimernetModel
  8. from ...model.mfr.pp_formulanet_plus_m.predict_formula import FormulaRecognizer
  9. from ...model.ocr.paddleocr2pytorch.pytorch_paddle import PytorchPaddleOCR
  10. from ...model.ori_cls.paddle_ori_cls import PaddleOrientationClsModel
  11. from ...model.table.cls.paddle_table_cls import PaddleTableClsModel
  12. # from ...model.table.rec.RapidTable import RapidTableModel
  13. from ...model.table.rec.slanet_plus.main import RapidTableModel
  14. from ...model.table.rec.unet_table.main import UnetTableModel
  15. from ...utils.enum_class import ModelPath
  16. from ...utils.models_download_utils import auto_download_and_get_model_root_path
  17. MFR_MODEL = os.getenv('MINERU_MFR_MODEL', None)
  18. if MFR_MODEL is None:
  19. # MFR_MODEL = "unimernet_small"
  20. MFR_MODEL = "pp_formulanet_plus_m"
  21. def img_orientation_cls_model_init():
  22. atom_model_manager = AtomModelSingleton()
  23. ocr_engine = atom_model_manager.get_atom_model(
  24. atom_model_name=AtomicModel.OCR,
  25. det_db_box_thresh=0.5,
  26. det_db_unclip_ratio=1.6,
  27. lang="ch_lite",
  28. enable_merge_det_boxes=False
  29. )
  30. cls_model = PaddleOrientationClsModel(ocr_engine)
  31. return cls_model
  32. def table_cls_model_init():
  33. return PaddleTableClsModel()
  34. def wired_table_model_init(lang=None):
  35. atom_model_manager = AtomModelSingleton()
  36. ocr_engine = atom_model_manager.get_atom_model(
  37. atom_model_name=AtomicModel.OCR,
  38. det_db_box_thresh=0.5,
  39. det_db_unclip_ratio=1.6,
  40. lang=lang,
  41. enable_merge_det_boxes=False
  42. )
  43. table_model = UnetTableModel(ocr_engine)
  44. return table_model
  45. def wireless_table_model_init(lang=None):
  46. atom_model_manager = AtomModelSingleton()
  47. ocr_engine = atom_model_manager.get_atom_model(
  48. atom_model_name=AtomicModel.OCR,
  49. det_db_box_thresh=0.5,
  50. det_db_unclip_ratio=1.6,
  51. lang=lang,
  52. enable_merge_det_boxes=False
  53. )
  54. table_model = RapidTableModel(ocr_engine)
  55. return table_model
  56. def mfd_model_init(weight, device='cpu'):
  57. if str(device).startswith('npu'):
  58. device = torch.device(device)
  59. mfd_model = YOLOv8MFDModel(weight, device)
  60. return mfd_model
  61. def mfr_model_init(weight_dir, device='cpu'):
  62. if MFR_MODEL == "unimernet_small":
  63. mfr_model = UnimernetModel(weight_dir, device)
  64. elif MFR_MODEL == "pp_formulanet_plus_m":
  65. mfr_model = FormulaRecognizer(weight_dir, device)
  66. else:
  67. logger.error('MFR model name not allow')
  68. exit(1)
  69. return mfr_model
  70. def doclayout_yolo_model_init(weight, device='cpu'):
  71. if str(device).startswith('npu'):
  72. device = torch.device(device)
  73. model = DocLayoutYOLOModel(weight, device)
  74. return model
  75. def ocr_model_init(det_db_box_thresh=0.3,
  76. lang=None,
  77. det_db_unclip_ratio=1.8,
  78. enable_merge_det_boxes=True
  79. ):
  80. if lang is not None and lang != '':
  81. model = PytorchPaddleOCR(
  82. det_db_box_thresh=det_db_box_thresh,
  83. lang=lang,
  84. use_dilation=True,
  85. det_db_unclip_ratio=det_db_unclip_ratio,
  86. enable_merge_det_boxes=enable_merge_det_boxes,
  87. )
  88. else:
  89. model = PytorchPaddleOCR(
  90. det_db_box_thresh=det_db_box_thresh,
  91. use_dilation=True,
  92. det_db_unclip_ratio=det_db_unclip_ratio,
  93. enable_merge_det_boxes=enable_merge_det_boxes,
  94. )
  95. return model
  96. class AtomModelSingleton:
  97. _instance = None
  98. _models = {}
  99. def __new__(cls, *args, **kwargs):
  100. if cls._instance is None:
  101. cls._instance = super().__new__(cls)
  102. return cls._instance
  103. def get_atom_model(self, atom_model_name: str, **kwargs):
  104. lang = kwargs.get('lang', None)
  105. if atom_model_name in [AtomicModel.WiredTable, AtomicModel.WirelessTable]:
  106. key = (
  107. atom_model_name,
  108. lang
  109. )
  110. elif atom_model_name in [AtomicModel.OCR]:
  111. key = (
  112. atom_model_name,
  113. kwargs.get('det_db_box_thresh', 0.3),
  114. lang,
  115. kwargs.get('det_db_unclip_ratio', 1.8),
  116. kwargs.get('enable_merge_det_boxes', True)
  117. )
  118. else:
  119. key = atom_model_name
  120. if key not in self._models:
  121. self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
  122. return self._models[key]
  123. def atom_model_init(model_name: str, **kwargs):
  124. atom_model = None
  125. if model_name == AtomicModel.Layout:
  126. atom_model = doclayout_yolo_model_init(
  127. kwargs.get('doclayout_yolo_weights'),
  128. kwargs.get('device')
  129. )
  130. elif model_name == AtomicModel.MFD:
  131. atom_model = mfd_model_init(
  132. kwargs.get('mfd_weights'),
  133. kwargs.get('device')
  134. )
  135. elif model_name == AtomicModel.MFR:
  136. atom_model = mfr_model_init(
  137. kwargs.get('mfr_weight_dir'),
  138. kwargs.get('device')
  139. )
  140. elif model_name == AtomicModel.OCR:
  141. atom_model = ocr_model_init(
  142. kwargs.get('det_db_box_thresh', 0.3),
  143. kwargs.get('lang'),
  144. kwargs.get('det_db_unclip_ratio', 1.8),
  145. kwargs.get('enable_merge_det_boxes', True)
  146. )
  147. elif model_name == AtomicModel.WirelessTable:
  148. atom_model = wireless_table_model_init(
  149. kwargs.get('lang'),
  150. )
  151. elif model_name == AtomicModel.WiredTable:
  152. atom_model = wired_table_model_init(
  153. kwargs.get('lang'),
  154. )
  155. elif model_name == AtomicModel.TableCls:
  156. atom_model = table_cls_model_init()
  157. elif model_name == AtomicModel.ImgOrientationCls:
  158. atom_model = img_orientation_cls_model_init()
  159. else:
  160. logger.error('model name not allow')
  161. exit(1)
  162. if atom_model is None:
  163. logger.error('model init failed')
  164. exit(1)
  165. else:
  166. return atom_model
  167. class MineruPipelineModel:
  168. def __init__(self, **kwargs):
  169. self.formula_config = kwargs.get('formula_config')
  170. self.apply_formula = self.formula_config.get('enable', True)
  171. self.table_config = kwargs.get('table_config')
  172. self.apply_table = self.table_config.get('enable', True)
  173. self.lang = kwargs.get('lang', None)
  174. self.device = kwargs.get('device', 'cpu')
  175. logger.info(
  176. 'DocAnalysis init, this may take some times......'
  177. )
  178. atom_model_manager = AtomModelSingleton()
  179. if self.apply_formula:
  180. # 初始化公式检测模型
  181. self.mfd_model = atom_model_manager.get_atom_model(
  182. atom_model_name=AtomicModel.MFD,
  183. mfd_weights=str(
  184. os.path.join(auto_download_and_get_model_root_path(ModelPath.yolo_v8_mfd), ModelPath.yolo_v8_mfd)
  185. ),
  186. device=self.device,
  187. )
  188. # 初始化公式解析模型
  189. if MFR_MODEL == "unimernet_small":
  190. mfr_model_path = ModelPath.unimernet_small
  191. elif MFR_MODEL == "pp_formulanet_plus_m":
  192. mfr_model_path = ModelPath.pp_formulanet_plus_m
  193. else:
  194. logger.error('MFR model name not allow')
  195. exit(1)
  196. self.mfr_model = atom_model_manager.get_atom_model(
  197. atom_model_name=AtomicModel.MFR,
  198. mfr_weight_dir=str(os.path.join(auto_download_and_get_model_root_path(mfr_model_path), mfr_model_path)),
  199. device=self.device,
  200. )
  201. # 初始化layout模型
  202. self.layout_model = atom_model_manager.get_atom_model(
  203. atom_model_name=AtomicModel.Layout,
  204. doclayout_yolo_weights=str(
  205. os.path.join(auto_download_and_get_model_root_path(ModelPath.doclayout_yolo), ModelPath.doclayout_yolo)
  206. ),
  207. device=self.device,
  208. )
  209. # 初始化ocr
  210. self.ocr_model = atom_model_manager.get_atom_model(
  211. atom_model_name=AtomicModel.OCR,
  212. det_db_box_thresh=0.3,
  213. lang=self.lang
  214. )
  215. # init table model
  216. if self.apply_table:
  217. self.wired_table_model = atom_model_manager.get_atom_model(
  218. atom_model_name=AtomicModel.WiredTable,
  219. lang=self.lang,
  220. )
  221. self.wireless_table_model = atom_model_manager.get_atom_model(
  222. atom_model_name=AtomicModel.WirelessTable,
  223. lang=self.lang,
  224. )
  225. self.table_cls_model = atom_model_manager.get_atom_model(
  226. atom_model_name=AtomicModel.TableCls,
  227. )
  228. self.img_orientation_cls_model = atom_model_manager.get_atom_model(
  229. atom_model_name=AtomicModel.ImgOrientationCls,
  230. lang=self.lang,
  231. )
  232. logger.info('DocAnalysis init done!')