model_init.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. import os
  2. import torch
  3. from loguru import logger
  4. from magic_pdf.config.constants import MODEL_NAME
  5. from magic_pdf.model.model_list import AtomicModel
  6. from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import \
  7. YOLOv11LangDetModel
  8. from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import \
  9. DocLayoutYOLOModel
  10. from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import \
  11. Layoutlmv3_Predictor
  12. from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel
  13. from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel
  14. try:
  15. from magic_pdf_ascend_plugin.libs.license_verifier import (
  16. LicenseExpiredError, LicenseFormatError, LicenseSignatureError,
  17. load_license)
  18. from magic_pdf_ascend_plugin.model_plugin.ocr.paddleocr.ppocr_273_npu import \
  19. ModifiedPaddleOCR
  20. from magic_pdf_ascend_plugin.model_plugin.table.rapidtable.rapid_table_npu import \
  21. RapidTableModel
  22. license_key = load_license()
  23. logger.info(f'Using Ascend Plugin Success, License id is {license_key["payload"]["id"]},'
  24. f' License expired at {license_key["payload"]["date"]["end_date"]}')
  25. except Exception as e:
  26. if isinstance(e, ImportError):
  27. pass
  28. elif isinstance(e, LicenseFormatError):
  29. logger.error('Ascend Plugin: Invalid license format. Please check the license file.')
  30. elif isinstance(e, LicenseSignatureError):
  31. logger.error('Ascend Plugin: Invalid signature. The license may be tampered with.')
  32. elif isinstance(e, LicenseExpiredError):
  33. logger.error('Ascend Plugin: License has expired. Please renew your license.')
  34. elif isinstance(e, FileNotFoundError):
  35. logger.error('Ascend Plugin: Not found License file.')
  36. else:
  37. logger.error(f'Ascend Plugin: {e}')
  38. from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import ModifiedPaddleOCR
  39. # from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
  40. from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
  41. from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import \
  42. StructTableModel
  43. from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import \
  44. TableMasterPaddleModel
  45. def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None, table_sub_model_name=None):
  46. if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
  47. table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
  48. elif table_model_type == MODEL_NAME.TABLE_MASTER:
  49. config = {
  50. 'model_dir': model_path,
  51. 'device': _device_
  52. }
  53. table_model = TableMasterPaddleModel(config)
  54. elif table_model_type == MODEL_NAME.RAPID_TABLE:
  55. table_model = RapidTableModel(ocr_engine, table_sub_model_name)
  56. else:
  57. logger.error('table model type not allow')
  58. exit(1)
  59. return table_model
  60. def mfd_model_init(weight, device='cpu'):
  61. if str(device).startswith('npu'):
  62. device = torch.device(device)
  63. mfd_model = YOLOv8MFDModel(weight, device)
  64. return mfd_model
  65. def mfr_model_init(weight_dir, cfg_path, device='cpu'):
  66. mfr_model = UnimernetModel(weight_dir, cfg_path, device)
  67. return mfr_model
  68. def layout_model_init(weight, config_file, device):
  69. model = Layoutlmv3_Predictor(weight, config_file, device)
  70. return model
  71. def doclayout_yolo_model_init(weight, device='cpu'):
  72. if str(device).startswith('npu'):
  73. device = torch.device(device)
  74. model = DocLayoutYOLOModel(weight, device)
  75. return model
  76. def langdetect_model_init(langdetect_model_weight, device='cpu'):
  77. if str(device).startswith('npu'):
  78. device = torch.device(device)
  79. model = YOLOv11LangDetModel(langdetect_model_weight, device)
  80. return model
  81. def ocr_model_init(show_log: bool = False,
  82. det_db_box_thresh=0.3,
  83. lang=None,
  84. use_dilation=True,
  85. det_db_unclip_ratio=1.8,
  86. ):
  87. if lang is not None and lang != '':
  88. model = ModifiedPaddleOCR(
  89. show_log=show_log,
  90. det_db_box_thresh=det_db_box_thresh,
  91. lang=lang,
  92. use_dilation=use_dilation,
  93. det_db_unclip_ratio=det_db_unclip_ratio,
  94. )
  95. else:
  96. model = ModifiedPaddleOCR(
  97. show_log=show_log,
  98. det_db_box_thresh=det_db_box_thresh,
  99. use_dilation=use_dilation,
  100. det_db_unclip_ratio=det_db_unclip_ratio,
  101. )
  102. return model
  103. class AtomModelSingleton:
  104. _instance = None
  105. _models = {}
  106. def __new__(cls, *args, **kwargs):
  107. if cls._instance is None:
  108. cls._instance = super().__new__(cls)
  109. return cls._instance
  110. def get_atom_model(self, atom_model_name: str, **kwargs):
  111. lang = kwargs.get('lang', None)
  112. layout_model_name = kwargs.get('layout_model_name', None)
  113. table_model_name = kwargs.get('table_model_name', None)
  114. if atom_model_name in [AtomicModel.OCR]:
  115. key = (atom_model_name, lang)
  116. elif atom_model_name in [AtomicModel.Layout]:
  117. key = (atom_model_name, layout_model_name)
  118. elif atom_model_name in [AtomicModel.Table]:
  119. key = (atom_model_name, table_model_name)
  120. else:
  121. key = atom_model_name
  122. if key not in self._models:
  123. self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
  124. return self._models[key]
  125. def atom_model_init(model_name: str, **kwargs):
  126. atom_model = None
  127. if model_name == AtomicModel.Layout:
  128. if kwargs.get('layout_model_name') == MODEL_NAME.LAYOUTLMv3:
  129. atom_model = layout_model_init(
  130. kwargs.get('layout_weights'),
  131. kwargs.get('layout_config_file'),
  132. kwargs.get('device')
  133. )
  134. elif kwargs.get('layout_model_name') == MODEL_NAME.DocLayout_YOLO:
  135. atom_model = doclayout_yolo_model_init(
  136. kwargs.get('doclayout_yolo_weights'),
  137. kwargs.get('device')
  138. )
  139. else:
  140. logger.error('layout model name not allow')
  141. exit(1)
  142. elif model_name == AtomicModel.MFD:
  143. atom_model = mfd_model_init(
  144. kwargs.get('mfd_weights'),
  145. kwargs.get('device')
  146. )
  147. elif model_name == AtomicModel.MFR:
  148. atom_model = mfr_model_init(
  149. kwargs.get('mfr_weight_dir'),
  150. kwargs.get('mfr_cfg_path'),
  151. kwargs.get('device')
  152. )
  153. elif model_name == AtomicModel.OCR:
  154. atom_model = ocr_model_init(
  155. kwargs.get('ocr_show_log'),
  156. kwargs.get('det_db_box_thresh'),
  157. kwargs.get('lang'),
  158. )
  159. elif model_name == AtomicModel.Table:
  160. atom_model = table_model_init(
  161. kwargs.get('table_model_name'),
  162. kwargs.get('table_model_path'),
  163. kwargs.get('table_max_time'),
  164. kwargs.get('device'),
  165. kwargs.get('ocr_engine'),
  166. kwargs.get('table_sub_model_name')
  167. )
  168. elif model_name == AtomicModel.LangDetect:
  169. if kwargs.get('langdetect_model_name') == MODEL_NAME.YOLO_V11_LangDetect:
  170. atom_model = langdetect_model_init(
  171. kwargs.get('langdetect_model_weight'),
  172. kwargs.get('device')
  173. )
  174. else:
  175. logger.error('langdetect model name not allow')
  176. exit(1)
  177. else:
  178. logger.error('model name not allow')
  179. exit(1)
  180. if atom_model is None:
  181. logger.error('model init failed')
  182. exit(1)
  183. else:
  184. return atom_model