model_init.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. import os
  2. import torch
  3. from loguru import logger
  4. from .model_list import AtomicModel
  5. from ...model.layout.doclayout_yolo import DocLayoutYOLOModel
  6. from ...model.mfd.yolo_v8 import YOLOv8MFDModel
  7. from ...model.mfr.unimernet.Unimernet import UnimernetModel
  8. from ...model.ocr.paddleocr2pytorch.pytorch_paddle import PytorchPaddleOCR
  9. from ...model.table.rapid_table import RapidTableModel
  10. from ...utils.enum_class import ModelPath
  11. from ...utils.models_download_utils import get_file_from_repos
  12. def table_model_init(lang=None):
  13. atom_model_manager = AtomModelSingleton()
  14. ocr_engine = atom_model_manager.get_atom_model(
  15. atom_model_name='ocr',
  16. det_db_box_thresh=0.5,
  17. det_db_unclip_ratio=1.6,
  18. lang=lang
  19. )
  20. table_model = RapidTableModel(ocr_engine)
  21. return table_model
  22. def mfd_model_init(weight, device='cpu'):
  23. if str(device).startswith('npu'):
  24. device = torch.device(device)
  25. mfd_model = YOLOv8MFDModel(weight, device)
  26. return mfd_model
  27. def mfr_model_init(weight_dir, device='cpu'):
  28. mfr_model = UnimernetModel(weight_dir, device)
  29. return mfr_model
  30. def doclayout_yolo_model_init(weight, device='cpu'):
  31. if str(device).startswith('npu'):
  32. device = torch.device(device)
  33. model = DocLayoutYOLOModel(weight, device)
  34. return model
  35. def ocr_model_init(det_db_box_thresh=0.3,
  36. lang=None,
  37. use_dilation=True,
  38. det_db_unclip_ratio=1.8,
  39. ):
  40. if lang is not None and lang != '':
  41. model = PytorchPaddleOCR(
  42. det_db_box_thresh=det_db_box_thresh,
  43. lang=lang,
  44. use_dilation=use_dilation,
  45. det_db_unclip_ratio=det_db_unclip_ratio,
  46. )
  47. else:
  48. model = PytorchPaddleOCR(
  49. det_db_box_thresh=det_db_box_thresh,
  50. use_dilation=use_dilation,
  51. det_db_unclip_ratio=det_db_unclip_ratio,
  52. )
  53. return model
  54. class AtomModelSingleton:
  55. _instance = None
  56. _models = {}
  57. def __new__(cls, *args, **kwargs):
  58. if cls._instance is None:
  59. cls._instance = super().__new__(cls)
  60. return cls._instance
  61. def get_atom_model(self, atom_model_name: str, **kwargs):
  62. lang = kwargs.get('lang', None)
  63. table_model_name = kwargs.get('table_model_name', None)
  64. if atom_model_name in [AtomicModel.OCR]:
  65. key = (atom_model_name, lang)
  66. elif atom_model_name in [AtomicModel.Table]:
  67. key = (atom_model_name, table_model_name, lang)
  68. else:
  69. key = atom_model_name
  70. if key not in self._models:
  71. self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
  72. return self._models[key]
  73. def atom_model_init(model_name: str, **kwargs):
  74. atom_model = None
  75. if model_name == AtomicModel.Layout:
  76. atom_model = doclayout_yolo_model_init(
  77. kwargs.get('doclayout_yolo_weights'),
  78. kwargs.get('device')
  79. )
  80. elif model_name == AtomicModel.MFD:
  81. atom_model = mfd_model_init(
  82. kwargs.get('mfd_weights'),
  83. kwargs.get('device')
  84. )
  85. elif model_name == AtomicModel.MFR:
  86. atom_model = mfr_model_init(
  87. kwargs.get('mfr_weight_dir'),
  88. kwargs.get('device')
  89. )
  90. elif model_name == AtomicModel.OCR:
  91. atom_model = ocr_model_init(
  92. kwargs.get('det_db_box_thresh'),
  93. kwargs.get('lang'),
  94. )
  95. elif model_name == AtomicModel.Table:
  96. atom_model = table_model_init(
  97. kwargs.get('lang'),
  98. )
  99. else:
  100. logger.error('model name not allow')
  101. exit(1)
  102. if atom_model is None:
  103. logger.error('model init failed')
  104. exit(1)
  105. else:
  106. return atom_model
  107. class MineruPipelineModel:
  108. def __init__(self, **kwargs):
  109. self.formula_config = kwargs.get('formula_config')
  110. self.apply_formula = self.formula_config.get('enable', True)
  111. self.table_config = kwargs.get('table_config')
  112. self.apply_table = self.table_config.get('enable', True)
  113. self.lang = kwargs.get('lang', None)
  114. self.device = kwargs.get('device', 'cpu')
  115. logger.info(
  116. 'DocAnalysis init, this may take some times......'
  117. )
  118. atom_model_manager = AtomModelSingleton()
  119. if self.apply_formula:
  120. # 初始化公式检测模型
  121. self.mfd_model = atom_model_manager.get_atom_model(
  122. atom_model_name=AtomicModel.MFD,
  123. mfd_weights=str(
  124. get_file_from_repos(ModelPath.yolo_v8_mfd)
  125. ),
  126. device=self.device,
  127. )
  128. # 初始化公式解析模型
  129. mfr_weight_dir = str(
  130. get_file_from_repos(ModelPath.unimernet_small)
  131. )
  132. self.mfr_model = atom_model_manager.get_atom_model(
  133. atom_model_name=AtomicModel.MFR,
  134. mfr_weight_dir=mfr_weight_dir,
  135. device=self.device,
  136. )
  137. # 初始化layout模型
  138. self.layout_model = atom_model_manager.get_atom_model(
  139. atom_model_name=AtomicModel.Layout,
  140. doclayout_yolo_weights=str(
  141. get_file_from_repos(ModelPath.doclayout_yolo)
  142. ),
  143. device=self.device,
  144. )
  145. # 初始化ocr
  146. self.ocr_model = atom_model_manager.get_atom_model(
  147. atom_model_name=AtomicModel.OCR,
  148. det_db_box_thresh=0.3,
  149. lang=self.lang
  150. )
  151. # init table model
  152. if self.apply_table:
  153. self.table_model = atom_model_manager.get_atom_model(
  154. atom_model_name=AtomicModel.Table,
  155. lang=self.lang,
  156. )
  157. logger.info('DocAnalysis init done!')