Browse Source

support det_threshold for kpt pipeline (#2931)

* support det_threshold for kpt pipeline

* Refactor type annotations for compatibility

* Update type annotations to use 'List' and 'Dict'
学卿 10 months ago
parent
commit
b98a9d7c50

+ 3 - 1
paddlex/inference/models_new/keypoint_detection/result.py

@@ -174,7 +174,9 @@ class KptResult(BaseCVResult):
             keypoints = [
                 obj["keypoints"] for obj in self["boxes"]
             ]  # for top-down pipeline result
-        image = draw_keypoints(self["input_img"], dict(keypoints=np.stack(keypoints)))
+        image = self["input_img"]
+        if keypoints:
+            image = draw_keypoints(image, dict(keypoints=np.stack(keypoints)))
         return {"res": image}
 
     def _to_str(self, *args, **kwargs):

+ 2 - 2
paddlex/inference/pipelines_new/instance_segmentation/pipeline.py

@@ -54,8 +54,8 @@ class InstanceSegmentationPipeline(BasePipeline):
 
     def predict(
         self,
-        input: str | list[str] | np.ndarray | list[np.ndarray],
-        threshold: float | None = None,
+        input: Union[str, List[str], np.ndarray, List[np.ndarray]],
+        threshold: Union[float, None] = None,
         **kwargs
     ) -> InstanceSegResult:
         """Predicts instance segmentation results for the given input.

+ 13 - 6
paddlex/inference/pipelines_new/keypoint_detection/pipeline.py

@@ -49,10 +49,12 @@ class KeypointDetectionPipeline(BasePipeline):
         # create object detection model
         model_cfg = config["SubModules"]["ObjectDetection"]
         model_kwargs = {}
+        self.det_threshold = None
         if "threshold" in model_cfg:
             model_kwargs["threshold"] = model_cfg["threshold"]
-        if "img_size" in model_cfg:
-            model_kwargs["img_size"] = model_cfg["img_size"]
+            self.det_threshold = model_cfg["threshold"]
+        if "imgsz" in model_cfg:
+            model_kwargs["imgsz"] = model_cfg["imgsz"]
         self.det_model = self.create_model(model_cfg, **model_kwargs)
 
         # create keypoint detection model
@@ -95,19 +97,23 @@ class KeypointDetectionPipeline(BasePipeline):
         return center, scale
 
     def predict(
-        self, input: Union[str, List[str], np.ndarray, List[np.ndarray]], **kwargs
+        self,
+        input: Union[str, List[str], np.ndarray, List[np.ndarray]],
+        det_threshold: Optional[float] = None,
+        **kwargs,
     ) -> KptResult:
         """Predicts image classification results for the given input.
 
         Args:
-            input (Union[str, list[str], np.ndarray, list[np.ndarray]]): The input image(s) or path(s) to the images.
+            input (str | list[str] | np.ndarray | list[np.ndarray]): The input image(s) or path(s) to the images.
+            det_threshold (float): The detection threshold. Defaults to None.
             **kwargs: Additional keyword arguments that can be passed to the function.
 
         Returns:
             KptResult: The predicted KeyPoint Detection results.
         """
-
-        for det_res in self.det_model(input):
+        det_threshold = self.det_threshold if det_threshold is None else det_threshold
+        for det_res in self.det_model(input, threshold=det_threshold):
             ori_img, img_path = det_res["input_img"], det_res["input_path"]
             single_img_res = {"input_path": img_path, "input_img": ori_img, "boxes": []}
             for box in det_res["boxes"]:
@@ -126,6 +132,7 @@ class KeypointDetectionPipeline(BasePipeline):
                         "coordinate": box["coordinate"],
                         "det_score": box["score"],
                         "keypoints": kpt_res["kpts"][0]["keypoints"],
+                        "kpt_score": kpt_res["kpts"][0]["kpt_score"],
                     }
                 )
             yield KptResult(single_img_res)

+ 2 - 2
paddlex/inference/pipelines_new/rotated_object_detection/pipeline.py

@@ -54,8 +54,8 @@ class RotatedObjectDetectionPipeline(BasePipeline):
 
     def predict(
         self,
-        input: str | list[str] | np.ndarray | list[np.ndarray],
-        threshold: None | dict[int, float] | float = None,
+        input: Union[str, List[str], np.ndarray, List[np.ndarray]],
+        threshold: Union[None, Dict[int, float], float] = None,
         **kwargs
     ) -> DetResult:
         """Predicts rotated object detection results for the given input.

+ 3 - 3
paddlex/inference/pipelines_new/semantic_segmentation/pipeline.py

@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Any, Dict, Optional, Literal
+from typing import Any, Dict, Optional, Literal, Union, List, Tuple
 import numpy as np
 from ...utils.pp_option import PaddlePredictorOption
 from ..base import BasePipeline
@@ -54,8 +54,8 @@ class SemanticSegmentationPipeline(BasePipeline):
 
     def predict(
         self,
-        input: str | list[str] | np.ndarray | list[np.ndarray],
-        target_size: Literal[-1] | None | int | tuple[int] = None,
+        input: Union[str, List[str], np.ndarray, List[np.ndarray]],
+        target_size: Union[Literal[-1], None, int, Tuple[int]] = None,
         **kwargs
     ) -> SegResult:
         """Predicts semantic segmentation results for the given input.

+ 2 - 2
paddlex/inference/pipelines_new/small_object_detection/pipeline.py

@@ -54,8 +54,8 @@ class SmallObjectDetectionPipeline(BasePipeline):
 
     def predict(
         self,
-        input: str | list[str] | np.ndarray | list[np.ndarray],
-        threshold: None | dict[int, float] | float = None,
+        input: Union[str, List[str], np.ndarray, List[np.ndarray]],
+        threshold: Union[None, Dict[int, float], float] = None,
         **kwargs
     ) -> DetResult:
         """Predicts small object detection results for the given input.