main.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from ...modules.base import create_model
  15. from ...modules.image_classification.predictor import transforms as T
  16. from ...modules.base.predictor.utils.paddle_inference_predictor import PaddleInferenceOption
  17. class ClsPipeline(object):
  18. """Cls Pipeline
  19. """
  20. def __init__(self,
  21. model_name,
  22. model_dir=None,
  23. output_dir="./output",
  24. kernel_option=None):
  25. self.output_dir = output_dir
  26. post_transforms = self.get_post_transforms(model_dir)
  27. kernel_option = self.get_kernel_option(
  28. ) if kernel_option is None else kernel_option
  29. self.model = create_model(
  30. model_name,
  31. model_dir=model_dir,
  32. kernel_option=kernel_option,
  33. post_transforms=post_transforms)
  34. def __call__(self, input_path):
  35. return self.model.predict({"input_path": input_path})
  36. def get_post_transforms(self, model_dir):
  37. """get post transform ops
  38. """
  39. return [T.Topk(topk=1), T.PrintResult()]
  40. def get_kernel_option(self):
  41. """get kernel option
  42. """
  43. kernel_option = PaddleInferenceOption()
  44. kernel_option.set_device("gpu")