# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from ..base import BasePipeline from ...modules import create_model, PaddleInferenceOption from ...modules.image_classification import transforms as T class ClsPipeline(BasePipeline): """Cls Pipeline """ support_models = "image_classification" def __init__(self, model_name=None, model_dir=None, kernel_option=None, **kwargs): super().__init__() self.model_name = model_name self.model_dir = model_dir self.kernel_option = self.get_kernel_option( ) if kernel_option is None else kernel_option self.post_transforms = self.get_post_transforms() if self.model_name is not None: self.load_model() def predict(self, input_path): """predict """ return self.model.predict({"input_path": input_path}) def load_model(self): """load model predictor """ assert self.model_name is not None self.model = create_model( self.model_name, model_dir=self.model_dir, kernel_option=self.kernel_option, post_transforms=self.post_transforms) def get_post_transforms(self): """get post transform ops """ return [T.Topk(topk=1), T.PrintResult()] def get_kernel_option(self): """get kernel option """ kernel_option = PaddleInferenceOption() kernel_option.set_device("gpu") return kernel_option def update_model_name(self, model_name_list): """update model name and re Args: model_list (list): list of model name. """ assert len(model_name_list) == 1 self.model_name = model_name_list[0]