瀏覽代碼

support to set topk (#2678)

Tingquan Gao 11 月之前
父節點
當前提交
4a5cff3ba6

+ 3 - 3
paddlex/inference/models_new/base/predictor/base_predictor.py

@@ -68,7 +68,7 @@ class BasePredictor(ABC):
         self.config = config if config else self.load_config(self.model_dir)
         self.batch_sampler = self._build_batch_sampler()
         self.result_class = self._get_result_class()
-        
+
         # alias predict() to the __call__()
         self.predict = self.__call__
         self.benchmark = None
@@ -128,7 +128,7 @@ class BasePredictor(ABC):
         """Sets up the predictor."""
         raise NotImplementedError
 
-    def apply(self, input: Any) -> Iterator[Any]:
+    def apply(self, input: Any, **kwargs) -> Iterator[Any]:
         """
         Do predicting with the input data and yields predictions.
 
@@ -139,7 +139,7 @@ class BasePredictor(ABC):
             Iterator[Any]: An iterator yielding prediction results.
         """
         for batch_data in self.batch_sampler(input):
-            prediction = self.process(batch_data)
+            prediction = self.process(batch_data, **kwargs)
             prediction = PredictionWrap(prediction, len(batch_data))
             for idx in range(len(batch_data)):
                 yield self.result_class(prediction.get_by_idx(idx))

+ 15 - 5
paddlex/inference/models_new/base/predictor/basic_predictor.py

@@ -59,22 +59,32 @@ class BasicPredictor(
         logging.debug(f"{self.__class__.__name__}: {self.model_dir}")
         self.benchmark = benchmark
 
-    def __call__(self, input: Any, **kwargs: Dict[str, Any]) -> Iterator[Any]:
+    def __call__(
+        self,
+        input: Any,
+        batch_size: int = None,
+        device: str = None,
+        pp_option: PaddlePredictorOption = None,
+        **kwargs: Dict[str, Any],
+    ) -> Iterator[Any]:
         """
         Predict with the input data.
 
         Args:
             input (Any): The input data to be predicted.
+            batch_size (int, optional): The batch size to use. Defaults to None.
+            device (str, optional): The device to run the predictor on. Defaults to None.
+            pp_option (PaddlePredictorOption, optional): The predictor options to set. Defaults to None.
             **kwargs (Dict[str, Any]): Additional keyword arguments to set up predictor.
 
         Returns:
             Iterator[Any]: An iterator yielding the prediction output.
         """
-        self.set_predictor(**kwargs)
+        self.set_predictor(batch_size, device, pp_option)
         if self.benchmark:
             self.benchmark.start()
             if INFER_BENCHMARK_WARMUP > 0:
-                output = self.apply(input)
+                output = self.apply(input, **kwargs)
                 warmup_num = 0
                 for _ in range(INFER_BENCHMARK_WARMUP):
                     try:
@@ -86,10 +96,10 @@ class BasicPredictor(
                         )
                         break
                 self.benchmark.warmup_stop(warmup_num)
-            output = list(self.apply(input))
+            output = list(self.apply(input, **kwargs))
             self.benchmark.collect(len(output))
         else:
-            yield from self.apply(input)
+            yield from self.apply(input, **kwargs)
 
     def set_predictor(
         self,

+ 13 - 4
paddlex/inference/models_new/image_classification/predictor.py

@@ -40,14 +40,18 @@ class ClasPredictor(BasicPredictor):
     _FUNC_MAP = {}
     register = FuncRegister(_FUNC_MAP)
 
-    def __init__(self, *args: List, **kwargs: Dict) -> None:
+    def __init__(
+        self, topk: Union[int, None] = None, *args: List, **kwargs: Dict
+    ) -> None:
         """Initializes ClasPredictor.
 
         Args:
+            topk (int, optional): The number of top-k predictions to return. If None, it will be depending on config of inference or predict. Defaults to None.
             *args: Arbitrary positional arguments passed to the superclass.
             **kwargs: Arbitrary keyword arguments passed to the superclass.
         """
         super().__init__(*args, **kwargs)
+        self.topk = topk
         self.preprocessors, self.infer, self.postprocessors = self._build()
 
     def _build_batch_sampler(self) -> ImageBatchSampler:
@@ -95,12 +99,15 @@ class ClasPredictor(BasicPredictor):
             postprocessors[name] = op
         return preprocessors, infer, postprocessors
 
-    def process(self, batch_data: List[Union[str, np.ndarray]]) -> Dict[str, Any]:
+    def process(
+        self, batch_data: List[Union[str, np.ndarray]], topk: Union[int, None] = None
+    ) -> Dict[str, Any]:
         """
         Process a batch of data through the preprocessing, inference, and postprocessing.
 
         Args:
             batch_data (List[Union[str, np.ndarray], ...]): A batch of input data (e.g., image file paths).
+            topk: The number of top predictions to keep. If None, it will be depending on `self.topk`. Defaults to None.
 
         Returns:
             dict: A dictionary containing the input path, raw image, class IDs, scores, and label names for every instance of the batch. Keys include 'input_path', 'input_img', 'class_ids', 'scores', and 'label_names'.
@@ -113,7 +120,7 @@ class ClasPredictor(BasicPredictor):
         x = self.preprocessors["ToBatch"](imgs=batch_imgs)
         batch_preds = self.infer(x=x)
         batch_class_ids, batch_scores, batch_label_names = self.postprocessors["Topk"](
-            batch_preds
+            batch_preds, topk=topk or self.topk
         )
         return {
             "input_path": batch_data,
@@ -160,4 +167,6 @@ class ClasPredictor(BasicPredictor):
 
     @register("Topk")
     def build_topk(self, topk, label_list=None):
-        return "Topk", Topk(topk=int(topk), class_ids=label_list)
+        if not self.topk:
+            self.topk = int(topk)
+        return "Topk", Topk(class_ids=label_list)

+ 3 - 5
paddlex/inference/models_new/image_classification/processors.py

@@ -67,10 +67,8 @@ class Crop:
 class Topk:
     """Topk Transform"""
 
-    def __init__(self, topk, class_ids=None):
+    def __init__(self, class_ids=None):
         super().__init__()
-        assert isinstance(topk, (int,))
-        self.topk = topk
         self.class_id_map = self._parse_class_id_map(class_ids)
 
     def _parse_class_id_map(self, class_ids):
@@ -80,8 +78,8 @@ class Topk:
         class_id_map = {id: str(lb) for id, lb in enumerate(class_ids)}
         return class_id_map
 
-    def __call__(self, preds):
-        indexes = preds[0].argsort(axis=1)[:, -self.topk :][:, ::-1].astype("int32")
+    def __call__(self, preds, topk=5):
+        indexes = preds[0].argsort(axis=1)[:, -topk:][:, ::-1].astype("int32")
         scores = [
             np.around(pred[index], decimals=5) for pred, index in zip(preds[0], indexes)
         ]