# !/usr/bin/env python3 # -*- coding: UTF-8 -*- ################################################################################ # # Copyright (c) 2024 Baidu.com, Inc. All Rights Reserved # ################################################################################ """ Author: PaddlePaddle Authors """ import os from .modules.base import build_dataset_checker, build_trainer, build_evaluater, build_predictor from .utils.result_saver import try_except_decorator from .utils import config from .utils.errors import raise_unsupported_api_error class Engine(object): """ Engine """ def __init__(self): args = config.parse_args() self.config = config.get_config( args.config, overrides=args.override, show=False) self.mode = self.config.Global.mode self.output = self.config.Global.output @try_except_decorator def run(self): """ the main function """ if self.config.Global.mode == "check_dataset": dataset_checker = build_dataset_checker(self.config) return dataset_checker.check() elif self.config.Global.mode == "train": trainer = build_trainer(self.config) trainer.train() elif self.config.Global.mode == "evaluate": evaluator = build_evaluater(self.config) return evaluator.evaluate() elif self.config.Global.mode == "export": raise_unsupported_api_error("export", self.__class__) elif self.config.Global.mode == "predict": predictor = build_predictor(self.config) return predictor.predict() else: raise_unsupported_api_error(f"{self.config.Global.mode}", self.__class__)