model_init.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. from loguru import logger
  2. from magic_pdf.config.constants import MODEL_NAME
  3. from magic_pdf.model.model_list import AtomicModel
  4. from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import \
  5. DocLayoutYOLOModel
  6. from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import \
  7. Layoutlmv3_Predictor
  8. from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel
  9. from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel
  10. from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import \
  11. ModifiedPaddleOCR
  12. from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import \
  13. RapidTableModel
  14. # from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
  15. from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import \
  16. StructTableModel
  17. from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import \
  18. TableMasterPaddleModel
  19. def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
  20. if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
  21. table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
  22. elif table_model_type == MODEL_NAME.TABLE_MASTER:
  23. config = {
  24. 'model_dir': model_path,
  25. 'device': _device_
  26. }
  27. table_model = TableMasterPaddleModel(config)
  28. elif table_model_type == MODEL_NAME.RAPID_TABLE:
  29. table_model = RapidTableModel()
  30. else:
  31. logger.error('table model type not allow')
  32. exit(1)
  33. return table_model
  34. def mfd_model_init(weight, device='cpu'):
  35. mfd_model = YOLOv8MFDModel(weight, device)
  36. return mfd_model
  37. def mfr_model_init(weight_dir, cfg_path, device='cpu'):
  38. mfr_model = UnimernetModel(weight_dir, cfg_path, device)
  39. return mfr_model
  40. def layout_model_init(weight, config_file, device):
  41. model = Layoutlmv3_Predictor(weight, config_file, device)
  42. return model
  43. def doclayout_yolo_model_init(weight, device='cpu'):
  44. model = DocLayoutYOLOModel(weight, device)
  45. return model
  46. def ocr_model_init(show_log: bool = False,
  47. det_db_box_thresh=0.3,
  48. lang=None,
  49. use_dilation=True,
  50. det_db_unclip_ratio=1.8,
  51. ):
  52. if lang is not None and lang != '':
  53. model = ModifiedPaddleOCR(
  54. show_log=show_log,
  55. det_db_box_thresh=det_db_box_thresh,
  56. lang=lang,
  57. use_dilation=use_dilation,
  58. det_db_unclip_ratio=det_db_unclip_ratio,
  59. )
  60. else:
  61. model = ModifiedPaddleOCR(
  62. show_log=show_log,
  63. det_db_box_thresh=det_db_box_thresh,
  64. use_dilation=use_dilation,
  65. det_db_unclip_ratio=det_db_unclip_ratio,
  66. # use_angle_cls=True,
  67. )
  68. return model
  69. from threading import Lock
  70. class AtomModelSingleton:
  71. _instance = None
  72. _models = {}
  73. _lock = Lock()
  74. def __new__(cls, *args, **kwargs):
  75. if cls._instance is None:
  76. cls._instance = super().__new__(cls)
  77. return cls._instance
  78. def get_atom_model(self, atom_model_name: str, **kwargs):
  79. lang = kwargs.get('lang', None)
  80. layout_model_name = kwargs.get('layout_model_name', None)
  81. key = (atom_model_name, layout_model_name, lang)
  82. if atom_model_name == AtomicModel.OCR:
  83. with self._lock:
  84. if key not in self._models:
  85. self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
  86. else:
  87. return self._models[key]
  88. else:
  89. if key not in self._models:
  90. self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
  91. else:
  92. return self._models[key]
  93. def atom_model_init(model_name: str, **kwargs):
  94. atom_model = None
  95. if model_name == AtomicModel.Layout:
  96. if kwargs.get('layout_model_name') == MODEL_NAME.LAYOUTLMv3:
  97. atom_model = layout_model_init(
  98. kwargs.get('layout_weights'),
  99. kwargs.get('layout_config_file'),
  100. kwargs.get('device')
  101. )
  102. elif kwargs.get('layout_model_name') == MODEL_NAME.DocLayout_YOLO:
  103. atom_model = doclayout_yolo_model_init(
  104. kwargs.get('doclayout_yolo_weights'),
  105. kwargs.get('device')
  106. )
  107. elif model_name == AtomicModel.MFD:
  108. atom_model = mfd_model_init(
  109. kwargs.get('mfd_weights'),
  110. kwargs.get('device')
  111. )
  112. elif model_name == AtomicModel.MFR:
  113. atom_model = mfr_model_init(
  114. kwargs.get('mfr_weight_dir'),
  115. kwargs.get('mfr_cfg_path'),
  116. kwargs.get('device')
  117. )
  118. elif model_name == AtomicModel.OCR:
  119. atom_model = ocr_model_init(
  120. kwargs.get('ocr_show_log'),
  121. kwargs.get('det_db_box_thresh'),
  122. kwargs.get('lang')
  123. )
  124. elif model_name == AtomicModel.Table:
  125. atom_model = table_model_init(
  126. kwargs.get('table_model_name'),
  127. kwargs.get('table_model_path'),
  128. kwargs.get('table_max_time'),
  129. kwargs.get('device')
  130. )
  131. else:
  132. logger.error('model name not allow')
  133. exit(1)
  134. if atom_model is None:
  135. logger.error('model init failed')
  136. exit(1)
  137. else:
  138. return atom_model