Browse Source

adapt to mask with channel of 2 and 3

will-jl944 4 năm trước cách đây
mục cha
commit
5c22dcdfbb

+ 3 - 9
dygraph/paddlex/cv/transforms/functions.py

@@ -47,23 +47,17 @@ def center_crop(im, crop_size=224):
     h_start = (height - crop_size) // 2
     w_end = w_start + crop_size
     h_end = h_start + crop_size
-    im = im[h_start:h_end, w_start:w_end, :]
+    im = im[h_start:h_end, w_start:w_end, ...]
     return im
 
 
 def horizontal_flip(im):
-    if len(im.shape) == 3:
-        im = im[:, ::-1, :]
-    elif len(im.shape) == 2:
-        im = im[:, ::-1]
+    im = im[:, ::-1, ...]
     return im
 
 
 def vertical_flip(im):
-    if len(im.shape) == 3:
-        im = im[::-1, :, :]
-    elif len(im.shape) == 2:
-        im = im[::-1, :]
+    im = im[::-1, :, ...]
     return im
 
 

+ 1 - 1
dygraph/paddlex/cv/transforms/operators.py

@@ -814,7 +814,7 @@ class RandomCrop(Transform):
 
     def apply_mask(self, mask, crop):
         x1, y1, x2, y2 = crop
-        return mask[y1:y2, x1:x2, :]
+        return mask[y1:y2, x1:x2, ...]
 
     def apply(self, sample):
         crop_info = self._generate_crop_info(sample)