Эх сурвалжийг харах

Merge pull request #696 from FlyingQianMM/develop_qh

add cv2.IMREAD_COLOR to decode when input_channel is 3
FlyingQianMM 4 жил өмнө
parent
commit
3853730624

+ 3 - 2
paddlex/cv/transforms/cls_transforms.py

@@ -77,9 +77,10 @@ class Compose(ClsTransform):
             try:
                 if input_channel == 3:
                     im = cv2.imread(im_file, cv2.IMREAD_ANYDEPTH |
-                                    cv2.IMREAD_ANYCOLOR)
+                                    cv2.IMREAD_ANYCOLOR | cv2.IMREAD_COLOR)
                 else:
-                    im = cv2.imread(im_file, cv2.IMREAD_UNCHANGED)
+                    im = cv2.imread(im_file, cv2.IMREAD_ANYDEPTH |
+                                    cv2.IMREAD_ANYCOLOR)
                     if im.ndim < 3:
                         im = np.expand_dims(im, axis=-1)
             except:

+ 3 - 2
paddlex/cv/transforms/det_transforms.py

@@ -112,9 +112,10 @@ class Compose(DetTransform):
                 try:
                     if input_channel == 3:
                         im = cv2.imread(im_file, cv2.IMREAD_ANYDEPTH |
-                                        cv2.IMREAD_ANYCOLOR)
+                                        cv2.IMREAD_ANYCOLOR | cv2.IMREAD_COLOR)
                     else:
-                        im = cv2.imread(im_file, cv2.IMREAD_UNCHANGED)
+                        im = cv2.imread(im_file, cv2.IMREAD_ANYDEPTH |
+                                        cv2.IMREAD_ANYCOLOR)
                         if im.ndim < 3:
                             im = np.expand_dims(im, axis=-1)
                 except:

+ 3 - 2
paddlex/cv/transforms/seg_transforms.py

@@ -89,9 +89,10 @@ class Compose(SegTransform):
         elif img_format in ['jpeg', 'bmp', 'png', 'jpg']:
             if input_channel == 3:
                 return cv2.imread(img_path, cv2.IMREAD_ANYDEPTH |
-                                  cv2.IMREAD_ANYCOLOR)
+                                  cv2.IMREAD_ANYCOLOR | cv2.IMREAD_COLOR)
             else:
-                return cv2.imread(im_file, cv2.IMREAD_UNCHANGED)
+                return cv2.imread(im_file, cv2.IMREAD_ANYDEPTH |
+                                  cv2.IMREAD_ANYCOLOR)
         elif ext == '.npy':
             return np.load(img_path)
         else: