base_predictor.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import List, Dict, Any, Iterator
  15. from pathlib import Path
  16. from abc import abstractmethod, ABC
  17. from .....utils.flags import INFER_BENCHMARK
  18. from ....utils.io import YAMLReader
  19. from ....common.batch_sampler import BaseBatchSampler
  20. class PredictionWrap:
  21. """Wraps the prediction data and supports get by index."""
  22. def __init__(self, data: Dict[str, List[Any]], num: int) -> None:
  23. """Initializes the PredictionWrap with prediction data.
  24. Args:
  25. data (Dict[str, List[Any]]): A dictionary where keys are string identifiers and values are lists of predictions.
  26. num (int): The number of predictions, that is length of values per key in the data dictionary.
  27. Raises:
  28. AssertionError: If the length of any list in data does not match num.
  29. """
  30. assert isinstance(data, dict), "data must be a dictionary"
  31. for k in data:
  32. assert len(data[k]) == num, f"{len(data[k])} != {num} for key {k}!"
  33. self._data = data
  34. self._keys = data.keys()
  35. def get_by_idx(self, idx: int) -> Dict[str, Any]:
  36. """Get the prediction by specified index.
  37. Args:
  38. idx (int): The index to get predictions from.
  39. Returns:
  40. Dict[str, Any]: A dictionary with the same keys as the input data, but with the values at the specified index.
  41. """
  42. return {key: self._data[key][idx] for key in self._keys}
  43. class BasePredictor(ABC):
  44. """BasePredictor."""
  45. MODEL_FILE_PREFIX = "inference"
  46. def __init__(self, model_dir: str, config: Dict = None) -> None:
  47. """Initializes the BasePredictor.
  48. Args:
  49. model_dir (str): The directory where the static model files is stored.
  50. config (dict, optional): The configuration of model to infer. Defaults to None.
  51. """
  52. super().__init__()
  53. self.model_dir = Path(model_dir)
  54. self.config = config if config else self.load_config(self.model_dir)
  55. self.batch_sampler = self._build_batch_sampler()
  56. self.result_class = self._get_result_class()
  57. # alias predict() to the __call__()
  58. self.predict = self.__call__
  59. @property
  60. def config_path(self) -> str:
  61. """
  62. Get the path to the configuration file.
  63. Returns:
  64. str: The path to the configuration file.
  65. """
  66. return self.get_config_path(self.model_dir)
  67. @property
  68. def model_name(self) -> str:
  69. """
  70. Get the model name.
  71. Returns:
  72. str: The model name.
  73. """
  74. return self.config["Global"]["model_name"]
  75. @classmethod
  76. def get_config_path(cls, model_dir) -> str:
  77. """Get the path to the configuration file for the given model directory.
  78. Args:
  79. model_dir (Path): The directory where the static model files is stored.
  80. Returns:
  81. Path: The path to the configuration file.
  82. """
  83. return model_dir / f"{cls.MODEL_FILE_PREFIX}.yml"
  84. @classmethod
  85. def load_config(cls, model_dir) -> Dict:
  86. """Load the configuration from the specified model directory.
  87. Args:
  88. model_dir (Path): The where the static model files is stored.
  89. Returns:
  90. dict: The loaded configuration dictionary.
  91. """
  92. yaml_reader = YAMLReader()
  93. return yaml_reader.read(cls.get_config_path(model_dir))
  94. @abstractmethod
  95. def __call__(self, input: Any, **kwargs: Dict[str, Any]) -> Iterator[Any]:
  96. """Predict with the given input and additional keyword arguments."""
  97. raise NotImplementedError
  98. @abstractmethod
  99. def set_predictor(self, batch_size: int = None, device: str = None, *args) -> None:
  100. """Sets up the predictor."""
  101. raise NotImplementedError
  102. def apply(self, input: Any, **kwargs) -> Iterator[Any]:
  103. """
  104. Do predicting with the input data and yields predictions.
  105. Args:
  106. input (Any): The input data to be predicted.
  107. Yields:
  108. Iterator[Any]: An iterator yielding prediction results.
  109. """
  110. if INFER_BENCHMARK:
  111. if not isinstance(input, list):
  112. raise TypeError("In benchmark mode, `input` must be a list")
  113. batches = list(self.batch_sampler(input))
  114. if len(batches) != 1 or len(batches[0]) != len(input):
  115. raise ValueError("Unexpected number of instances")
  116. else:
  117. batches = self.batch_sampler(input)
  118. for batch_data in batches:
  119. prediction = self.process(batch_data, **kwargs)
  120. prediction = PredictionWrap(prediction, len(batch_data))
  121. for idx in range(len(batch_data)):
  122. yield self.result_class(prediction.get_by_idx(idx))
  123. @abstractmethod
  124. def process(self, batch_data: List[Any]) -> Dict[str, List[Any]]:
  125. """process the batch data sampled from BatchSampler and return the prediction result.
  126. Args:
  127. batch_data (List[Any]): The batch data sampled from BatchSampler.
  128. Returns:
  129. Dict[str, List[Any]]: The prediction result.
  130. """
  131. raise NotImplementedError
  132. @abstractmethod
  133. def _build_batch_sampler(self) -> BaseBatchSampler:
  134. """Build batch sampler.
  135. Returns:
  136. BaseBatchSampler: batch sampler object.
  137. """
  138. raise NotImplementedError
  139. @abstractmethod
  140. def _get_result_class(self) -> type:
  141. """Get the result class.
  142. Returns:
  143. type: The result class.
  144. """
  145. raise NotImplementedError