model_init.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. import torch
  2. from loguru import logger
  3. from magic_pdf.config.constants import MODEL_NAME
  4. from magic_pdf.model.model_list import AtomicModel
  5. from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import YOLOv11LangDetModel
  6. from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import DocLayoutYOLOModel
  7. from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel
  8. from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel
  9. from magic_pdf.model.sub_modules.ocr.paddleocr2pytorch.pytorch_paddle import PytorchPaddleOCR
  10. from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
  11. # try:
  12. # from magic_pdf_ascend_plugin.libs.license_verifier import (
  13. # LicenseExpiredError, LicenseFormatError, LicenseSignatureError,
  14. # load_license)
  15. # from magic_pdf_ascend_plugin.model_plugin.ocr.paddleocr.ppocr_273_npu import ModifiedPaddleOCR
  16. # from magic_pdf_ascend_plugin.model_plugin.table.rapidtable.rapid_table_npu import RapidTableModel
  17. # license_key = load_license()
  18. # logger.info(f'Using Ascend Plugin Success, License id is {license_key["payload"]["id"]},'
  19. # f' License expired at {license_key["payload"]["date"]["end_date"]}')
  20. # except Exception as e:
  21. # if isinstance(e, ImportError):
  22. # pass
  23. # elif isinstance(e, LicenseFormatError):
  24. # logger.error('Ascend Plugin: Invalid license format. Please check the license file.')
  25. # elif isinstance(e, LicenseSignatureError):
  26. # logger.error('Ascend Plugin: Invalid signature. The license may be tampered with.')
  27. # elif isinstance(e, LicenseExpiredError):
  28. # logger.error('Ascend Plugin: License has expired. Please renew your license.')
  29. # elif isinstance(e, FileNotFoundError):
  30. # logger.error('Ascend Plugin: Not found License file.')
  31. # else:
  32. # logger.error(f'Ascend Plugin: {e}')
  33. # from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import ModifiedPaddleOCR
  34. # # from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
  35. # from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
  36. def table_model_init(table_model_type, model_path, max_time, _device_='cpu', lang=None, table_sub_model_name=None):
  37. if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
  38. from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import StructTableModel
  39. table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
  40. elif table_model_type == MODEL_NAME.TABLE_MASTER:
  41. from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import TableMasterPaddleModel
  42. config = {
  43. 'model_dir': model_path,
  44. 'device': _device_
  45. }
  46. table_model = TableMasterPaddleModel(config)
  47. elif table_model_type == MODEL_NAME.RAPID_TABLE:
  48. atom_model_manager = AtomModelSingleton()
  49. ocr_engine = atom_model_manager.get_atom_model(
  50. atom_model_name='ocr',
  51. ocr_show_log=False,
  52. det_db_box_thresh=0.5,
  53. det_db_unclip_ratio=1.6,
  54. lang=lang
  55. )
  56. table_model = RapidTableModel(ocr_engine, table_sub_model_name)
  57. else:
  58. logger.error('table model type not allow')
  59. exit(1)
  60. return table_model
  61. def mfd_model_init(weight, device='cpu'):
  62. if str(device).startswith('npu'):
  63. device = torch.device(device)
  64. mfd_model = YOLOv8MFDModel(weight, device)
  65. return mfd_model
  66. def mfr_model_init(weight_dir, cfg_path, device='cpu'):
  67. mfr_model = UnimernetModel(weight_dir, cfg_path, device)
  68. return mfr_model
  69. def layout_model_init(weight, config_file, device):
  70. from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import Layoutlmv3_Predictor
  71. model = Layoutlmv3_Predictor(weight, config_file, device)
  72. return model
  73. def doclayout_yolo_model_init(weight, device='cpu'):
  74. if str(device).startswith('npu'):
  75. device = torch.device(device)
  76. model = DocLayoutYOLOModel(weight, device)
  77. return model
  78. def langdetect_model_init(langdetect_model_weight, device='cpu'):
  79. if str(device).startswith('npu'):
  80. device = torch.device(device)
  81. model = YOLOv11LangDetModel(langdetect_model_weight, device)
  82. return model
  83. def ocr_model_init(show_log: bool = False,
  84. det_db_box_thresh=0.3,
  85. lang=None,
  86. use_dilation=True,
  87. det_db_unclip_ratio=1.8,
  88. ):
  89. if lang is not None and lang != '':
  90. # model = ModifiedPaddleOCR(
  91. model = PytorchPaddleOCR(
  92. show_log=show_log,
  93. det_db_box_thresh=det_db_box_thresh,
  94. lang=lang,
  95. use_dilation=use_dilation,
  96. det_db_unclip_ratio=det_db_unclip_ratio,
  97. )
  98. else:
  99. # model = ModifiedPaddleOCR(
  100. model = PytorchPaddleOCR(
  101. show_log=show_log,
  102. det_db_box_thresh=det_db_box_thresh,
  103. use_dilation=use_dilation,
  104. det_db_unclip_ratio=det_db_unclip_ratio,
  105. )
  106. return model
  107. class AtomModelSingleton:
  108. _instance = None
  109. _models = {}
  110. def __new__(cls, *args, **kwargs):
  111. if cls._instance is None:
  112. cls._instance = super().__new__(cls)
  113. return cls._instance
  114. def get_atom_model(self, atom_model_name: str, **kwargs):
  115. lang = kwargs.get('lang', None)
  116. layout_model_name = kwargs.get('layout_model_name', None)
  117. table_model_name = kwargs.get('table_model_name', None)
  118. if atom_model_name in [AtomicModel.OCR]:
  119. key = (atom_model_name, lang)
  120. elif atom_model_name in [AtomicModel.Layout]:
  121. key = (atom_model_name, layout_model_name)
  122. elif atom_model_name in [AtomicModel.Table]:
  123. key = (atom_model_name, table_model_name, lang)
  124. else:
  125. key = atom_model_name
  126. if key not in self._models:
  127. self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
  128. return self._models[key]
  129. def atom_model_init(model_name: str, **kwargs):
  130. atom_model = None
  131. if model_name == AtomicModel.Layout:
  132. if kwargs.get('layout_model_name') == MODEL_NAME.LAYOUTLMv3:
  133. atom_model = layout_model_init(
  134. kwargs.get('layout_weights'),
  135. kwargs.get('layout_config_file'),
  136. kwargs.get('device')
  137. )
  138. elif kwargs.get('layout_model_name') == MODEL_NAME.DocLayout_YOLO:
  139. atom_model = doclayout_yolo_model_init(
  140. kwargs.get('doclayout_yolo_weights'),
  141. kwargs.get('device')
  142. )
  143. else:
  144. logger.error('layout model name not allow')
  145. exit(1)
  146. elif model_name == AtomicModel.MFD:
  147. atom_model = mfd_model_init(
  148. kwargs.get('mfd_weights'),
  149. kwargs.get('device')
  150. )
  151. elif model_name == AtomicModel.MFR:
  152. atom_model = mfr_model_init(
  153. kwargs.get('mfr_weight_dir'),
  154. kwargs.get('mfr_cfg_path'),
  155. kwargs.get('device')
  156. )
  157. elif model_name == AtomicModel.OCR:
  158. atom_model = ocr_model_init(
  159. kwargs.get('ocr_show_log'),
  160. kwargs.get('det_db_box_thresh'),
  161. kwargs.get('lang'),
  162. )
  163. elif model_name == AtomicModel.Table:
  164. atom_model = table_model_init(
  165. kwargs.get('table_model_name'),
  166. kwargs.get('table_model_path'),
  167. kwargs.get('table_max_time'),
  168. kwargs.get('device'),
  169. kwargs.get('lang'),
  170. kwargs.get('table_sub_model_name')
  171. )
  172. elif model_name == AtomicModel.LangDetect:
  173. if kwargs.get('langdetect_model_name') == MODEL_NAME.YOLO_V11_LangDetect:
  174. atom_model = langdetect_model_init(
  175. kwargs.get('langdetect_model_weight'),
  176. kwargs.get('device')
  177. )
  178. else:
  179. logger.error('langdetect model name not allow')
  180. exit(1)
  181. else:
  182. logger.error('model name not allow')
  183. exit(1)
  184. if atom_model is None:
  185. logger.error('model init failed')
  186. exit(1)
  187. else:
  188. return atom_model