Browse Source

modify transforms

jiangjiajun 5 years ago
parent
commit
d2729260f4

+ 1 - 1
paddlex/__init__.py

@@ -53,4 +53,4 @@ log_level = 2
 
 from . import interpret
 
-__version__ = '1.0.2.github'
+__version__ = '1.0.4'

+ 20 - 19
paddlex/cv/models/deeplabv3p.py

@@ -190,11 +190,6 @@ class DeepLabv3p(BaseAPI):
         if mode == 'train':
             self.optimizer.minimize(model_out)
             outputs['loss'] = model_out
-        elif mode == 'eval':
-            outputs['loss'] = model_out[0]
-            outputs['pred'] = model_out[1]
-            outputs['label'] = model_out[2]
-            outputs['mask'] = model_out[3]
         else:
             outputs['pred'] = model_out[0]
             outputs['logit'] = model_out[1]
@@ -336,18 +331,26 @@ class DeepLabv3p(BaseAPI):
         for step, data in tqdm.tqdm(
                 enumerate(data_generator()), total=total_steps):
             images = np.array([d[0] for d in data])
-            labels = np.array([d[1] for d in data])
+
+            _, _, im_h, im_w = images.shape
+            labels = list()
+            for d in data:
+                padding_label = np.zeros(
+                    (1, im_h, im_w)).astype('int64') + self.ignore_index
+                padding_label[:, :im_h, :im_w] = d[1]
+                labels.append(padding_label)
+            labels = np.array(labels)
+
             num_samples = images.shape[0]
             if num_samples < batch_size:
                 num_pad_samples = batch_size - num_samples
                 pad_images = np.tile(images[0:1], (num_pad_samples, 1, 1, 1))
                 images = np.concatenate([images, pad_images])
             feed_data = {'image': images}
-            outputs = self.exe.run(
-                self.parallel_test_prog,
-                feed=feed_data,
-                fetch_list=list(self.test_outputs.values()),
-                return_numpy=True)
+            outputs = self.exe.run(self.parallel_test_prog,
+                                   feed=feed_data,
+                                   fetch_list=list(self.test_outputs.values()),
+                                   return_numpy=True)
             pred = outputs[0]
             if num_samples < batch_size:
                 pred = pred[0:num_samples]
@@ -364,8 +367,7 @@ class DeepLabv3p(BaseAPI):
 
         metrics = OrderedDict(
             zip(['miou', 'category_iou', 'macc', 'category_acc', 'kappa'],
-                [miou, category_iou, macc, category_acc,
-                 conf_mat.kappa()]))
+                [miou, category_iou, macc, category_acc, conf_mat.kappa()]))
         if return_details:
             eval_details = {
                 'confusion_matrix': conf_mat.confusion_matrix.tolist()
@@ -394,10 +396,9 @@ class DeepLabv3p(BaseAPI):
                 transforms=self.test_transforms, mode='test')
             im, im_info = self.test_transforms(im_file)
         im = np.expand_dims(im, axis=0)
-        result = self.exe.run(
-            self.test_prog,
-            feed={'image': im},
-            fetch_list=list(self.test_outputs.values()))
+        result = self.exe.run(self.test_prog,
+                              feed={'image': im},
+                              fetch_list=list(self.test_outputs.values()))
         pred = result[0]
         pred = np.squeeze(pred).astype('uint8')
         logit = result[1]
@@ -413,6 +414,6 @@ class DeepLabv3p(BaseAPI):
                 pred = pred[0:h, 0:w]
                 logit = logit[0:h, 0:w, :]
             else:
-                raise Exception("Unexpected info '{}' in im_info".format(
-                    info[0]))
+                raise Exception("Unexpected info '{}' in im_info".format(info[
+                    0]))
         return {'label_map': pred, 'score_map': logit}

+ 8 - 8
paddlex/cv/nets/segmentation/deeplabv3p.py

@@ -135,7 +135,8 @@ class DeepLabv3p(object):
         param_attr = fluid.ParamAttr(
             name=name_scope + 'weights',
             regularizer=None,
-            initializer=fluid.initializer.TruncatedNormal(loc=0.0, scale=0.06))
+            initializer=fluid.initializer.TruncatedNormal(
+                loc=0.0, scale=0.06))
         with scope('encoder'):
             channel = 256
             with scope("image_pool"):
@@ -151,8 +152,8 @@ class DeepLabv3p(object):
                         padding=0,
                         param_attr=param_attr))
                 input_shape = fluid.layers.shape(input)
-                image_avg = fluid.layers.resize_bilinear(
-                    image_avg, input_shape[2:])
+                image_avg = fluid.layers.resize_bilinear(image_avg,
+                                                         input_shape[2:])
 
             with scope("aspp0"):
                 aspp0 = bn_relu(
@@ -244,7 +245,8 @@ class DeepLabv3p(object):
         param_attr = fluid.ParamAttr(
             name=name_scope + 'weights',
             regularizer=None,
-            initializer=fluid.initializer.TruncatedNormal(loc=0.0, scale=0.06))
+            initializer=fluid.initializer.TruncatedNormal(
+                loc=0.0, scale=0.06))
         with scope('decoder'):
             with scope('concat'):
                 decode_shortcut = bn_relu(
@@ -326,9 +328,6 @@ class DeepLabv3p(object):
         if self.mode == 'train':
             inputs['label'] = fluid.data(
                 dtype='int32', shape=[None, 1, None, None], name='label')
-        elif self.mode == 'eval':
-            inputs['label'] = fluid.data(
-                dtype='int32', shape=[None, 1, None, None], name='label')
         return inputs
 
     def build_net(self, inputs):
@@ -351,7 +350,8 @@ class DeepLabv3p(object):
             name=name_scope + 'weights',
             regularizer=fluid.regularizer.L2DecayRegularizer(
                 regularization_coeff=0.0),
-            initializer=fluid.initializer.TruncatedNormal(loc=0.0, scale=0.01))
+            initializer=fluid.initializer.TruncatedNormal(
+                loc=0.0, scale=0.01))
         with scope('logit'):
             with fluid.name_scope('last_conv'):
                 logit = conv(

+ 5 - 4
paddlex/cv/transforms/cls_transforms.py

@@ -494,13 +494,14 @@ class ComposedClsTransforms(Compose):
                  std=[0.229, 0.224, 0.225]):
         width = crop_size
         if isinstance(crop_size, list):
-            if shape[0] != shape[1]:
+            if crop_size[0] != crop_size[1]:
                 raise Exception(
-                    "In classifier model, width and height should be equal")
+                    "In classifier model, width and height should be equal, please modify your parameter `crop_size`"
+                )
             width = crop_size[0]
         if width % 32 != 0:
             raise Exception(
-                "In classifier model, width and height should be multiple of 32, e.g 224、256、320...."
+                "In classifier model, width and height should be multiple of 32, e.g 224、256、320...., please modify your parameter `crop_size`"
             )
 
         if mode == 'train':
@@ -513,7 +514,7 @@ class ComposedClsTransforms(Compose):
         else:
             # 验证/预测时的transforms
             transforms = [
-                ReiszeByShort(short_size=int(width * 1.14)),
+                ResizeByShort(short_size=int(width * 1.14)),
                 CenterCrop(crop_size=width), Normalize(
                     mean=mean, std=std)
             ]

+ 2 - 2
paddlex/cv/transforms/det_transforms.py

@@ -1280,7 +1280,7 @@ class ComposedRCNNTransforms(Compose):
                 Padding(coarsest_stride=32)
             ]
 
-        super(RCNNTransforms, self).__init__(transforms)
+        super(ComposedRCNNTransforms, self).__init__(transforms)
 
 
 class ComposedYOLOTransforms(Compose):
@@ -1338,4 +1338,4 @@ class ComposedYOLOTransforms(Compose):
                     target_size=width, interp='CUBIC'), Normalize(
                         mean=mean, std=std)
             ]
-        super(YOLOTransforms, self).__init__(transforms)
+        super(ComposedYOLOTransforms, self).__init__(transforms)

+ 3 - 3
paddlex/cv/transforms/seg_transforms.py

@@ -1091,7 +1091,7 @@ class ArrangeSegmenter(SegTransform):
             return (im, )
 
 
-class ComposedTransforms(Compose):
+class ComposedSegTransforms(Compose):
     """ 语义分割模型(UNet/DeepLabv3p)的图像处理流程,具体如下
         训练阶段:
         1. 随机对图像以0.5的概率水平翻转
@@ -1113,7 +1113,7 @@ class ComposedTransforms(Compose):
                  train_crop_size=[769, 769],
                  mean=[0.5, 0.5, 0.5],
                  std=[0.5, 0.5, 0.5]):
-        if self.mode == 'train':
+        if mode == 'train':
             # 训练时的transforms,包含数据增强
             transforms = [
                 RandomHorizontalFlip(prob=0.5), ResizeStepScaling(),
@@ -1122,6 +1122,6 @@ class ComposedTransforms(Compose):
             ]
         else:
             # 验证/预测时的transforms
-            transforms = [transforms.Normalize(mean=mean, std=std)]
+            transforms = [Resize(512), Normalize(mean=mean, std=std)]
 
         super(ComposedSegTransforms, self).__init__(transforms)

+ 1 - 1
setup.py

@@ -19,7 +19,7 @@ long_description = "PaddleX. A end-to-end deeplearning model development toolkit
 
 setuptools.setup(
     name="paddlex",
-    version='1.0.2',
+    version='1.0.4',
     author="paddlex",
     author_email="paddlex@baidu.com",
     description=long_description,