model_init.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. import os
  2. import torch
  3. from loguru import logger
  4. from .model_list import AtomicModel
  5. from ...model.layout.doclayout_yolo import DocLayoutYOLOModel
  6. from ...model.mfd.yolo_v8 import YOLOv8MFDModel
  7. from ...model.mfr.unimernet.Unimernet import UnimernetModel
  8. from ...model.ocr.paddleocr2pytorch.pytorch_paddle import PytorchPaddleOCR
  9. from ...model.ori_cls.paddle_ori_cls import PaddleOrientationClsModel
  10. from ...model.table.cls.paddle_table_cls import PaddleTableClsModel
  11. from ...model.table.rec.rapid_table import RapidTableModel
  12. from ...model.table.rec.unet_table.unet_table import UnetTableModel
  13. from ...utils.enum_class import ModelPath
  14. from ...utils.models_download_utils import auto_download_and_get_model_root_path
  15. def img_orientation_cls_model_init():
  16. atom_model_manager = AtomModelSingleton()
  17. ocr_engine = atom_model_manager.get_atom_model(
  18. atom_model_name="ocr", det_db_box_thresh=0.5, det_db_unclip_ratio=1.6, lang="ch_lite"
  19. )
  20. cls_model = PaddleOrientationClsModel(ocr_engine)
  21. return cls_model
  22. def table_cls_model_init():
  23. return PaddleTableClsModel()
  24. def wired_table_model_init(lang=None):
  25. atom_model_manager = AtomModelSingleton()
  26. ocr_engine = atom_model_manager.get_atom_model(
  27. atom_model_name="ocr", det_db_box_thresh=0.5, det_db_unclip_ratio=1.6, lang=lang
  28. )
  29. table_model = UnetTableModel(ocr_engine)
  30. return table_model
  31. def wireless_table_model_init(lang=None):
  32. atom_model_manager = AtomModelSingleton()
  33. ocr_engine = atom_model_manager.get_atom_model(
  34. atom_model_name="ocr", det_db_box_thresh=0.5, det_db_unclip_ratio=1.6, lang=lang
  35. )
  36. table_model = RapidTableModel(ocr_engine)
  37. return table_model
  38. def mfd_model_init(weight, device='cpu'):
  39. if str(device).startswith('npu'):
  40. device = torch.device(device)
  41. mfd_model = YOLOv8MFDModel(weight, device)
  42. return mfd_model
  43. def mfr_model_init(weight_dir, device='cpu'):
  44. mfr_model = UnimernetModel(weight_dir, device)
  45. return mfr_model
  46. def doclayout_yolo_model_init(weight, device='cpu'):
  47. if str(device).startswith('npu'):
  48. device = torch.device(device)
  49. model = DocLayoutYOLOModel(weight, device)
  50. return model
  51. def ocr_model_init(det_db_box_thresh=0.3,
  52. lang=None,
  53. use_dilation=True,
  54. det_db_unclip_ratio=1.8,
  55. ):
  56. if lang is not None and lang != '':
  57. model = PytorchPaddleOCR(
  58. det_db_box_thresh=det_db_box_thresh,
  59. lang=lang,
  60. use_dilation=use_dilation,
  61. det_db_unclip_ratio=det_db_unclip_ratio,
  62. )
  63. else:
  64. model = PytorchPaddleOCR(
  65. det_db_box_thresh=det_db_box_thresh,
  66. use_dilation=use_dilation,
  67. det_db_unclip_ratio=det_db_unclip_ratio,
  68. )
  69. return model
  70. class AtomModelSingleton:
  71. _instance = None
  72. _models = {}
  73. def __new__(cls, *args, **kwargs):
  74. if cls._instance is None:
  75. cls._instance = super().__new__(cls)
  76. return cls._instance
  77. def get_atom_model(self, atom_model_name: str, **kwargs):
  78. lang = kwargs.get('lang', None)
  79. if atom_model_name in [AtomicModel.OCR, AtomicModel.WiredTable, AtomicModel.WirelessTable]:
  80. key = (atom_model_name, lang)
  81. else:
  82. key = atom_model_name
  83. if key not in self._models:
  84. self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
  85. return self._models[key]
  86. def atom_model_init(model_name: str, **kwargs):
  87. atom_model = None
  88. if model_name == AtomicModel.Layout:
  89. atom_model = doclayout_yolo_model_init(
  90. kwargs.get('doclayout_yolo_weights'),
  91. kwargs.get('device')
  92. )
  93. elif model_name == AtomicModel.MFD:
  94. atom_model = mfd_model_init(
  95. kwargs.get('mfd_weights'),
  96. kwargs.get('device')
  97. )
  98. elif model_name == AtomicModel.MFR:
  99. atom_model = mfr_model_init(
  100. kwargs.get('mfr_weight_dir'),
  101. kwargs.get('device')
  102. )
  103. elif model_name == AtomicModel.OCR:
  104. atom_model = ocr_model_init(
  105. kwargs.get('det_db_box_thresh'),
  106. kwargs.get('lang'),
  107. )
  108. elif model_name == AtomicModel.WirelessTable:
  109. atom_model = wireless_table_model_init(
  110. kwargs.get('lang'),
  111. )
  112. elif model_name == AtomicModel.WiredTable:
  113. atom_model = wired_table_model_init(
  114. kwargs.get('lang'),
  115. )
  116. elif model_name == AtomicModel.TableCls:
  117. atom_model = table_cls_model_init()
  118. elif model_name == AtomicModel.ImgOrientationCls:
  119. atom_model = img_orientation_cls_model_init()
  120. else:
  121. logger.error('model name not allow')
  122. exit(1)
  123. if atom_model is None:
  124. logger.error('model init failed')
  125. exit(1)
  126. else:
  127. return atom_model
  128. class MineruPipelineModel:
  129. def __init__(self, **kwargs):
  130. self.formula_config = kwargs.get('formula_config')
  131. self.apply_formula = self.formula_config.get('enable', True)
  132. self.table_config = kwargs.get('table_config')
  133. self.apply_table = self.table_config.get('enable', True)
  134. self.lang = kwargs.get('lang', None)
  135. self.device = kwargs.get('device', 'cpu')
  136. logger.info(
  137. 'DocAnalysis init, this may take some times......'
  138. )
  139. atom_model_manager = AtomModelSingleton()
  140. if self.apply_formula:
  141. # 初始化公式检测模型
  142. self.mfd_model = atom_model_manager.get_atom_model(
  143. atom_model_name=AtomicModel.MFD,
  144. mfd_weights=str(
  145. os.path.join(auto_download_and_get_model_root_path(ModelPath.yolo_v8_mfd), ModelPath.yolo_v8_mfd)
  146. ),
  147. device=self.device,
  148. )
  149. # 初始化公式解析模型
  150. mfr_weight_dir = os.path.join(auto_download_and_get_model_root_path(ModelPath.unimernet_small), ModelPath.unimernet_small)
  151. self.mfr_model = atom_model_manager.get_atom_model(
  152. atom_model_name=AtomicModel.MFR,
  153. mfr_weight_dir=mfr_weight_dir,
  154. device=self.device,
  155. )
  156. # 初始化layout模型
  157. self.layout_model = atom_model_manager.get_atom_model(
  158. atom_model_name=AtomicModel.Layout,
  159. doclayout_yolo_weights=str(
  160. os.path.join(auto_download_and_get_model_root_path(ModelPath.doclayout_yolo), ModelPath.doclayout_yolo)
  161. ),
  162. device=self.device,
  163. )
  164. # 初始化ocr
  165. self.ocr_model = atom_model_manager.get_atom_model(
  166. atom_model_name=AtomicModel.OCR,
  167. det_db_box_thresh=0.3,
  168. lang=self.lang
  169. )
  170. # init table model
  171. if self.apply_table:
  172. self.wired_table_model = atom_model_manager.get_atom_model(
  173. atom_model_name=AtomicModel.WiredTable,
  174. lang=self.lang,
  175. )
  176. self.wireless_table_model = atom_model_manager.get_atom_model(
  177. atom_model_name=AtomicModel.WirelessTable,
  178. lang=self.lang,
  179. )
  180. self.table_cls_model = atom_model_manager.get_atom_model(
  181. atom_model_name=AtomicModel.TableCls,
  182. )
  183. self.img_orientation_cls_model = atom_model_manager.get_atom_model(
  184. atom_model_name=AtomicModel.ImgOrientationCls,
  185. lang=self.lang,
  186. )
  187. logger.info('DocAnalysis init done!')