ソースを参照

fix insanceseg bug

zhangyubo0722 1 年間 前
コミット
0a1d857232
1 ファイル変更4 行追加4 行削除
  1. 4 4
      paddlex/inference/components/task_related/instance_seg.py

+ 4 - 4
paddlex/inference/components/task_related/instance_seg.py

@@ -41,9 +41,9 @@ def extract_masks_from_boxes(boxes, masks):
 class InstanceSegPostProcess(BaseComponent):
     """Save Result Transform"""
 
-    INPUT_KEYS = ["boxes", "masks"]
+    INPUT_KEYS = ["boxes", "masks", "img_size"]
     OUTPUT_KEYS = ["img_path", "boxes", "masks"]
-    DEAULT_INPUTS = {"boxes": "boxes", "masks": "masks"}
+    DEAULT_INPUTS = {"boxes": "boxes", "masks": "masks", "img_size": "ori_img_size"}
     DEAULT_OUTPUTS = {
         "boxes": "boxes",
         "masks": "masks",
@@ -54,11 +54,11 @@ class InstanceSegPostProcess(BaseComponent):
         self.threshold = threshold
         self.labels = labels
 
-    def apply(self, boxes, masks):
+    def apply(self, boxes, masks, img_size):
         """apply"""
         expect_boxes = (boxes[:, 1] > self.threshold) & (boxes[:, 0] > -1)
         boxes = boxes[expect_boxes, :]
-        boxes = restructured_boxes(boxes, self.labels)
+        boxes = restructured_boxes(boxes, self.labels, img_size)
         masks = masks[expect_boxes, :, :]
         masks = extract_masks_from_boxes(boxes, masks)
         result = {"boxes": boxes, "masks": masks}