|
|
@@ -611,7 +611,7 @@ class DetPostProcess:
|
|
|
self.labels = labels
|
|
|
self.layout_postprocess = layout_postprocess
|
|
|
|
|
|
- def apply(self, boxes: ndarray, img_size) -> Boxes:
|
|
|
+ def apply(self, boxes: ndarray, img_size, threshold: Union[float, dict]) -> Boxes:
|
|
|
"""Apply post-processing to the detection boxes.
|
|
|
|
|
|
Args:
|
|
|
@@ -621,15 +621,15 @@ class DetPostProcess:
|
|
|
Returns:
|
|
|
Boxes: The post-processed detection boxes.
|
|
|
"""
|
|
|
- if isinstance(self.threshold, float):
|
|
|
- expect_boxes = (boxes[:, 1] > self.threshold) & (boxes[:, 0] > -1)
|
|
|
+ if isinstance(threshold, float):
|
|
|
+ expect_boxes = (boxes[:, 1] > threshold) & (boxes[:, 0] > -1)
|
|
|
boxes = boxes[expect_boxes, :]
|
|
|
- elif isinstance(self.threshold, dict):
|
|
|
+ elif isinstance(threshold, dict):
|
|
|
category_filtered_boxes = []
|
|
|
for cat_id in np.unique(boxes[:, 0]):
|
|
|
category_boxes = boxes[boxes[:, 0] == cat_id]
|
|
|
category_scores = category_boxes[:, 1]
|
|
|
- category_threshold = self.threshold.get(int(cat_id), 0.5)
|
|
|
+ category_threshold = threshold.get(int(cat_id), 0.5)
|
|
|
selected_indices = category_scores > category_threshold
|
|
|
category_filtered_boxes.append(category_boxes[selected_indices])
|
|
|
boxes = (
|
|
|
@@ -714,7 +714,12 @@ class DetPostProcess:
|
|
|
)
|
|
|
return boxes
|
|
|
|
|
|
- def __call__(self, batch_outputs: List[dict], datas: List[dict]) -> List[Boxes]:
|
|
|
+ def __call__(
|
|
|
+ self,
|
|
|
+ batch_outputs: List[dict],
|
|
|
+ datas: List[dict],
|
|
|
+ threshold: Optional[Union[float, dict]] = None,
|
|
|
+ ) -> List[Boxes]:
|
|
|
"""Apply the post-processing to a batch of outputs.
|
|
|
|
|
|
Args:
|
|
|
@@ -726,6 +731,10 @@ class DetPostProcess:
|
|
|
"""
|
|
|
outputs = []
|
|
|
for data, output in zip(datas, batch_outputs):
|
|
|
- boxes = self.apply(output["boxes"], data["ori_img_size"])
|
|
|
+ boxes = self.apply(
|
|
|
+ output["boxes"],
|
|
|
+ data["ori_img_size"],
|
|
|
+ threshold if threshold is not None else self.threshold,
|
|
|
+ )
|
|
|
outputs.append(boxes)
|
|
|
return outputs
|