base.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from pathlib import Path
  15. from typing import Any, Dict, Optional
  16. from abc import ABC, abstractmethod
  17. import yaml
  18. import codecs
  19. from ...utils.subclass_register import AutoRegisterABCMetaClass
  20. from ..utils.pp_option import PaddlePredictorOption
  21. from ..models import BasePredictor
  22. class BasePipeline(ABC, metaclass=AutoRegisterABCMetaClass):
  23. """Base class for all pipelines.
  24. This class serves as a foundation for creating various pipelines.
  25. It includes common attributes and methods that are shared among all
  26. pipeline implementations.
  27. """
  28. __is_base = True
  29. def __init__(
  30. self,
  31. device: str = None,
  32. pp_option: PaddlePredictorOption = None,
  33. use_hpip: bool = False,
  34. hpi_params: Optional[Dict[str, Any]] = None,
  35. *args,
  36. **kwargs,
  37. ) -> None:
  38. """
  39. Initializes the class with specified parameters.
  40. Args:
  41. device (str, optional): The device to use for prediction. Defaults to None.
  42. pp_option (PaddlePredictorOption, optional): The options for PaddlePredictor. Defaults to None.
  43. use_hpip (bool, optional): Whether to use high-performance inference (hpip) for prediction. Defaults to False.
  44. hpi_params (Dict[str, Any], optional): Additional parameters for hpip. Defaults to None.
  45. """
  46. super().__init__()
  47. self.device = device
  48. self.pp_option = pp_option
  49. self.use_hpip = use_hpip
  50. self.hpi_params = hpi_params
  51. @abstractmethod
  52. def predict(self, input, **kwargs):
  53. """
  54. Declaration of an abstract method. Subclasses are expected to
  55. provide a concrete implementation of predict.
  56. Args:
  57. input: The input data to predict.
  58. **kwargs: Additional keyword arguments.
  59. """
  60. raise NotImplementedError("The method `predict` has not been implemented yet.")
  61. def create_model(self, config: Dict, **kwargs) -> BasePredictor:
  62. """
  63. Create a model instance based on the given configuration.
  64. Args:
  65. config (Dict): A dictionary containing configuration settings.
  66. **kwargs: The model arguments that needed to be pass.
  67. Returns:
  68. BasePredictor: An instance of the model.
  69. """
  70. if "model_config_error" in config:
  71. raise ValueError(config["model_config_error"])
  72. model_dir = config["model_dir"]
  73. if model_dir == None:
  74. model_dir = config["model_name"]
  75. from .. import create_predictor
  76. model = create_predictor(
  77. model=model_dir,
  78. device=self.device,
  79. pp_option=self.pp_option,
  80. use_hpip=self.use_hpip,
  81. hpi_params=self.hpi_params,
  82. **kwargs,
  83. )
  84. return model
  85. def create_pipeline(self, config: Dict):
  86. """
  87. Creates a pipeline based on the provided configuration.
  88. Args:
  89. config (Dict): A dictionary containing the pipeline configuration.
  90. Returns:
  91. BasePipeline: An instance of the created pipeline.
  92. """
  93. if "pipeline_config_error" in config:
  94. raise ValueError(config["pipeline_config_error"])
  95. from . import create_pipeline
  96. pipeline_name = config["pipeline_name"]
  97. pipeline = create_pipeline(
  98. pipeline_name,
  99. config=config,
  100. device=self.device,
  101. pp_option=self.pp_option,
  102. use_hpip=self.use_hpip,
  103. hpi_params=self.hpi_params,
  104. )
  105. return pipeline
  106. def __call__(self, input, **kwargs):
  107. """
  108. Calls the predict method with the given input and keyword arguments.
  109. Args:
  110. input: The input data to be predicted.
  111. **kwargs: Additional keyword arguments to be passed to the predict method.
  112. Returns:
  113. The prediction result from the predict method.
  114. """
  115. return self.predict(input, **kwargs)