瀏覽代碼

modify Padding in seg_transforms to support multi-channel input

FlyingQianMM 5 年之前
父節點
當前提交
9746160e31
共有 1 個文件被更改,包括 21 次插入17 次删除
  1. 21 17
      paddlex/cv/transforms/seg_transforms.py

+ 21 - 17
paddlex/cv/transforms/seg_transforms.py

@@ -725,23 +725,27 @@ class Padding(SegTransform):
         pad_width = target_width - im_width
         pad_height = max(pad_height, 0)
         pad_width = max(pad_width, 0)
-        im = cv2.copyMakeBorder(
-            im,
-            0,
-            pad_height,
-            0,
-            pad_width,
-            cv2.BORDER_CONSTANT,
-            value=self.im_padding_value)
-        if label is not None:
-            label = cv2.copyMakeBorder(
-                label,
-                0,
-                pad_height,
-                0,
-                pad_width,
-                cv2.BORDER_CONSTANT,
-                value=self.label_padding_value)
+        if (pad_height > 0 or pad_width > 0):
+            im_channel = im.shape[2]
+            import copy
+            orig_im = copy.deepcopy(im)
+            im = np.zeros((im_height + pad_height, im_width + pad_width,
+                           im_channel)).astype(orig_im.dtype)
+            for i in range(im_channel):
+                im[:, :, i] = np.pad(
+                    orig_im[:, :, i],
+                    pad_width=((0, pad_height), (0, pad_width)),
+                    mode='constant',
+                    constant_values=(self.im_padding_value[i],
+                                     self.im_padding_value[i]))
+
+            if label is not None:
+                label = np.pad(label,
+                               pad_width=((0, pad_height), (0, pad_width)),
+                               mode='constant',
+                               constant_values=(self.label_padding_value,
+                                                self.label_padding_value))
+
         if label is None:
             return (im, im_info)
         else: