|
|
@@ -41,9 +41,9 @@ def extract_masks_from_boxes(boxes, masks):
|
|
|
class InstanceSegPostProcess(BaseComponent):
|
|
|
"""Save Result Transform"""
|
|
|
|
|
|
- INPUT_KEYS = ["boxes", "masks"]
|
|
|
+ INPUT_KEYS = ["boxes", "masks", "img_size"]
|
|
|
OUTPUT_KEYS = ["img_path", "boxes", "masks"]
|
|
|
- DEAULT_INPUTS = {"boxes": "boxes", "masks": "masks"}
|
|
|
+ DEAULT_INPUTS = {"boxes": "boxes", "masks": "masks", "img_size": "ori_img_size"}
|
|
|
DEAULT_OUTPUTS = {
|
|
|
"boxes": "boxes",
|
|
|
"masks": "masks",
|
|
|
@@ -54,11 +54,11 @@ class InstanceSegPostProcess(BaseComponent):
|
|
|
self.threshold = threshold
|
|
|
self.labels = labels
|
|
|
|
|
|
- def apply(self, boxes, masks):
|
|
|
+ def apply(self, boxes, masks, img_size):
|
|
|
"""apply"""
|
|
|
expect_boxes = (boxes[:, 1] > self.threshold) & (boxes[:, 0] > -1)
|
|
|
boxes = boxes[expect_boxes, :]
|
|
|
- boxes = restructured_boxes(boxes, self.labels)
|
|
|
+ boxes = restructured_boxes(boxes, self.labels, img_size)
|
|
|
masks = masks[expect_boxes, :, :]
|
|
|
masks = extract_masks_from_boxes(boxes, masks)
|
|
|
result = {"boxes": boxes, "masks": masks}
|