소스 검색

Merge pull request #529 from FlyingQianMM/develop_qh

add 1-channel image requirement for the segmentation label file
Jason 4 년 전
부모
커밋
d5395ee62a
2개의 변경된 파일5개의 추가작업 그리고 1개의 파일을 삭제
  1. 1 1
      paddlex/cv/models/slim/post_quantization.py
  2. 4 0
      paddlex/cv/transforms/seg_transforms.py

+ 1 - 1
paddlex/cv/models/slim/post_quantization.py

@@ -263,7 +263,7 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
                 start = time.time()
                 sampling_data = []
                 file_name = os.path.join(self._cache_dir, var_name)
-                cache_dir, var_name_ = os.path.split(file_name) 
+                cache_dir, var_name_ = os.path.split(file_name)
                 filenames = [f for f in os.listdir(cache_dir) \
                     if re.match(var_name_ + '_[0-9]+.npy', f)]
                 for filename in filenames:

+ 4 - 0
paddlex/cv/transforms/seg_transforms.py

@@ -119,6 +119,10 @@ class Compose(SegTransform):
                     label = np.asarray(Image.open(label))
                 except:
                     ValueError('Can\'t read The label file {}!'.format(label))
+                if len(label.shape) != 2:
+                    raise Exception(
+                        "label should be a 1-channel image, but recevied a {}-channel image.".
+                        format(label.shape[2]))
             im_height, im_width, _ = im.shape
             label_height, label_width = label.shape
             if im_height != label_height or im_width != label_width: