# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from copy import deepcopy from .inference import PaddlePredictorOption, create_predictor from .modules import ( build_dataset_checker, build_evaluator, build_exportor, build_trainer, ) # TODO(gaotingquan): support _ModelBasedConfig def create_model(model_name, model_dir=None, *args, **kwargs): return _ModelBasedInference( model_name=model_name, model_dir=model_dir, *args, **kwargs ) class _BaseModel: def check_dataset(self, *args, **kwargs): raise Exception("check_dataset is not supported!") def train(self, *args, **kwargs): raise Exception("train is not supported!") def evaluate(self, *args, **kwargs): raise Exception("evaluate is not supported!") def export(self, *args, **kwargs): raise Exception("export is not supported!") def predict(self, *args, **kwargs): raise Exception("predict is not supported!") def set_predict(self, *args, **kwargs): raise Exception("set_predict is not supported!") def __call__(self, *args, **kwargs): yield from self.predict(*args, **kwargs) class _ModelBasedInference(_BaseModel): def __init__(self, *args, **kwargs): self._predictor = create_predictor(*args, **kwargs) def predict(self, *args, **kwargs): yield from self._predictor(*args, **kwargs) def set_predictor(self, **kwargs): self._predictor.set_predictor(**kwargs) def __getattr__(self, name): if hasattr(self._predictor, name): return getattr(self._predictor, name) raise AttributeError( f"'{self.__class__.__name__}' object has no attribute '{name}'" ) class _ModelBasedConfig(_BaseModel): def __init__(self, config=None, *args, **kwargs): super().__init__() self._config = config self._model_name = config.Global.model def _build_predictor(self): predict_kwargs = deepcopy(self._config.Predict) model_dir = predict_kwargs.pop("model_dir", None) UNSET = object() device = self._config.Global.get("device", None) kernel_option = predict_kwargs.pop("kernel_option", UNSET) use_hpip = predict_kwargs.pop("use_hpip", UNSET) hpi_config = predict_kwargs.pop("hpi_config", UNSET) create_predictor_kwargs = {} if kernel_option is not UNSET: kernel_option.setdefault("model_name", self._model_name) create_predictor_kwargs["pp_option"] = PaddlePredictorOption( **kernel_option ) if use_hpip is not UNSET: create_predictor_kwargs["use_hpip"] = use_hpip else: create_predictor_kwargs["use_hpip"] = False if hpi_config is not UNSET: create_predictor_kwargs["hpi_config"] = hpi_config predictor = create_predictor( self._model_name, model_dir, device=device, **create_predictor_kwargs, ) assert "input" in predict_kwargs return predict_kwargs, predictor def check_dataset(self): dataset_checker = build_dataset_checker(self._config) return dataset_checker.check() def train(self): trainer = build_trainer(self._config) trainer.train() def evaluate(self): evaluator = build_evaluator(self._config) return evaluator.evaluate() def export(self): exportor = build_exportor(self._config) return exportor.export() def predict(self): predict_kwargs, predictor = self._build_predictor() yield from predictor(**predict_kwargs)