|
|
@@ -12,7 +12,7 @@
|
|
|
# See the License for the specific language governing permissions and
|
|
|
# limitations under the License.
|
|
|
|
|
|
-from typing import Union, Tuple, List, Dict, Any, Iterator
|
|
|
+from typing import Dict, Any, Iterator
|
|
|
from abc import abstractmethod
|
|
|
|
|
|
from .....utils.subclass_register import AutoRegisterABCMetaClass
|
|
|
@@ -23,41 +23,9 @@ from .....utils.flags import (
|
|
|
from .....utils import logging
|
|
|
from ....utils.pp_option import PaddlePredictorOption
|
|
|
from ....utils.benchmark import benchmark
|
|
|
-from ....common.batch_sampler import BaseBatchSampler
|
|
|
from .base_predictor import BasePredictor
|
|
|
|
|
|
|
|
|
-class PredictionWrap:
|
|
|
- """Wraps the prediction data and supports get by index."""
|
|
|
-
|
|
|
- def __init__(self, data: Dict[str, List[Any]], num: int) -> None:
|
|
|
- """Initializes the PredictionWrap with prediction data.
|
|
|
-
|
|
|
- Args:
|
|
|
- data (Dict[str, List[Any]]): A dictionary where keys are string identifiers and values are lists of predictions.
|
|
|
- num (int): The number of predictions, that is length of values per key in the data dictionary.
|
|
|
-
|
|
|
- Raises:
|
|
|
- AssertionError: If the length of any list in data does not match num.
|
|
|
- """
|
|
|
- assert isinstance(data, dict), "data must be a dictionary"
|
|
|
- for k in data:
|
|
|
- assert len(data[k]) == num, f"{len(data[k])} != {num} for key {k}!"
|
|
|
- self._data = data
|
|
|
- self._keys = data.keys()
|
|
|
-
|
|
|
- def get_by_idx(self, idx: int) -> Dict[str, Any]:
|
|
|
- """Get the prediction by specified index.
|
|
|
-
|
|
|
- Args:
|
|
|
- idx (int): The index to get predictions from.
|
|
|
-
|
|
|
- Returns:
|
|
|
- Dict[str, Any]: A dictionary with the same keys as the input data, but with the values at the specified index.
|
|
|
- """
|
|
|
- return {key: self._data[key][idx] for key in self._keys}
|
|
|
-
|
|
|
-
|
|
|
class BasicPredictor(
|
|
|
BasePredictor,
|
|
|
metaclass=AutoRegisterABCMetaClass,
|
|
|
@@ -124,22 +92,6 @@ class BasicPredictor(
|
|
|
else:
|
|
|
yield from self.apply(input)
|
|
|
|
|
|
- def apply(self, input: Any) -> Iterator[Any]:
|
|
|
- """
|
|
|
- Do predicting with the input data and yields predictions.
|
|
|
-
|
|
|
- Args:
|
|
|
- input (Any): The input data to be predicted.
|
|
|
-
|
|
|
- Yields:
|
|
|
- Iterator[Any]: An iterator yielding prediction results.
|
|
|
- """
|
|
|
- for batch_data in self.batch_sampler(input):
|
|
|
- prediction = self.process(batch_data)
|
|
|
- prediction = PredictionWrap(prediction, len(batch_data))
|
|
|
- for idx in range(len(batch_data)):
|
|
|
- yield self.result_class(prediction.get_by_idx(idx))
|
|
|
-
|
|
|
def set_predictor(
|
|
|
self,
|
|
|
batch_size: int = None,
|
|
|
@@ -164,33 +116,3 @@ class BasicPredictor(
|
|
|
self.pp_option.device = device
|
|
|
if pp_option and pp_option != self.pp_option:
|
|
|
self.pp_option = pp_option
|
|
|
-
|
|
|
- @abstractmethod
|
|
|
- def _build_batch_sampler(self) -> BaseBatchSampler:
|
|
|
- """Build batch sampler.
|
|
|
-
|
|
|
- Returns:
|
|
|
- BaseBatchSampler: batch sampler object.
|
|
|
- """
|
|
|
- raise NotImplementedError
|
|
|
-
|
|
|
- @abstractmethod
|
|
|
- def process(self, batch_data: List[Any]) -> Dict[str, List[Any]]:
|
|
|
- """process the batch data sampled from BatchSampler and return the prediction result.
|
|
|
-
|
|
|
- Args:
|
|
|
- batch_data (List[Any]): The batch data sampled from BatchSampler.
|
|
|
-
|
|
|
- Returns:
|
|
|
- Dict[str, List[Any]]: The prediction result.
|
|
|
- """
|
|
|
- raise NotImplementedError
|
|
|
-
|
|
|
- @abstractmethod
|
|
|
- def _get_result_class(self) -> type:
|
|
|
- """Get the result class.
|
|
|
-
|
|
|
- Returns:
|
|
|
- type: The result class.
|
|
|
- """
|
|
|
- raise NotImplementedError
|