|
|
@@ -49,10 +49,12 @@ class KeypointDetectionPipeline(BasePipeline):
|
|
|
# create object detection model
|
|
|
model_cfg = config["SubModules"]["ObjectDetection"]
|
|
|
model_kwargs = {}
|
|
|
+ self.det_threshold = None
|
|
|
if "threshold" in model_cfg:
|
|
|
model_kwargs["threshold"] = model_cfg["threshold"]
|
|
|
- if "img_size" in model_cfg:
|
|
|
- model_kwargs["img_size"] = model_cfg["img_size"]
|
|
|
+ self.det_threshold = model_cfg["threshold"]
|
|
|
+ if "imgsz" in model_cfg:
|
|
|
+ model_kwargs["imgsz"] = model_cfg["imgsz"]
|
|
|
self.det_model = self.create_model(model_cfg, **model_kwargs)
|
|
|
|
|
|
# create keypoint detection model
|
|
|
@@ -95,19 +97,23 @@ class KeypointDetectionPipeline(BasePipeline):
|
|
|
return center, scale
|
|
|
|
|
|
def predict(
|
|
|
- self, input: Union[str, List[str], np.ndarray, List[np.ndarray]], **kwargs
|
|
|
+ self,
|
|
|
+ input: Union[str, List[str], np.ndarray, List[np.ndarray]],
|
|
|
+ det_threshold: Optional[float] = None,
|
|
|
+ **kwargs,
|
|
|
) -> KptResult:
|
|
|
"""Predicts image classification results for the given input.
|
|
|
|
|
|
Args:
|
|
|
- input (Union[str, list[str], np.ndarray, list[np.ndarray]]): The input image(s) or path(s) to the images.
|
|
|
+ input (str | list[str] | np.ndarray | list[np.ndarray]): The input image(s) or path(s) to the images.
|
|
|
+ det_threshold (float): The detection threshold. Defaults to None.
|
|
|
**kwargs: Additional keyword arguments that can be passed to the function.
|
|
|
|
|
|
Returns:
|
|
|
KptResult: The predicted KeyPoint Detection results.
|
|
|
"""
|
|
|
-
|
|
|
- for det_res in self.det_model(input):
|
|
|
+ det_threshold = self.det_threshold if det_threshold is None else det_threshold
|
|
|
+ for det_res in self.det_model(input, threshold=det_threshold):
|
|
|
ori_img, img_path = det_res["input_img"], det_res["input_path"]
|
|
|
single_img_res = {"input_path": img_path, "input_img": ori_img, "boxes": []}
|
|
|
for box in det_res["boxes"]:
|
|
|
@@ -126,6 +132,7 @@ class KeypointDetectionPipeline(BasePipeline):
|
|
|
"coordinate": box["coordinate"],
|
|
|
"det_score": box["score"],
|
|
|
"keypoints": kpt_res["kpts"][0]["keypoints"],
|
|
|
+ "kpt_score": kpt_res["kpts"][0]["kpt_score"],
|
|
|
}
|
|
|
)
|
|
|
yield KptResult(single_img_res)
|