zhangyue66 11 tháng trước cách đây
mục cha
commit
0217c377ee

+ 79 - 1
paddlex/inference/models_new/base/predictor/base_predictor.py

@@ -12,11 +12,43 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Union, Tuple, List, Dict, Any, Iterator
+from typing import List, Dict, Any, Iterator
 from pathlib import Path
 from abc import abstractmethod, ABC
 
 from ....utils.io import YAMLReader
+from ....common.batch_sampler import BaseBatchSampler
+
+
+class PredictionWrap:
+    """Wraps the prediction data and supports get by index."""
+
+    def __init__(self, data: Dict[str, List[Any]], num: int) -> None:
+        """Initializes the PredictionWrap with prediction data.
+
+        Args:
+            data (Dict[str, List[Any]]): A dictionary where keys are string identifiers and values are lists of predictions.
+            num (int): The number of predictions, that is length of values per key in the data dictionary.
+
+        Raises:
+            AssertionError: If the length of any list in data does not match num.
+        """
+        assert isinstance(data, dict), "data must be a dictionary"
+        for k in data:
+            assert len(data[k]) == num, f"{len(data[k])} != {num} for key {k}!"
+        self._data = data
+        self._keys = data.keys()
+
+    def get_by_idx(self, idx: int) -> Dict[str, Any]:
+        """Get the prediction by specified index.
+
+        Args:
+            idx (int): The index to get predictions from.
+
+        Returns:
+            Dict[str, Any]: A dictionary with the same keys as the input data, but with the values at the specified index.
+        """
+        return {key: self._data[key][idx] for key in self._keys}
 
 
 class BasePredictor(ABC):
@@ -98,3 +130,49 @@ class BasePredictor(ABC):
     def set_predictor(self) -> None:
         """Sets up the predictor."""
         raise NotImplementedError
+
+    def apply(self, input: Any) -> Iterator[Any]:
+        """
+        Do predicting with the input data and yields predictions.
+
+        Args:
+            input (Any): The input data to be predicted.
+
+        Yields:
+            Iterator[Any]: An iterator yielding prediction results.
+        """
+        for batch_data in self.batch_sampler(input):
+            prediction = self.process(batch_data)
+            prediction = PredictionWrap(prediction, len(batch_data))
+            for idx in range(len(batch_data)):
+                yield self.result_class(prediction.get_by_idx(idx))
+
+    @abstractmethod
+    def process(self, batch_data: List[Any]) -> Dict[str, List[Any]]:
+        """process the batch data sampled from BatchSampler and return the prediction result.
+
+        Args:
+            batch_data (List[Any]): The batch data sampled from BatchSampler.
+
+        Returns:
+            Dict[str, List[Any]]: The prediction result.
+        """
+        raise NotImplementedError
+
+    @abstractmethod
+    def _build_batch_sampler(self) -> BaseBatchSampler:
+        """Build batch sampler.
+
+        Returns:
+            BaseBatchSampler: batch sampler object.
+        """
+        raise NotImplementedError
+
+    @abstractmethod
+    def _get_result_class(self) -> type:
+        """Get the result class.
+
+        Returns:
+            type: The result class.
+        """
+        raise NotImplementedError

+ 1 - 79
paddlex/inference/models_new/base/predictor/basic_predictor.py

@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Union, Tuple, List, Dict, Any, Iterator
+from typing import Dict, Any, Iterator
 from abc import abstractmethod
 
 from .....utils.subclass_register import AutoRegisterABCMetaClass
@@ -23,41 +23,9 @@ from .....utils.flags import (
 from .....utils import logging
 from ....utils.pp_option import PaddlePredictorOption
 from ....utils.benchmark import benchmark
-from ....common.batch_sampler import BaseBatchSampler
 from .base_predictor import BasePredictor
 
 
-class PredictionWrap:
-    """Wraps the prediction data and supports get by index."""
-
-    def __init__(self, data: Dict[str, List[Any]], num: int) -> None:
-        """Initializes the PredictionWrap with prediction data.
-
-        Args:
-            data (Dict[str, List[Any]]): A dictionary where keys are string identifiers and values are lists of predictions.
-            num (int): The number of predictions, that is length of values per key in the data dictionary.
-
-        Raises:
-            AssertionError: If the length of any list in data does not match num.
-        """
-        assert isinstance(data, dict), "data must be a dictionary"
-        for k in data:
-            assert len(data[k]) == num, f"{len(data[k])} != {num} for key {k}!"
-        self._data = data
-        self._keys = data.keys()
-
-    def get_by_idx(self, idx: int) -> Dict[str, Any]:
-        """Get the prediction by specified index.
-
-        Args:
-            idx (int): The index to get predictions from.
-
-        Returns:
-            Dict[str, Any]: A dictionary with the same keys as the input data, but with the values at the specified index.
-        """
-        return {key: self._data[key][idx] for key in self._keys}
-
-
 class BasicPredictor(
     BasePredictor,
     metaclass=AutoRegisterABCMetaClass,
@@ -124,22 +92,6 @@ class BasicPredictor(
         else:
             yield from self.apply(input)
 
-    def apply(self, input: Any) -> Iterator[Any]:
-        """
-        Do predicting with the input data and yields predictions.
-
-        Args:
-            input (Any): The input data to be predicted.
-
-        Yields:
-            Iterator[Any]: An iterator yielding prediction results.
-        """
-        for batch_data in self.batch_sampler(input):
-            prediction = self.process(batch_data)
-            prediction = PredictionWrap(prediction, len(batch_data))
-            for idx in range(len(batch_data)):
-                yield self.result_class(prediction.get_by_idx(idx))
-
     def set_predictor(
         self,
         batch_size: int = None,
@@ -164,33 +116,3 @@ class BasicPredictor(
             self.pp_option.device = device
         if pp_option and pp_option != self.pp_option:
             self.pp_option = pp_option
-
-    @abstractmethod
-    def _build_batch_sampler(self) -> BaseBatchSampler:
-        """Build batch sampler.
-
-        Returns:
-            BaseBatchSampler: batch sampler object.
-        """
-        raise NotImplementedError
-
-    @abstractmethod
-    def process(self, batch_data: List[Any]) -> Dict[str, List[Any]]:
-        """process the batch data sampled from BatchSampler and return the prediction result.
-
-        Args:
-            batch_data (List[Any]): The batch data sampled from BatchSampler.
-
-        Returns:
-            Dict[str, List[Any]]: The prediction result.
-        """
-        raise NotImplementedError
-
-    @abstractmethod
-    def _get_result_class(self) -> type:
-        """Get the result class.
-
-        Returns:
-            type: The result class.
-        """
-        raise NotImplementedError