model_init.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. from loguru import logger
  2. from magic_pdf.libs.Constants import MODEL_NAME
  3. from magic_pdf.model.model_list import AtomicModel
  4. from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import DocLayoutYOLOModel
  5. from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import Layoutlmv3_Predictor
  6. from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel
  7. from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel
  8. from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import ModifiedPaddleOCR
  9. # from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
  10. from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import StructTableModel
  11. from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import TableMasterPaddleModel
  12. from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
  13. def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
  14. if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
  15. table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
  16. elif table_model_type == MODEL_NAME.TABLE_MASTER:
  17. config = {
  18. "model_dir": model_path,
  19. "device": _device_
  20. }
  21. table_model = TableMasterPaddleModel(config)
  22. elif table_model_type == MODEL_NAME.RAPID_TABLE:
  23. table_model = RapidTableModel()
  24. else:
  25. logger.error("table model type not allow")
  26. exit(1)
  27. return table_model
  28. def mfd_model_init(weight, device='cpu'):
  29. mfd_model = YOLOv8MFDModel(weight, device)
  30. return mfd_model
  31. def mfr_model_init(weight_dir, cfg_path, device='cpu'):
  32. mfr_model = UnimernetModel(weight_dir, cfg_path, device)
  33. return mfr_model
  34. def layout_model_init(weight, config_file, device):
  35. model = Layoutlmv3_Predictor(weight, config_file, device)
  36. return model
  37. def doclayout_yolo_model_init(weight, device='cpu'):
  38. model = DocLayoutYOLOModel(weight, device)
  39. return model
  40. def ocr_model_init(show_log: bool = False,
  41. det_db_box_thresh=0.3,
  42. lang=None,
  43. use_dilation=True,
  44. det_db_unclip_ratio=1.8,
  45. ):
  46. if lang is not None:
  47. model = ModifiedPaddleOCR(
  48. show_log=show_log,
  49. det_db_box_thresh=det_db_box_thresh,
  50. lang=lang,
  51. use_dilation=use_dilation,
  52. det_db_unclip_ratio=det_db_unclip_ratio,
  53. )
  54. else:
  55. model = ModifiedPaddleOCR(
  56. show_log=show_log,
  57. det_db_box_thresh=det_db_box_thresh,
  58. use_dilation=use_dilation,
  59. det_db_unclip_ratio=det_db_unclip_ratio,
  60. # use_angle_cls=True,
  61. )
  62. return model
  63. class AtomModelSingleton:
  64. _instance = None
  65. _models = {}
  66. def __new__(cls, *args, **kwargs):
  67. if cls._instance is None:
  68. cls._instance = super().__new__(cls)
  69. return cls._instance
  70. def get_atom_model(self, atom_model_name: str, **kwargs):
  71. lang = kwargs.get("lang", None)
  72. layout_model_name = kwargs.get("layout_model_name", None)
  73. key = (atom_model_name, layout_model_name, lang)
  74. if key not in self._models:
  75. self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
  76. return self._models[key]
  77. def atom_model_init(model_name: str, **kwargs):
  78. atom_model = None
  79. if model_name == AtomicModel.Layout:
  80. if kwargs.get("layout_model_name") == MODEL_NAME.LAYOUTLMv3:
  81. atom_model = layout_model_init(
  82. kwargs.get("layout_weights"),
  83. kwargs.get("layout_config_file"),
  84. kwargs.get("device")
  85. )
  86. elif kwargs.get("layout_model_name") == MODEL_NAME.DocLayout_YOLO:
  87. atom_model = doclayout_yolo_model_init(
  88. kwargs.get("doclayout_yolo_weights"),
  89. kwargs.get("device")
  90. )
  91. elif model_name == AtomicModel.MFD:
  92. atom_model = mfd_model_init(
  93. kwargs.get("mfd_weights"),
  94. kwargs.get("device")
  95. )
  96. elif model_name == AtomicModel.MFR:
  97. atom_model = mfr_model_init(
  98. kwargs.get("mfr_weight_dir"),
  99. kwargs.get("mfr_cfg_path"),
  100. kwargs.get("device")
  101. )
  102. elif model_name == AtomicModel.OCR:
  103. atom_model = ocr_model_init(
  104. kwargs.get("ocr_show_log"),
  105. kwargs.get("det_db_box_thresh"),
  106. kwargs.get("lang")
  107. )
  108. elif model_name == AtomicModel.Table:
  109. atom_model = table_model_init(
  110. kwargs.get("table_model_name"),
  111. kwargs.get("table_model_path"),
  112. kwargs.get("table_max_time"),
  113. kwargs.get("device")
  114. )
  115. else:
  116. logger.error("model name not allow")
  117. exit(1)
  118. if atom_model is None:
  119. logger.error("model init failed")
  120. exit(1)
  121. else:
  122. return atom_model