Эх сурвалжийг харах

Merge pull request #454 from FlyingQianMM/develop_qh

return multi-channel image of decode in seg_transforms
Jason 4 жил өмнө
parent
commit
03c8227b75

+ 2 - 1
paddlex/cv/transforms/ops.py

@@ -43,9 +43,10 @@ def resize_long(im, long_size=224, interpolation=cv2.INTER_LINEAR):
     resized_width = int(round(im.shape[1] * scale))
     resized_height = int(round(im.shape[0] * scale))
 
+    im_dims = im.ndim
     im = cv2.resize(
         im, (resized_width, resized_height), interpolation=interpolation)
-    if im.ndim < 3:
+    if im_dims >= 3 and im.ndim < 3:
         im = np.expand_dims(im, axis=-1)
     return im
 

+ 1 - 1
paddlex/cv/transforms/seg_transforms.py

@@ -86,7 +86,7 @@ class Compose(SegTransform):
             if input_channel == 3:
                 return cv2.imread(img_path)
             else:
-                im = cv2.imread(im_file, cv2.IMREAD_UNCHANGED)
+                return cv2.imread(im_file, cv2.IMREAD_UNCHANGED)
         elif ext == '.npy':
             return np.load(img_path)
         else: