engine.py 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. # !/usr/bin/env python3
  2. # -*- coding: UTF-8 -*-
  3. ################################################################################
  4. #
  5. # Copyright (c) 2024 Baidu.com, Inc. All Rights Reserved
  6. #
  7. ################################################################################
  8. """
  9. Author: PaddlePaddle Authors
  10. """
  11. import os
  12. from .modules.base import build_dataset_checker, build_trainer, build_evaluater
  13. from .utils.result_saver import try_except_decorator
  14. from .utils import config
  15. from .utils.errors import raise_unsupported_api_error
  16. class Engine(object):
  17. """ Engine """
  18. def __init__(self):
  19. args = config.parse_args()
  20. self.config = config.get_config(
  21. args.config, overrides=args.override, show=False)
  22. self.mode = self.config.Global.mode
  23. self.output_dir = self.config.Global.output
  24. @try_except_decorator
  25. def run(self):
  26. """ the main function """
  27. if self.config.Global.mode == "check_dataset":
  28. check_dataset = build_dataset_checker(self.config)
  29. return check_dataset()
  30. elif self.config.Global.mode == "train":
  31. train = build_trainer(self.config)
  32. train()
  33. elif self.config.Global.mode == "evaluate":
  34. evaluate = build_evaluater(self.config)
  35. return evaluate()
  36. elif self.config.Global.mode == "export":
  37. raise_unsupported_api_error("export", self.__class__)
  38. else:
  39. raise_unsupported_api_error(f"{self.config.Global.mode}",
  40. self.__class__)