model_init.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. import os
  2. import torch
  3. from loguru import logger
  4. from .model_list import AtomicModel
  5. from ...model.layout.doclayoutyolo import DocLayoutYOLOModel
  6. from ...model.mfd.yolo_v8 import YOLOv8MFDModel
  7. from ...model.mfr.unimernet.Unimernet import UnimernetModel
  8. from ...model.ocr.paddleocr2pytorch.pytorch_paddle import PytorchPaddleOCR
  9. from ...model.ori_cls.paddle_ori_cls import PaddleOrientationClsModel
  10. from ...model.table.cls.paddle_table_cls import PaddleTableClsModel
  11. # from ...model.table.rec.RapidTable import RapidTableModel
  12. from ...model.table.rec.slanet_plus.main import RapidTableModel
  13. from ...model.table.rec.unet_table.main import UnetTableModel
  14. from ...utils.enum_class import ModelPath
  15. from ...utils.models_download_utils import auto_download_and_get_model_root_path
  16. def img_orientation_cls_model_init():
  17. atom_model_manager = AtomModelSingleton()
  18. ocr_engine = atom_model_manager.get_atom_model(
  19. atom_model_name=AtomicModel.OCR,
  20. det_db_box_thresh=0.5,
  21. det_db_unclip_ratio=1.6,
  22. lang="ch_lite",
  23. enable_merge_det_boxes=False
  24. )
  25. cls_model = PaddleOrientationClsModel(ocr_engine)
  26. return cls_model
  27. def table_cls_model_init():
  28. return PaddleTableClsModel()
  29. def wired_table_model_init(lang=None):
  30. atom_model_manager = AtomModelSingleton()
  31. ocr_engine = atom_model_manager.get_atom_model(
  32. atom_model_name=AtomicModel.OCR,
  33. det_db_box_thresh=0.5,
  34. det_db_unclip_ratio=1.6,
  35. lang=lang,
  36. enable_merge_det_boxes=False
  37. )
  38. table_model = UnetTableModel(ocr_engine)
  39. return table_model
  40. def wireless_table_model_init(lang=None):
  41. atom_model_manager = AtomModelSingleton()
  42. ocr_engine = atom_model_manager.get_atom_model(
  43. atom_model_name=AtomicModel.OCR,
  44. det_db_box_thresh=0.5,
  45. det_db_unclip_ratio=1.6,
  46. lang=lang,
  47. enable_merge_det_boxes=False
  48. )
  49. table_model = RapidTableModel(ocr_engine)
  50. return table_model
  51. def mfd_model_init(weight, device='cpu'):
  52. if str(device).startswith('npu'):
  53. device = torch.device(device)
  54. mfd_model = YOLOv8MFDModel(weight, device)
  55. return mfd_model
  56. def mfr_model_init(weight_dir, device='cpu'):
  57. mfr_model = UnimernetModel(weight_dir, device)
  58. return mfr_model
  59. def doclayout_yolo_model_init(weight, device='cpu'):
  60. if str(device).startswith('npu'):
  61. device = torch.device(device)
  62. model = DocLayoutYOLOModel(weight, device)
  63. return model
  64. def ocr_model_init(det_db_box_thresh=0.3,
  65. lang=None,
  66. det_db_unclip_ratio=1.8,
  67. enable_merge_det_boxes=True
  68. ):
  69. if lang is not None and lang != '':
  70. model = PytorchPaddleOCR(
  71. det_db_box_thresh=det_db_box_thresh,
  72. lang=lang,
  73. use_dilation=True,
  74. det_db_unclip_ratio=det_db_unclip_ratio,
  75. enable_merge_det_boxes=enable_merge_det_boxes,
  76. )
  77. else:
  78. model = PytorchPaddleOCR(
  79. det_db_box_thresh=det_db_box_thresh,
  80. use_dilation=True,
  81. det_db_unclip_ratio=det_db_unclip_ratio,
  82. enable_merge_det_boxes=enable_merge_det_boxes,
  83. )
  84. return model
  85. class AtomModelSingleton:
  86. _instance = None
  87. _models = {}
  88. def __new__(cls, *args, **kwargs):
  89. if cls._instance is None:
  90. cls._instance = super().__new__(cls)
  91. return cls._instance
  92. def get_atom_model(self, atom_model_name: str, **kwargs):
  93. lang = kwargs.get('lang', None)
  94. if atom_model_name in [AtomicModel.WiredTable, AtomicModel.WirelessTable]:
  95. key = (
  96. atom_model_name,
  97. lang
  98. )
  99. elif atom_model_name in [AtomicModel.OCR]:
  100. key = (
  101. atom_model_name,
  102. kwargs.get('det_db_box_thresh', 0.3),
  103. lang,
  104. kwargs.get('det_db_unclip_ratio', 1.8),
  105. kwargs.get('enable_merge_det_boxes', True)
  106. )
  107. else:
  108. key = atom_model_name
  109. if key not in self._models:
  110. self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
  111. return self._models[key]
  112. def atom_model_init(model_name: str, **kwargs):
  113. atom_model = None
  114. if model_name == AtomicModel.Layout:
  115. atom_model = doclayout_yolo_model_init(
  116. kwargs.get('doclayout_yolo_weights'),
  117. kwargs.get('device')
  118. )
  119. elif model_name == AtomicModel.MFD:
  120. atom_model = mfd_model_init(
  121. kwargs.get('mfd_weights'),
  122. kwargs.get('device')
  123. )
  124. elif model_name == AtomicModel.MFR:
  125. atom_model = mfr_model_init(
  126. kwargs.get('mfr_weight_dir'),
  127. kwargs.get('device')
  128. )
  129. elif model_name == AtomicModel.OCR:
  130. atom_model = ocr_model_init(
  131. kwargs.get('det_db_box_thresh', 0.3),
  132. kwargs.get('lang'),
  133. kwargs.get('det_db_unclip_ratio', 1.8),
  134. kwargs.get('enable_merge_det_boxes', True)
  135. )
  136. elif model_name == AtomicModel.WirelessTable:
  137. atom_model = wireless_table_model_init(
  138. kwargs.get('lang'),
  139. )
  140. elif model_name == AtomicModel.WiredTable:
  141. atom_model = wired_table_model_init(
  142. kwargs.get('lang'),
  143. )
  144. elif model_name == AtomicModel.TableCls:
  145. atom_model = table_cls_model_init()
  146. elif model_name == AtomicModel.ImgOrientationCls:
  147. atom_model = img_orientation_cls_model_init()
  148. else:
  149. logger.error('model name not allow')
  150. exit(1)
  151. if atom_model is None:
  152. logger.error('model init failed')
  153. exit(1)
  154. else:
  155. return atom_model
  156. class MineruPipelineModel:
  157. def __init__(self, **kwargs):
  158. self.formula_config = kwargs.get('formula_config')
  159. self.apply_formula = self.formula_config.get('enable', True)
  160. self.table_config = kwargs.get('table_config')
  161. self.apply_table = self.table_config.get('enable', True)
  162. self.lang = kwargs.get('lang', None)
  163. self.device = kwargs.get('device', 'cpu')
  164. logger.info(
  165. 'DocAnalysis init, this may take some times......'
  166. )
  167. atom_model_manager = AtomModelSingleton()
  168. if self.apply_formula:
  169. # 初始化公式检测模型
  170. self.mfd_model = atom_model_manager.get_atom_model(
  171. atom_model_name=AtomicModel.MFD,
  172. mfd_weights=str(
  173. os.path.join(auto_download_and_get_model_root_path(ModelPath.yolo_v8_mfd), ModelPath.yolo_v8_mfd)
  174. ),
  175. device=self.device,
  176. )
  177. # 初始化公式解析模型
  178. mfr_weight_dir = os.path.join(auto_download_and_get_model_root_path(ModelPath.unimernet_small), ModelPath.unimernet_small)
  179. self.mfr_model = atom_model_manager.get_atom_model(
  180. atom_model_name=AtomicModel.MFR,
  181. mfr_weight_dir=mfr_weight_dir,
  182. device=self.device,
  183. )
  184. # 初始化layout模型
  185. self.layout_model = atom_model_manager.get_atom_model(
  186. atom_model_name=AtomicModel.Layout,
  187. doclayout_yolo_weights=str(
  188. os.path.join(auto_download_and_get_model_root_path(ModelPath.doclayout_yolo), ModelPath.doclayout_yolo)
  189. ),
  190. device=self.device,
  191. )
  192. # 初始化ocr
  193. self.ocr_model = atom_model_manager.get_atom_model(
  194. atom_model_name=AtomicModel.OCR,
  195. det_db_box_thresh=0.3,
  196. lang=self.lang
  197. )
  198. # init table model
  199. if self.apply_table:
  200. self.wired_table_model = atom_model_manager.get_atom_model(
  201. atom_model_name=AtomicModel.WiredTable,
  202. lang=self.lang,
  203. )
  204. self.wireless_table_model = atom_model_manager.get_atom_model(
  205. atom_model_name=AtomicModel.WirelessTable,
  206. lang=self.lang,
  207. )
  208. self.table_cls_model = atom_model_manager.get_atom_model(
  209. atom_model_name=AtomicModel.TableCls,
  210. )
  211. self.img_orientation_cls_model = atom_model_manager.get_atom_model(
  212. atom_model_name=AtomicModel.ImgOrientationCls,
  213. lang=self.lang,
  214. )
  215. logger.info('DocAnalysis init done!')