model_init.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. import torch
  2. from loguru import logger
  3. from magic_pdf.config.constants import MODEL_NAME
  4. from magic_pdf.model.model_list import AtomicModel
  5. from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import YOLOv11LangDetModel
  6. from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import DocLayoutYOLOModel
  7. from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import Layoutlmv3_Predictor
  8. from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel
  9. from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel
  10. from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import StructTableModel
  11. from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import TableMasterPaddleModel
  12. try:
  13. from magic_pdf_ascend_plugin.model_plugin.ocr.paddleocr.ppocr_273_npu import ModifiedPaddleOCR
  14. from magic_pdf_ascend_plugin.model_plugin.table.rapidtable.rapid_table_npu import RapidTableModel
  15. logger.info('Using Ascend Plugin')
  16. except ImportError:
  17. from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import ModifiedPaddleOCR
  18. # from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
  19. from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
  20. def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None, table_sub_model_name=None):
  21. if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
  22. table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
  23. elif table_model_type == MODEL_NAME.TABLE_MASTER:
  24. config = {
  25. 'model_dir': model_path,
  26. 'device': _device_
  27. }
  28. table_model = TableMasterPaddleModel(config)
  29. elif table_model_type == MODEL_NAME.RAPID_TABLE:
  30. table_model = RapidTableModel(ocr_engine, table_sub_model_name)
  31. else:
  32. logger.error('table model type not allow')
  33. exit(1)
  34. return table_model
  35. def mfd_model_init(weight, device='cpu'):
  36. if str(device).startswith("npu"):
  37. device = torch.device(device)
  38. mfd_model = YOLOv8MFDModel(weight, device)
  39. return mfd_model
  40. def mfr_model_init(weight_dir, cfg_path, device='cpu'):
  41. mfr_model = UnimernetModel(weight_dir, cfg_path, device)
  42. return mfr_model
  43. def layout_model_init(weight, config_file, device):
  44. model = Layoutlmv3_Predictor(weight, config_file, device)
  45. return model
  46. def doclayout_yolo_model_init(weight, device='cpu'):
  47. if str(device).startswith("npu"):
  48. device = torch.device(device)
  49. model = DocLayoutYOLOModel(weight, device)
  50. return model
  51. def langdetect_model_init(langdetect_model_weight, device='cpu'):
  52. if str(device).startswith("npu"):
  53. device = torch.device(device)
  54. model = YOLOv11LangDetModel(langdetect_model_weight, device)
  55. return model
  56. def ocr_model_init(show_log: bool = False,
  57. det_db_box_thresh=0.3,
  58. lang=None,
  59. use_dilation=True,
  60. det_db_unclip_ratio=1.8,
  61. ):
  62. if lang is not None and lang != '':
  63. model = ModifiedPaddleOCR(
  64. show_log=show_log,
  65. det_db_box_thresh=det_db_box_thresh,
  66. lang=lang,
  67. use_dilation=use_dilation,
  68. det_db_unclip_ratio=det_db_unclip_ratio,
  69. )
  70. else:
  71. model = ModifiedPaddleOCR(
  72. show_log=show_log,
  73. det_db_box_thresh=det_db_box_thresh,
  74. use_dilation=use_dilation,
  75. det_db_unclip_ratio=det_db_unclip_ratio,
  76. )
  77. return model
  78. class AtomModelSingleton:
  79. _instance = None
  80. _models = {}
  81. def __new__(cls, *args, **kwargs):
  82. if cls._instance is None:
  83. cls._instance = super().__new__(cls)
  84. return cls._instance
  85. def get_atom_model(self, atom_model_name: str, **kwargs):
  86. lang = kwargs.get('lang', None)
  87. layout_model_name = kwargs.get('layout_model_name', None)
  88. table_model_name = kwargs.get('table_model_name', None)
  89. if atom_model_name in [AtomicModel.OCR]:
  90. key = (atom_model_name, lang)
  91. elif atom_model_name in [AtomicModel.Layout]:
  92. key = (atom_model_name, layout_model_name)
  93. elif atom_model_name in [AtomicModel.Table]:
  94. key = (atom_model_name, table_model_name)
  95. else:
  96. key = atom_model_name
  97. if key not in self._models:
  98. self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
  99. return self._models[key]
  100. def atom_model_init(model_name: str, **kwargs):
  101. atom_model = None
  102. if model_name == AtomicModel.Layout:
  103. if kwargs.get('layout_model_name') == MODEL_NAME.LAYOUTLMv3:
  104. atom_model = layout_model_init(
  105. kwargs.get('layout_weights'),
  106. kwargs.get('layout_config_file'),
  107. kwargs.get('device')
  108. )
  109. elif kwargs.get('layout_model_name') == MODEL_NAME.DocLayout_YOLO:
  110. atom_model = doclayout_yolo_model_init(
  111. kwargs.get('doclayout_yolo_weights'),
  112. kwargs.get('device')
  113. )
  114. else:
  115. logger.error('layout model name not allow')
  116. exit(1)
  117. elif model_name == AtomicModel.MFD:
  118. atom_model = mfd_model_init(
  119. kwargs.get('mfd_weights'),
  120. kwargs.get('device')
  121. )
  122. elif model_name == AtomicModel.MFR:
  123. atom_model = mfr_model_init(
  124. kwargs.get('mfr_weight_dir'),
  125. kwargs.get('mfr_cfg_path'),
  126. kwargs.get('device')
  127. )
  128. elif model_name == AtomicModel.OCR:
  129. atom_model = ocr_model_init(
  130. kwargs.get('ocr_show_log'),
  131. kwargs.get('det_db_box_thresh'),
  132. kwargs.get('lang'),
  133. )
  134. elif model_name == AtomicModel.Table:
  135. atom_model = table_model_init(
  136. kwargs.get('table_model_name'),
  137. kwargs.get('table_model_path'),
  138. kwargs.get('table_max_time'),
  139. kwargs.get('device'),
  140. kwargs.get('ocr_engine'),
  141. kwargs.get('table_sub_model_name')
  142. )
  143. elif model_name == AtomicModel.LangDetect:
  144. if kwargs.get('langdetect_model_name') == MODEL_NAME.YOLO_V11_LangDetect:
  145. atom_model = langdetect_model_init(
  146. kwargs.get('langdetect_model_weight'),
  147. kwargs.get('device')
  148. )
  149. else:
  150. logger.error('langdetect model name not allow')
  151. exit(1)
  152. else:
  153. logger.error('model name not allow')
  154. exit(1)
  155. if atom_model is None:
  156. logger.error('model init failed')
  157. exit(1)
  158. else:
  159. return atom_model