소스 검색

adapt to mask with channel of 2 and 3

will-jl944 4 년 전
부모
커밋
5c22dcdfbb
2개의 변경된 파일4개의 추가작업 그리고 10개의 파일을 삭제
  1. 3 9
      dygraph/paddlex/cv/transforms/functions.py
  2. 1 1
      dygraph/paddlex/cv/transforms/operators.py

+ 3 - 9
dygraph/paddlex/cv/transforms/functions.py

@@ -47,23 +47,17 @@ def center_crop(im, crop_size=224):
     h_start = (height - crop_size) // 2
     w_end = w_start + crop_size
     h_end = h_start + crop_size
-    im = im[h_start:h_end, w_start:w_end, :]
+    im = im[h_start:h_end, w_start:w_end, ...]
     return im
 
 
 def horizontal_flip(im):
-    if len(im.shape) == 3:
-        im = im[:, ::-1, :]
-    elif len(im.shape) == 2:
-        im = im[:, ::-1]
+    im = im[:, ::-1, ...]
     return im
 
 
 def vertical_flip(im):
-    if len(im.shape) == 3:
-        im = im[::-1, :, :]
-    elif len(im.shape) == 2:
-        im = im[::-1, :]
+    im = im[::-1, :, ...]
     return im
 
 

+ 1 - 1
dygraph/paddlex/cv/transforms/operators.py

@@ -814,7 +814,7 @@ class RandomCrop(Transform):
 
     def apply_mask(self, mask, crop):
         x1, y1, x2, y2 = crop
-        return mask[y1:y2, x1:x2, :]
+        return mask[y1:y2, x1:x2, ...]
 
     def apply(self, sample):
         crop_info = self._generate_crop_info(sample)