__init__.py 1.2 KB

123456789101112131415161718192021222324252627282930313233
  1. from __future__ import absolute_import
  2. from __future__ import division
  3. from __future__ import print_function
  4. from __future__ import unicode_literals
  5. import copy
  6. __all__ = ['build_post_process']
  7. def build_post_process(config, global_config=None):
  8. from .db_postprocess import DBPostProcess
  9. from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, TableLabelDecode, \
  10. NRTRLabelDecode, SARLabelDecode, ViTSTRLabelDecode, RFLLabelDecode
  11. from .cls_postprocess import ClsPostProcess
  12. from .rec_postprocess import CANLabelDecode
  13. support_dict = [
  14. 'DBPostProcess', 'CTCLabelDecode',
  15. 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode',
  16. 'TableLabelDecode', 'NRTRLabelDecode', 'SARLabelDecode',
  17. 'ViTSTRLabelDecode','CANLabelDecode', 'RFLLabelDecode'
  18. ]
  19. config = copy.deepcopy(config)
  20. module_name = config.pop('name')
  21. if global_config is not None:
  22. config.update(global_config)
  23. assert module_name in support_dict, Exception(
  24. 'post process only support {}, but got {}'.format(support_dict, module_name))
  25. module_class = eval(module_name)(**config)
  26. return module_class