zhangyue66 11 months ago
parent
commit
0217c377ee

+ 79 - 1
paddlex/inference/models_new/base/predictor/base_predictor.py

@@ -12,11 +12,43 @@
 # See the License for the specific language governing permissions and
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # limitations under the License.
 
 
-from typing import Union, Tuple, List, Dict, Any, Iterator
+from typing import List, Dict, Any, Iterator
 from pathlib import Path
 from pathlib import Path
 from abc import abstractmethod, ABC
 from abc import abstractmethod, ABC
 
 
 from ....utils.io import YAMLReader
 from ....utils.io import YAMLReader
+from ....common.batch_sampler import BaseBatchSampler
+
+
+class PredictionWrap:
+    """Wraps the prediction data and supports get by index."""
+
+    def __init__(self, data: Dict[str, List[Any]], num: int) -> None:
+        """Initializes the PredictionWrap with prediction data.
+
+        Args:
+            data (Dict[str, List[Any]]): A dictionary where keys are string identifiers and values are lists of predictions.
+            num (int): The number of predictions, that is length of values per key in the data dictionary.
+
+        Raises:
+            AssertionError: If the length of any list in data does not match num.
+        """
+        assert isinstance(data, dict), "data must be a dictionary"
+        for k in data:
+            assert len(data[k]) == num, f"{len(data[k])} != {num} for key {k}!"
+        self._data = data
+        self._keys = data.keys()
+
+    def get_by_idx(self, idx: int) -> Dict[str, Any]:
+        """Get the prediction by specified index.
+
+        Args:
+            idx (int): The index to get predictions from.
+
+        Returns:
+            Dict[str, Any]: A dictionary with the same keys as the input data, but with the values at the specified index.
+        """
+        return {key: self._data[key][idx] for key in self._keys}
 
 
 
 
 class BasePredictor(ABC):
 class BasePredictor(ABC):
@@ -98,3 +130,49 @@ class BasePredictor(ABC):
     def set_predictor(self) -> None:
     def set_predictor(self) -> None:
         """Sets up the predictor."""
         """Sets up the predictor."""
         raise NotImplementedError
         raise NotImplementedError
+
+    def apply(self, input: Any) -> Iterator[Any]:
+        """
+        Do predicting with the input data and yields predictions.
+
+        Args:
+            input (Any): The input data to be predicted.
+
+        Yields:
+            Iterator[Any]: An iterator yielding prediction results.
+        """
+        for batch_data in self.batch_sampler(input):
+            prediction = self.process(batch_data)
+            prediction = PredictionWrap(prediction, len(batch_data))
+            for idx in range(len(batch_data)):
+                yield self.result_class(prediction.get_by_idx(idx))
+
+    @abstractmethod
+    def process(self, batch_data: List[Any]) -> Dict[str, List[Any]]:
+        """process the batch data sampled from BatchSampler and return the prediction result.
+
+        Args:
+            batch_data (List[Any]): The batch data sampled from BatchSampler.
+
+        Returns:
+            Dict[str, List[Any]]: The prediction result.
+        """
+        raise NotImplementedError
+
+    @abstractmethod
+    def _build_batch_sampler(self) -> BaseBatchSampler:
+        """Build batch sampler.
+
+        Returns:
+            BaseBatchSampler: batch sampler object.
+        """
+        raise NotImplementedError
+
+    @abstractmethod
+    def _get_result_class(self) -> type:
+        """Get the result class.
+
+        Returns:
+            type: The result class.
+        """
+        raise NotImplementedError

+ 1 - 79
paddlex/inference/models_new/base/predictor/basic_predictor.py

@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # limitations under the License.
 
 
-from typing import Union, Tuple, List, Dict, Any, Iterator
+from typing import Dict, Any, Iterator
 from abc import abstractmethod
 from abc import abstractmethod
 
 
 from .....utils.subclass_register import AutoRegisterABCMetaClass
 from .....utils.subclass_register import AutoRegisterABCMetaClass
@@ -23,41 +23,9 @@ from .....utils.flags import (
 from .....utils import logging
 from .....utils import logging
 from ....utils.pp_option import PaddlePredictorOption
 from ....utils.pp_option import PaddlePredictorOption
 from ....utils.benchmark import benchmark
 from ....utils.benchmark import benchmark
-from ....common.batch_sampler import BaseBatchSampler
 from .base_predictor import BasePredictor
 from .base_predictor import BasePredictor
 
 
 
 
-class PredictionWrap:
-    """Wraps the prediction data and supports get by index."""
-
-    def __init__(self, data: Dict[str, List[Any]], num: int) -> None:
-        """Initializes the PredictionWrap with prediction data.
-
-        Args:
-            data (Dict[str, List[Any]]): A dictionary where keys are string identifiers and values are lists of predictions.
-            num (int): The number of predictions, that is length of values per key in the data dictionary.
-
-        Raises:
-            AssertionError: If the length of any list in data does not match num.
-        """
-        assert isinstance(data, dict), "data must be a dictionary"
-        for k in data:
-            assert len(data[k]) == num, f"{len(data[k])} != {num} for key {k}!"
-        self._data = data
-        self._keys = data.keys()
-
-    def get_by_idx(self, idx: int) -> Dict[str, Any]:
-        """Get the prediction by specified index.
-
-        Args:
-            idx (int): The index to get predictions from.
-
-        Returns:
-            Dict[str, Any]: A dictionary with the same keys as the input data, but with the values at the specified index.
-        """
-        return {key: self._data[key][idx] for key in self._keys}
-
-
 class BasicPredictor(
 class BasicPredictor(
     BasePredictor,
     BasePredictor,
     metaclass=AutoRegisterABCMetaClass,
     metaclass=AutoRegisterABCMetaClass,
@@ -124,22 +92,6 @@ class BasicPredictor(
         else:
         else:
             yield from self.apply(input)
             yield from self.apply(input)
 
 
-    def apply(self, input: Any) -> Iterator[Any]:
-        """
-        Do predicting with the input data and yields predictions.
-
-        Args:
-            input (Any): The input data to be predicted.
-
-        Yields:
-            Iterator[Any]: An iterator yielding prediction results.
-        """
-        for batch_data in self.batch_sampler(input):
-            prediction = self.process(batch_data)
-            prediction = PredictionWrap(prediction, len(batch_data))
-            for idx in range(len(batch_data)):
-                yield self.result_class(prediction.get_by_idx(idx))
-
     def set_predictor(
     def set_predictor(
         self,
         self,
         batch_size: int = None,
         batch_size: int = None,
@@ -164,33 +116,3 @@ class BasicPredictor(
             self.pp_option.device = device
             self.pp_option.device = device
         if pp_option and pp_option != self.pp_option:
         if pp_option and pp_option != self.pp_option:
             self.pp_option = pp_option
             self.pp_option = pp_option
-
-    @abstractmethod
-    def _build_batch_sampler(self) -> BaseBatchSampler:
-        """Build batch sampler.
-
-        Returns:
-            BaseBatchSampler: batch sampler object.
-        """
-        raise NotImplementedError
-
-    @abstractmethod
-    def process(self, batch_data: List[Any]) -> Dict[str, List[Any]]:
-        """process the batch data sampled from BatchSampler and return the prediction result.
-
-        Args:
-            batch_data (List[Any]): The batch data sampled from BatchSampler.
-
-        Returns:
-            Dict[str, List[Any]]: The prediction result.
-        """
-        raise NotImplementedError
-
-    @abstractmethod
-    def _get_result_class(self) -> type:
-        """Get the result class.
-
-        Returns:
-            type: The result class.
-        """
-        raise NotImplementedError