瀏覽代碼

Merge pull request #845 from will-jl944/develop_jf

Adapt to mask with channel of 2 and 3
FlyingQianMM 4 年之前
父節點
當前提交
0111bc2cdb
共有 2 個文件被更改,包括 4 次插入10 次删除
  1. 3 9
      dygraph/paddlex/cv/transforms/functions.py
  2. 1 1
      dygraph/paddlex/cv/transforms/operators.py

+ 3 - 9
dygraph/paddlex/cv/transforms/functions.py

@@ -47,23 +47,17 @@ def center_crop(im, crop_size=224):
     h_start = (height - crop_size) // 2
     w_end = w_start + crop_size
     h_end = h_start + crop_size
-    im = im[h_start:h_end, w_start:w_end, :]
+    im = im[h_start:h_end, w_start:w_end, ...]
     return im
 
 
 def horizontal_flip(im):
-    if len(im.shape) == 3:
-        im = im[:, ::-1, :]
-    elif len(im.shape) == 2:
-        im = im[:, ::-1]
+    im = im[:, ::-1, ...]
     return im
 
 
 def vertical_flip(im):
-    if len(im.shape) == 3:
-        im = im[::-1, :, :]
-    elif len(im.shape) == 2:
-        im = im[::-1, :]
+    im = im[::-1, :, ...]
     return im
 
 

+ 1 - 1
dygraph/paddlex/cv/transforms/operators.py

@@ -814,7 +814,7 @@ class RandomCrop(Transform):
 
     def apply_mask(self, mask, crop):
         x1, y1, x2, y2 = crop
-        return mask[y1:y2, x1:x2, :]
+        return mask[y1:y2, x1:x2, ...]
 
     def apply(self, sample):
         crop_info = self._generate_crop_info(sample)