| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175 |
- # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from typing import List, Dict, Any, Iterator
- from pathlib import Path
- from abc import abstractmethod, ABC
- from ....utils.io import YAMLReader
- from ....common.batch_sampler import BaseBatchSampler
- class PredictionWrap:
- """Wraps the prediction data and supports get by index."""
- def __init__(self, data: Dict[str, List[Any]], num: int) -> None:
- """Initializes the PredictionWrap with prediction data.
- Args:
- data (Dict[str, List[Any]]): A dictionary where keys are string identifiers and values are lists of predictions.
- num (int): The number of predictions, that is length of values per key in the data dictionary.
- Raises:
- AssertionError: If the length of any list in data does not match num.
- """
- assert isinstance(data, dict), "data must be a dictionary"
- for k in data:
- assert len(data[k]) == num, f"{len(data[k])} != {num} for key {k}!"
- self._data = data
- self._keys = data.keys()
- def get_by_idx(self, idx: int) -> Dict[str, Any]:
- """Get the prediction by specified index.
- Args:
- idx (int): The index to get predictions from.
- Returns:
- Dict[str, Any]: A dictionary with the same keys as the input data, but with the values at the specified index.
- """
- return {key: self._data[key][idx] for key in self._keys}
- class BasePredictor(ABC):
- """BasePredictor."""
- MODEL_FILE_PREFIX = "inference"
- def __init__(self, model_dir: str, config: dict = None) -> None:
- """Initializes the BasePredictor.
- Args:
- model_dir (str): The directory where the static model files is stored.
- config (dict, optional): The configuration of model to infer. Defaults to None.
- """
- super().__init__()
- self.model_dir = Path(model_dir)
- self.config = config if config else self.load_config(self.model_dir)
- self.batch_sampler = self._build_batch_sampler()
- self.result_class = self._get_result_class()
- # alias predict() to the __call__()
- self.predict = self.__call__
- self.benchmark = None
- @property
- def config_path(self) -> str:
- """
- Get the path to the configuration file.
- Returns:
- str: The path to the configuration file.
- """
- return self.get_config_path(self.model_dir)
- @property
- def model_name(self) -> str:
- """
- Get the model name.
- Returns:
- str: The model name.
- """
- return self.config["Global"]["model_name"]
- @classmethod
- def get_config_path(cls, model_dir) -> str:
- """Get the path to the configuration file for the given model directory.
- Args:
- model_dir (Path): The directory where the static model files is stored.
- Returns:
- Path: The path to the configuration file.
- """
- return model_dir / f"{cls.MODEL_FILE_PREFIX}.yml"
- @classmethod
- def load_config(cls, model_dir) -> dict:
- """Load the configuration from the specified model directory.
- Args:
- model_dir (Path): The where the static model files is stored.
- Returns:
- dict: The loaded configuration dictionary.
- """
- yaml_reader = YAMLReader()
- return yaml_reader.read(cls.get_config_path(model_dir))
- @abstractmethod
- def __call__(self, input: Any, **kwargs: dict[str, Any]) -> Iterator[Any]:
- """Predict with the given input and additional keyword arguments."""
- raise NotImplementedError
- @abstractmethod
- def set_predictor(self) -> None:
- """Sets up the predictor."""
- raise NotImplementedError
- def apply(self, input: Any, **kwargs) -> Iterator[Any]:
- """
- Do predicting with the input data and yields predictions.
- Args:
- input (Any): The input data to be predicted.
- Yields:
- Iterator[Any]: An iterator yielding prediction results.
- """
- for batch_data in self.batch_sampler(input):
- prediction = self.process(batch_data, **kwargs)
- prediction = PredictionWrap(prediction, len(batch_data))
- for idx in range(len(batch_data)):
- yield self.result_class(prediction.get_by_idx(idx))
- @abstractmethod
- def process(self, batch_data: List[Any]) -> Dict[str, List[Any]]:
- """process the batch data sampled from BatchSampler and return the prediction result.
- Args:
- batch_data (List[Any]): The batch data sampled from BatchSampler.
- Returns:
- Dict[str, List[Any]]: The prediction result.
- """
- raise NotImplementedError
- @abstractmethod
- def _build_batch_sampler(self) -> BaseBatchSampler:
- """Build batch sampler.
- Returns:
- BaseBatchSampler: batch sampler object.
- """
- raise NotImplementedError
- @abstractmethod
- def _get_result_class(self) -> type:
- """Get the result class.
- Returns:
- type: The result class.
- """
- raise NotImplementedError
|