model_init.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. import torch
  2. from loguru import logger
  3. from magic_pdf.config.constants import MODEL_NAME
  4. from magic_pdf.libs.config_reader import get_device
  5. from magic_pdf.model.model_list import AtomicModel
  6. from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import \
  7. DocLayoutYOLOModel
  8. from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import \
  9. Layoutlmv3_Predictor
  10. from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel
  11. from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel
  12. from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import \
  13. ModifiedPaddleOCR
  14. from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import \
  15. RapidTableModel
  16. # from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
  17. from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import \
  18. StructTableModel
  19. from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import \
  20. TableMasterPaddleModel
  21. def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None):
  22. if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
  23. table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
  24. elif table_model_type == MODEL_NAME.TABLE_MASTER:
  25. config = {
  26. 'model_dir': model_path,
  27. 'device': _device_
  28. }
  29. table_model = TableMasterPaddleModel(config)
  30. elif table_model_type == MODEL_NAME.RAPID_TABLE:
  31. table_model = RapidTableModel(ocr_engine)
  32. else:
  33. logger.error('table model type not allow')
  34. exit(1)
  35. return table_model
  36. def mfd_model_init(weight, device='cpu'):
  37. if str(device).startswith("npu"):
  38. device = torch.device(device)
  39. mfd_model = YOLOv8MFDModel(weight, device)
  40. return mfd_model
  41. def mfr_model_init(weight_dir, cfg_path, device='cpu'):
  42. mfr_model = UnimernetModel(weight_dir, cfg_path, device)
  43. return mfr_model
  44. def layout_model_init(weight, config_file, device):
  45. model = Layoutlmv3_Predictor(weight, config_file, device)
  46. return model
  47. def doclayout_yolo_model_init(weight, device='cpu'):
  48. if str(device).startswith("npu"):
  49. device = torch.device(device)
  50. model = DocLayoutYOLOModel(weight, device)
  51. return model
  52. def ocr_model_init(show_log: bool = False,
  53. det_db_box_thresh=0.3,
  54. lang=None,
  55. use_dilation=True,
  56. det_db_unclip_ratio=1.8,
  57. ):
  58. if lang is not None and lang != '':
  59. model = ModifiedPaddleOCR(
  60. show_log=show_log,
  61. det_db_box_thresh=det_db_box_thresh,
  62. lang=lang,
  63. use_dilation=use_dilation,
  64. det_db_unclip_ratio=det_db_unclip_ratio,
  65. )
  66. else:
  67. model = ModifiedPaddleOCR(
  68. show_log=show_log,
  69. det_db_box_thresh=det_db_box_thresh,
  70. use_dilation=use_dilation,
  71. det_db_unclip_ratio=det_db_unclip_ratio,
  72. )
  73. return model
  74. class AtomModelSingleton:
  75. _instance = None
  76. _models = {}
  77. def __new__(cls, *args, **kwargs):
  78. if cls._instance is None:
  79. cls._instance = super().__new__(cls)
  80. return cls._instance
  81. def get_atom_model(self, atom_model_name: str, **kwargs):
  82. lang = kwargs.get('lang', None)
  83. layout_model_name = kwargs.get('layout_model_name', None)
  84. table_model_name = kwargs.get('table_model_name', None)
  85. if atom_model_name in [AtomicModel.OCR]:
  86. key = (atom_model_name, lang)
  87. elif atom_model_name in [AtomicModel.Layout]:
  88. key = (atom_model_name, layout_model_name)
  89. elif atom_model_name in [AtomicModel.Table]:
  90. key = (atom_model_name, table_model_name)
  91. else:
  92. key = atom_model_name
  93. if key not in self._models:
  94. self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
  95. return self._models[key]
  96. def atom_model_init(model_name: str, **kwargs):
  97. atom_model = None
  98. if model_name == AtomicModel.Layout:
  99. if kwargs.get('layout_model_name') == MODEL_NAME.LAYOUTLMv3:
  100. atom_model = layout_model_init(
  101. kwargs.get('layout_weights'),
  102. kwargs.get('layout_config_file'),
  103. kwargs.get('device')
  104. )
  105. elif kwargs.get('layout_model_name') == MODEL_NAME.DocLayout_YOLO:
  106. atom_model = doclayout_yolo_model_init(
  107. kwargs.get('doclayout_yolo_weights'),
  108. kwargs.get('device')
  109. )
  110. elif model_name == AtomicModel.MFD:
  111. atom_model = mfd_model_init(
  112. kwargs.get('mfd_weights'),
  113. kwargs.get('device')
  114. )
  115. elif model_name == AtomicModel.MFR:
  116. atom_model = mfr_model_init(
  117. kwargs.get('mfr_weight_dir'),
  118. kwargs.get('mfr_cfg_path'),
  119. kwargs.get('device')
  120. )
  121. elif model_name == AtomicModel.OCR:
  122. atom_model = ocr_model_init(
  123. kwargs.get('ocr_show_log'),
  124. kwargs.get('det_db_box_thresh'),
  125. kwargs.get('lang'),
  126. )
  127. elif model_name == AtomicModel.Table:
  128. atom_model = table_model_init(
  129. kwargs.get('table_model_name'),
  130. kwargs.get('table_model_path'),
  131. kwargs.get('table_max_time'),
  132. kwargs.get('device'),
  133. kwargs.get('ocr_engine')
  134. )
  135. else:
  136. logger.error('model name not allow')
  137. exit(1)
  138. if atom_model is None:
  139. logger.error('model init failed')
  140. exit(1)
  141. else:
  142. return atom_model