base_predictor.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import List, Dict, Any, Iterator
  15. from pathlib import Path
  16. from abc import abstractmethod, ABC
  17. from ....utils.io import YAMLReader
  18. from ....common.batch_sampler import BaseBatchSampler
  19. class PredictionWrap:
  20. """Wraps the prediction data and supports get by index."""
  21. def __init__(self, data: Dict[str, List[Any]], num: int) -> None:
  22. """Initializes the PredictionWrap with prediction data.
  23. Args:
  24. data (Dict[str, List[Any]]): A dictionary where keys are string identifiers and values are lists of predictions.
  25. num (int): The number of predictions, that is length of values per key in the data dictionary.
  26. Raises:
  27. AssertionError: If the length of any list in data does not match num.
  28. """
  29. assert isinstance(data, dict), "data must be a dictionary"
  30. for k in data:
  31. assert len(data[k]) == num, f"{len(data[k])} != {num} for key {k}!"
  32. self._data = data
  33. self._keys = data.keys()
  34. def get_by_idx(self, idx: int) -> Dict[str, Any]:
  35. """Get the prediction by specified index.
  36. Args:
  37. idx (int): The index to get predictions from.
  38. Returns:
  39. Dict[str, Any]: A dictionary with the same keys as the input data, but with the values at the specified index.
  40. """
  41. return {key: self._data[key][idx] for key in self._keys}
  42. class BasePredictor(ABC):
  43. """BasePredictor."""
  44. MODEL_FILE_PREFIX = "inference"
  45. def __init__(self, model_dir: str, config: Dict = None) -> None:
  46. """Initializes the BasePredictor.
  47. Args:
  48. model_dir (str): The directory where the static model files is stored.
  49. config (dict, optional): The configuration of model to infer. Defaults to None.
  50. """
  51. super().__init__()
  52. self.model_dir = Path(model_dir)
  53. self.config = config if config else self.load_config(self.model_dir)
  54. self.batch_sampler = self._build_batch_sampler()
  55. self.result_class = self._get_result_class()
  56. # alias predict() to the __call__()
  57. self.predict = self.__call__
  58. self.benchmark = None
  59. @property
  60. def config_path(self) -> str:
  61. """
  62. Get the path to the configuration file.
  63. Returns:
  64. str: The path to the configuration file.
  65. """
  66. return self.get_config_path(self.model_dir)
  67. @property
  68. def model_name(self) -> str:
  69. """
  70. Get the model name.
  71. Returns:
  72. str: The model name.
  73. """
  74. return self.config["Global"]["model_name"]
  75. @classmethod
  76. def get_config_path(cls, model_dir) -> str:
  77. """Get the path to the configuration file for the given model directory.
  78. Args:
  79. model_dir (Path): The directory where the static model files is stored.
  80. Returns:
  81. Path: The path to the configuration file.
  82. """
  83. return model_dir / f"{cls.MODEL_FILE_PREFIX}.yml"
  84. @classmethod
  85. def load_config(cls, model_dir) -> Dict:
  86. """Load the configuration from the specified model directory.
  87. Args:
  88. model_dir (Path): The where the static model files is stored.
  89. Returns:
  90. dict: The loaded configuration dictionary.
  91. """
  92. yaml_reader = YAMLReader()
  93. return yaml_reader.read(cls.get_config_path(model_dir))
  94. @abstractmethod
  95. def __call__(self, input: Any, **kwargs: Dict[str, Any]) -> Iterator[Any]:
  96. """Predict with the given input and additional keyword arguments."""
  97. raise NotImplementedError
  98. @abstractmethod
  99. def set_predictor(self, batch_size: int = None, device: str = None, *args) -> None:
  100. """Sets up the predictor."""
  101. raise NotImplementedError
  102. def apply(self, input: Any, **kwargs) -> Iterator[Any]:
  103. """
  104. Do predicting with the input data and yields predictions.
  105. Args:
  106. input (Any): The input data to be predicted.
  107. Yields:
  108. Iterator[Any]: An iterator yielding prediction results.
  109. """
  110. for batch_data in self.batch_sampler(input):
  111. prediction = self.process(batch_data, **kwargs)
  112. prediction = PredictionWrap(prediction, len(batch_data))
  113. for idx in range(len(batch_data)):
  114. yield self.result_class(prediction.get_by_idx(idx))
  115. @abstractmethod
  116. def process(self, batch_data: List[Any]) -> Dict[str, List[Any]]:
  117. """process the batch data sampled from BatchSampler and return the prediction result.
  118. Args:
  119. batch_data (List[Any]): The batch data sampled from BatchSampler.
  120. Returns:
  121. Dict[str, List[Any]]: The prediction result.
  122. """
  123. raise NotImplementedError
  124. @abstractmethod
  125. def _build_batch_sampler(self) -> BaseBatchSampler:
  126. """Build batch sampler.
  127. Returns:
  128. BaseBatchSampler: batch sampler object.
  129. """
  130. raise NotImplementedError
  131. @abstractmethod
  132. def _get_result_class(self) -> type:
  133. """Get the result class.
  134. Returns:
  135. type: The result class.
  136. """
  137. raise NotImplementedError