Эх сурвалжийг харах

support threshold param for DetPredictor

leo-q8 11 сар өмнө
parent
commit
d65348f97c

+ 14 - 4
paddlex/inference/models_new/object_detection/predictor.py

@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Any, List, Sequence
+from typing import Any, List, Sequence, Optional
 
 import numpy as np
 
@@ -43,8 +43,16 @@ class DetPredictor(BasicPredictor):
     _FUNC_MAP = {}
     register = FuncRegister(_FUNC_MAP)
 
-    def __init__(self, *args, **kwargs):
+    def __init__(self, *args, threshold: Optional[float] = None, **kwargs):
+        """Initializes DetPredictor.
+        Args:
+            *args: Arbitrary positional arguments passed to the superclass.
+            threshold (Optional[float], optional): The threshold for filtering out low-confidence predictions.
+                Defaults to None.
+            **kwargs: Arbitrary keyword arguments passed to the superclass.
+        """
         super().__init__(*args, **kwargs)
+        self.threshold = threshold
         self.pre_ops, self.infer, self.post_op = self._build()
 
     def _build_batch_sampler(self):
@@ -131,7 +139,7 @@ class DetPredictor(BasicPredictor):
         else:
             return [{"boxes": np.array(res)} for res in pred_box]
 
-    def process(self, batch_data: List[Any]):
+    def process(self, batch_data: List[Any], threshold: Optional[float] = None):
         """
         Process a batch of data through the preprocessing, inference, and postprocessing.
 
@@ -157,7 +165,9 @@ class DetPredictor(BasicPredictor):
         preds_list = self._format_output(batch_preds)
 
         # postprocess
-        boxes = self.post_op(preds_list, datas)
+        boxes = self.post_op(
+            preds_list, datas, threshold if threshold is not None else self.threshold
+        )
 
         return {
             "input_path": [data.get("img_path", None) for data in datas],

+ 16 - 7
paddlex/inference/models_new/object_detection/processors.py

@@ -611,7 +611,7 @@ class DetPostProcess:
         self.labels = labels
         self.layout_postprocess = layout_postprocess
 
-    def apply(self, boxes: ndarray, img_size) -> Boxes:
+    def apply(self, boxes: ndarray, img_size, threshold: Union[float, dict]) -> Boxes:
         """Apply post-processing to the detection boxes.
 
         Args:
@@ -621,15 +621,15 @@ class DetPostProcess:
         Returns:
             Boxes: The post-processed detection boxes.
         """
-        if isinstance(self.threshold, float):
-            expect_boxes = (boxes[:, 1] > self.threshold) & (boxes[:, 0] > -1)
+        if isinstance(threshold, float):
+            expect_boxes = (boxes[:, 1] > threshold) & (boxes[:, 0] > -1)
             boxes = boxes[expect_boxes, :]
-        elif isinstance(self.threshold, dict):
+        elif isinstance(threshold, dict):
             category_filtered_boxes = []
             for cat_id in np.unique(boxes[:, 0]):
                 category_boxes = boxes[boxes[:, 0] == cat_id]
                 category_scores = category_boxes[:, 1]
-                category_threshold = self.threshold.get(int(cat_id), 0.5)
+                category_threshold = threshold.get(int(cat_id), 0.5)
                 selected_indices = category_scores > category_threshold
                 category_filtered_boxes.append(category_boxes[selected_indices])
             boxes = (
@@ -714,7 +714,12 @@ class DetPostProcess:
             )
         return boxes
 
-    def __call__(self, batch_outputs: List[dict], datas: List[dict]) -> List[Boxes]:
+    def __call__(
+        self,
+        batch_outputs: List[dict],
+        datas: List[dict],
+        threshold: Optional[Union[float, dict]] = None,
+    ) -> List[Boxes]:
         """Apply the post-processing to a batch of outputs.
 
         Args:
@@ -726,6 +731,10 @@ class DetPostProcess:
         """
         outputs = []
         for data, output in zip(datas, batch_outputs):
-            boxes = self.apply(output["boxes"], data["ori_img_size"])
+            boxes = self.apply(
+                output["boxes"],
+                data["ori_img_size"],
+                threshold if threshold is not None else self.threshold,
+            )
             outputs.append(boxes)
         return outputs