engine.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. # !/usr/bin/env python3
  2. # -*- coding: UTF-8 -*-
  3. ################################################################################
  4. #
  5. # Copyright (c) 2024 Baidu.com, Inc. All Rights Reserved
  6. #
  7. ################################################################################
  8. """
  9. Author: PaddlePaddle Authors
  10. """
  11. import os
  12. from .modules.base import build_dataset_checker, build_trainer, build_evaluater, build_predictor
  13. from .utils.result_saver import try_except_decorator
  14. from .utils import config
  15. from .utils.errors import raise_unsupported_api_error
  16. class Engine(object):
  17. """ Engine """
  18. def __init__(self):
  19. args = config.parse_args()
  20. self.config = config.get_config(
  21. args.config, overrides=args.override, show=False)
  22. self.mode = self.config.Global.mode
  23. self.output = self.config.Global.output
  24. @try_except_decorator
  25. def run(self):
  26. """ the main function """
  27. if self.config.Global.mode == "check_dataset":
  28. dataset_checker = build_dataset_checker(self.config)
  29. return dataset_checker.check()
  30. elif self.config.Global.mode == "train":
  31. trainer = build_trainer(self.config)
  32. trainer.train()
  33. elif self.config.Global.mode == "evaluate":
  34. evaluator = build_evaluater(self.config)
  35. return evaluator.evaluate()
  36. elif self.config.Global.mode == "export":
  37. raise_unsupported_api_error("export", self.__class__)
  38. elif self.config.Global.mode == "predict":
  39. predictor = build_predictor(self.config)
  40. return predictor.predict()
  41. else:
  42. raise_unsupported_api_error(f"{self.config.Global.mode}",
  43. self.__class__)