base.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from pathlib import Path
  15. from typing import Any, Dict, Optional
  16. from abc import ABC, abstractmethod
  17. import yaml
  18. import codecs
  19. from ...utils.subclass_register import AutoRegisterABCMetaClass
  20. from ..utils.pp_option import PaddlePredictorOption
  21. from ..models import BasePredictor
  22. class BasePipeline(ABC, metaclass=AutoRegisterABCMetaClass):
  23. """Base class for all pipelines.
  24. This class serves as a foundation for creating various pipelines.
  25. It includes common attributes and methods that are shared among all
  26. pipeline implementations.
  27. """
  28. __is_base = True
  29. def __init__(
  30. self,
  31. device: str = None,
  32. pp_option: PaddlePredictorOption = None,
  33. use_hpip: bool = False,
  34. *args,
  35. **kwargs,
  36. ) -> None:
  37. """
  38. Initializes the class with specified parameters.
  39. Args:
  40. device (str, optional): The device to use for prediction. Defaults to None.
  41. pp_option (PaddlePredictorOption, optional): The options for PaddlePredictor. Defaults to None.
  42. use_hpip (bool, optional): Whether to use high-performance inference (hpip) for prediction. Defaults to False.
  43. """
  44. super().__init__()
  45. self.device = device
  46. self.pp_option = pp_option
  47. self.use_hpip = use_hpip
  48. @abstractmethod
  49. def predict(self, input, **kwargs):
  50. """
  51. Declaration of an abstract method. Subclasses are expected to
  52. provide a concrete implementation of predict.
  53. Args:
  54. input: The input data to predict.
  55. **kwargs: Additional keyword arguments.
  56. """
  57. raise NotImplementedError("The method `predict` has not been implemented yet.")
  58. def create_model(self, config: Dict, **kwargs) -> BasePredictor:
  59. """
  60. Create a model instance based on the given configuration.
  61. Args:
  62. config (Dict): A dictionary containing configuration settings.
  63. **kwargs: The model arguments that needed to be pass.
  64. Returns:
  65. BasePredictor: An instance of the model.
  66. """
  67. if "model_config_error" in config:
  68. raise ValueError(config["model_config_error"])
  69. model_dir = config.get("model_dir", None)
  70. hpi_params = config.get("hpi_params", None)
  71. from .. import create_predictor
  72. model = create_predictor(
  73. model_name=config["model_name"],
  74. model_dir=model_dir,
  75. device=self.device,
  76. batch_size=config.get("batch_size", 1),
  77. pp_option=self.pp_option,
  78. use_hpip=self.use_hpip,
  79. **kwargs,
  80. )
  81. return model
  82. def create_pipeline(self, config: Dict):
  83. """
  84. Creates a pipeline based on the provided configuration.
  85. Args:
  86. config (Dict): A dictionary containing the pipeline configuration.
  87. Returns:
  88. BasePipeline: An instance of the created pipeline.
  89. """
  90. if "pipeline_config_error" in config:
  91. raise ValueError(config["pipeline_config_error"])
  92. from . import create_pipeline
  93. pipeline = create_pipeline(
  94. config=config,
  95. device=self.device,
  96. pp_option=self.pp_option,
  97. use_hpip=self.use_hpip,
  98. )
  99. return pipeline
  100. def __call__(self, input, **kwargs):
  101. """
  102. Calls the predict method with the given input and keyword arguments.
  103. Args:
  104. input: The input data to be predicted.
  105. **kwargs: Additional keyword arguments to be passed to the predict method.
  106. Returns:
  107. The prediction result from the predict method.
  108. """
  109. return self.predict(input, **kwargs)