Ver código fonte

yowo infer param for paddlex (#2846)

Sunflower7788 10 meses atrás
pai
commit
5461f51964

+ 4 - 1
api_examples/pipelines/test_video_detection.py

@@ -15,10 +15,13 @@
 from paddlex import create_pipeline
 
 pipeline = create_pipeline(pipeline="video_detection")
-output = pipeline.predict("./test_samples/HorseRiding.avi")
+output = pipeline.predict(
+    "./test_samples/HorseRiding.avi", nms_thresh=0.5, score_thresh=0.85
+)
 
 for res in output:
     print(res)
     res.print()  ## 打印预测的结构化输出
+    res.save_to_video("./output/1.mp4")  ## 保存结果可视化视频
     res.save_to_video("./output/")  ## 保存结果可视化视频
     res.save_to_json("./output/")  ## 保存预测的结构化输出

+ 2 - 1
paddlex/configs/pipelines/video_detection.yaml

@@ -6,4 +6,5 @@ SubModules:
     model_name: YOWO
     model_dir: null
     batch_size: 1    
-    topk: 1
+    nms_thresh: 0.5
+    score_thresh: 0.8

+ 26 - 7
paddlex/inference/models_new/video_detection/predictor.py

@@ -33,8 +33,16 @@ class VideoDetPredictor(BasicPredictor):
     _FUNC_MAP = {}
     register = FuncRegister(_FUNC_MAP)
 
-    def __init__(self, topk: Union[int, None] = None, *args, **kwargs):
+    def __init__(
+        self,
+        nms_thresh: Union[float, None] = None,
+        score_thresh: Union[float, None] = None,
+        *args,
+        **kwargs
+    ):
         super().__init__(*args, **kwargs)
+        self.nms_thresh = nms_thresh
+        self.score_thresh = score_thresh
         self.pre_tfs, self.infer, self.post_op = self._build()
 
     def _build_batch_sampler(self):
@@ -73,7 +81,12 @@ class VideoDetPredictor(BasicPredictor):
 
         return pre_tfs, infer, post_op
 
-    def process(self, batch_data):
+    def process(
+        self,
+        batch_data,
+        nms_thresh: Union[float, None] = None,
+        score_thresh: Union[float, None] = None,
+    ):
         batch_raw_videos = self.pre_tfs["ReadVideo"](videos=batch_data)
         batch_videos = self.pre_tfs["ResizeVideo"](videos=batch_raw_videos)
         batch_videos = self.pre_tfs["Image2Array"](videos=batch_videos)
@@ -83,7 +96,11 @@ class VideoDetPredictor(BasicPredictor):
         for i in range(num_seg):
             batch_preds = self.infer(x=[x[0][i]])
             pred_seg.append(batch_preds)
-        batch_bboxes = self.post_op["DetVideoPostProcess"](preds=[pred_seg])
+        batch_bboxes = self.post_op["DetVideoPostProcess"](
+            preds=[pred_seg],
+            nms_thresh=nms_thresh or self.nms_thresh,
+            score_thresh=score_thresh or self.score_thresh,
+        )
         return {
             "input_path": batch_data,
             "result": batch_bboxes,
@@ -111,7 +128,9 @@ class VideoDetPredictor(BasicPredictor):
         return "NormalizeVideo", NormalizeVideo(scale=scale)
 
     @register("DetVideoPostProcess")
-    def build_postprocess(self, nms_thresh=0.5, score_thresh=0.4, label_list=[]):
-        return "DetVideoPostProcess", DetVideoPostProcess(
-            nms_thresh=nms_thresh, score_thresh=score_thresh, label_list=label_list
-        )
+    def build_postprocess(self, nms_thresh, score_thresh, label_list=[]):
+        if not self.nms_thresh:
+            self.nms_thresh = nms_thresh
+        if not self.score_thresh:
+            self.score_thresh = score_thresh
+        return "DetVideoPostProcess", DetVideoPostProcess(label_list=label_list)

+ 9 - 18
paddlex/inference/models_new/video_detection/processors.py

@@ -406,28 +406,17 @@ class DetVideoPostProcess:
 
     def __init__(
         self,
-        nms_thresh: float = 0.5,
-        score_thresh: float = 0.5,
         label_list: List[str] = [],
     ) -> None:
         """
         Args:
-            nms_thresh : float
-                The IoU (Intersection over Union) threshold used for Non-Maximum Suppression (NMS).
-                Detections with an IoU greater than this threshold will be suppressed.
-            score_thresh : float
-                The threshold for filtering out low-confidence detections.
-                Detections with a confidence score below this threshold will be discarded.
             labels : List[str]
                 A list of labels or class names associated with the detection results.
         """
         super().__init__()
-
-        self.nms_thresh = nms_thresh
-        self.score_thresh = score_thresh
         self.labels = label_list
 
-    def postprocess(self, pred: List) -> List:
+    def postprocess(self, pred: List, nms_thresh: float, score_thresh: float) -> List:
         font = cv2.FONT_HERSHEY_SIMPLEX
         num_seg = len(pred)
         pred_all = []
@@ -436,11 +425,10 @@ class DetVideoPostProcess:
             for out in outputs:
                 preds = []
                 out = paddle.to_tensor(out)
-                all_boxes = get_region_boxes(out, self.score_thresh, len(self.labels))
+                all_boxes = get_region_boxes(out, 0.3, len(self.labels))
                 for i in range(out.shape[0]):
                     boxes = all_boxes[i]
-                    boxes = nms(boxes, self.nms_thresh)
-
+                    boxes = nms(boxes, nms_thresh)
                     for box in boxes:
                         x1 = round(float(box[0] - box[2] / 2.0) * 320.0)
                         y1 = round(float(box[1] - box[3] / 2.0) * 240.0)
@@ -451,9 +439,12 @@ class DetVideoPostProcess:
                         for j in range((len(box) - 5) // 2):
                             cls_conf = float(box[5 + 2 * j].item())
                             prob = det_conf * cls_conf
-                        preds.append([[x1, y1, x2, y2], prob, self.labels[int(box[6])]])
+                        if prob > score_thresh:
+                            preds.append(
+                                [[x1, y1, x2, y2], prob, self.labels[int(box[6])]]
+                            )
             pred_all.append(preds)
         return pred_all
 
-    def __call__(self, preds: List) -> List:
-        return [self.postprocess(pred) for pred in preds]
+    def __call__(self, preds: List, nms_thresh, score_thresh) -> List:
+        return [self.postprocess(pred, nms_thresh, score_thresh) for pred in preds]

+ 8 - 2
paddlex/inference/pipelines_new/video_detection/pipeline.py

@@ -52,7 +52,11 @@ class VideoDetectionPipeline(BasePipeline):
         self.video_detection_model = self.create_model(video_detection_model_config)
 
     def predict(
-        self, input: str | list[str] | np.ndarray | list[np.ndarray], **kwargs
+        self,
+        input: str | list[str] | np.ndarray | list[np.ndarray],
+        nms_thresh: float = 0.5,
+        score_thresh: float = 0.4,
+        **kwargs
     ) -> DetVideoResult:
         """Predicts video detection results for the given input.
 
@@ -64,4 +68,6 @@ class VideoDetectionPipeline(BasePipeline):
             DetVideoResult: The predicted video detection results.
         """
 
-        yield from self.video_detection_model(input)
+        yield from self.video_detection_model(
+            input, nms_thresh=nms_thresh, score_thresh=score_thresh
+        )