Browse Source

randomdisort support multi-channel input

FlyingQianMM 5 năm trước cách đây
mục cha
commit
908254c8a8

+ 65 - 0
examples/change_detection/hrnet_1.py

@@ -0,0 +1,65 @@
+# 环境变量配置,用于控制是否使用GPU
+# 说明文档:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html#gpu
+import os
+os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,6,7'
+
+import paddlex as pdx
+from paddlex.seg import transforms
+
+# 定义训练和验证时的transforms
+# API说明 https://paddlex.readthedocs.io/zh_CN/develop/apis/transforms/seg_transforms.html
+train_transforms = transforms.Compose([
+    transforms.ResizeStepScaling(
+        min_scale_factor=0.5, max_scale_factor=2.,
+        scale_step_size=0.25), transforms.RandomRotate(
+            rotate_range=180,
+            im_padding_value=[127.5] * 6), transforms.RandomPaddingCrop(
+                crop_size=769, im_padding_value=[127.5] * 6),
+    transforms.RandomDistort(), transforms.RandomHorizontalFlip(),
+    transforms.RandomVerticalFlip(), transforms.Normalize(
+        mean=[0.5] * 6, std=[0.5] * 6, min_val=[0] * 6, max_val=[255] * 6)
+])
+
+eval_transforms = transforms.Compose([
+    transforms.Padding(
+        target_size=769, im_padding_value=[127.5] * 6), transforms.Normalize(
+            mean=[0.5] * 6, std=[0.5] * 6, min_val=[0] * 6, max_val=[255] * 6)
+])
+
+# 定义训练和验证所用的数据集
+# API说明:https://paddlex.readthedocs.io/zh_CN/develop/apis/datasets.html#paddlex-datasets-segdataset
+train_dataset = pdx.datasets.ChangeDetDataset(
+    data_dir='tiled_dataset',
+    file_list='tiled_dataset/train_list.txt',
+    label_list='tiled_dataset/labels.txt',
+    transforms=train_transforms,
+    num_workers=4,
+    parallel_method='thread',
+    shuffle=True)
+eval_dataset = pdx.datasets.ChangeDetDataset(
+    data_dir='tiled_dataset',
+    file_list='tiled_dataset/val_list.txt',
+    label_list='tiled_dataset/labels.txt',
+    num_workers=4,
+    parallel_method='thread',
+    transforms=eval_transforms)
+
+# 初始化模型,并进行训练
+# 可使用VisualDL查看训练指标,参考https://paddlex.readthedocs.io/zh_CN/develop/train/visualdl.html
+num_classes = len(train_dataset.labels)
+
+# API说明:https://paddlex.readthedocs.io/zh_CN/develop/apis/models/semantic_segmentation.html#paddlex-seg-deeplabv3p
+model = pdx.seg.HRNet(num_classes=num_classes, input_channel=6)
+
+# API说明:https://paddlex.readthedocs.io/zh_CN/develop/apis/models/semantic_segmentation.html#train
+# 各参数介绍与调整说明:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html
+model.train(
+    num_epochs=400,
+    train_dataset=train_dataset,
+    train_batch_size=16,
+    eval_dataset=eval_dataset,
+    learning_rate=0.1,
+    save_interval_epochs=10,
+    pretrain_weights='CITYSCAPES',
+    save_dir='output/hrnet_1',
+    use_vdl=True)

+ 108 - 0
examples/change_detection/prepara_data_cd.py

@@ -0,0 +1,108 @@
+import os
+import os.path as osp
+import numpy as np
+import cv2
+import shutil
+import random
+random.seed(0)
+from PIL import Image
+import paddlex as pdx
+
+# 定义训练集切分时的滑动窗口大小和步长,格式为(W, H)
+train_tile_size = (1024, 1024)
+train_stride = (512, 512)
+# 定义验证集切分时的滑动窗口大小和步长,格式(W, H)
+val_tile_size = (769, 769)
+val_stride = (769, 769)
+# 训练集和验证集比例
+train_ratio = 0.8
+val_ratio = 0.2
+
+change_det_dataset = './change_det_data'
+tiled_dataset = './tiled_dataset'
+origin_dataset = './origin_dataset'
+tiled_image_dir = osp.join(tiled_dataset, 'JPEGImages')
+tiled_anno_dir = osp.join(tiled_dataset, 'Annotations')
+
+if not osp.exists(tiled_image_dir):
+    os.makedirs(tiled_image_dir)
+if not osp.exists(tiled_anno_dir):
+    os.makedirs(tiled_anno_dir)
+
+# 划分数据集
+im1_file_list = os.listdir(osp.join(change_det_dataset, 'T1'))
+im2_file_list = os.listdir(osp.join(change_det_dataset, 'T2'))
+label_file_list = os.listdir(osp.join(change_det_dataset, 'labels_change'))
+im1_file_list = sorted(
+    im1_file_list, key=lambda k: int(k.split('test')[-1].split('_')[0]))
+im2_file_list = sorted(
+    im2_file_list, key=lambda k: int(k.split('test')[-1].split('_')[0]))
+label_file_list = sorted(
+    label_file_list, key=lambda k: int(k.split('test')[-1].split('_')[0]))
+file_list = list()
+for im1_file, im2_file, label_file in zip(im1_file_list, im2_file_list,
+                                          label_file_list):
+    im1_file = osp.join(osp.join(change_det_dataset, 'T1'), im1_file)
+    im2_file = osp.join(osp.join(change_det_dataset, 'T2'), im2_file)
+    label_file = osp.join(
+        osp.join(change_det_dataset, 'labels_change'), label_file)
+    file_list.append((im1_file, im2_file, label_file))
+random.shuffle(file_list)
+train_num = int(len(file_list) * train_ratio)
+
+for i, item in enumerate(file_list):
+    im1_file, im2_file, label_file = item[:]
+    if i < train_num:
+        stride = train_stride
+        tile_size = train_tile_size
+    else:
+        stride = val_stride
+        tile_size = val_tile_size
+    i += 1
+    set_name = 'train' if i < train_num else 'val'
+    im1 = cv2.imread(im1_file)
+    im2 = cv2.imread(im2_file)
+    label = cv2.imread(label_file, cv2.IMREAD_GRAYSCALE)
+    label = label != 0
+    label = label.astype(np.uint8)
+    H, W, C = im1.shape
+    tile_id = 1
+    im1_name = osp.split(im1_file)[-1].split('.')[0]
+    im2_name = osp.split(im2_file)[-1].split('.')[0]
+    label_name = osp.split(label_file)[-1].split('.')[0]
+    for h in range(0, H, stride[1]):
+        for w in range(0, W, stride[0]):
+            left = w
+            upper = h
+            right = min(w + tile_size[0], W)
+            lower = min(h + tile_size[1], H)
+            tile_im1 = im1[upper:lower, left:right, :]
+            tile_im2 = im2[upper:lower, left:right, :]
+            cv2.imwrite(
+                osp.join(tiled_image_dir,
+                         "{}_{}.bmp".format(im1_name, tile_id)), tile_im1)
+            cv2.imwrite(
+                osp.join(tiled_image_dir,
+                         "{}_{}.bmp".format(im2_name, tile_id)), tile_im2)
+            cut_label = label[upper:lower, left:right]
+            cv2.imwrite(
+                osp.join(tiled_anno_dir,
+                         "{}_{}.png".format(label_name, tile_id)), cut_label)
+            mode = 'w' if i in [0, train_num] and tile_id == 1 else 'a'
+            with open(
+                    osp.join(tiled_dataset, '{}_list.txt'.format(set_name)),
+                    mode) as f:
+                f.write(
+                    "JPEGImages/{}_{}.bmp JPEGImages/{}_{}.bmp Annotations/{}_{}.png\n".
+                    format(im1_name, tile_id, im2_name, tile_id, label_name,
+                           tile_id))
+            tile_id += 1
+
+# 生成labels.txt
+label_list = ['unchanged', 'changed']
+for i, label in enumerate(label_list):
+    mode = 'w' if i == 0 else 'a'
+    with open(osp.join(tiled_dataset, 'labels.txt'), 'a') as f:
+        name = "{}\n".format(label) if i < len(
+            label_list) - 1 else "{}".format(label)
+        f.write(name)

+ 40 - 0
examples/change_detection/test.py

@@ -0,0 +1,40 @@
+import os
+import cv2
+import numpy as np
+from PIL import Image
+
+import paddlex as pdx
+
+model_dir = "output/unet_1/best_model/"
+save_dir = 'output/gt_pred'
+if not os.path.exists(save_dir):
+    os.makedirs(save_dir)
+color = [0, 0, 0, 255, 255, 255]
+
+model = pdx.load_model(model_dir)
+
+with open('tiled_dataset/val_list.txt', 'r') as f:
+    for line in f:
+        items = line.strip().split()
+        img_file_1 = os.path.join('tiled_dataset', items[0])
+        img_file_2 = os.path.join('tiled_dataset', items[1])
+        label_file = os.path.join('tiled_dataset', items[2])
+
+        # 预测并可视化预测结果
+        im1 = cv2.imread(img_file_1)
+        im2 = cv2.imread(img_file_2)
+        image = np.concatenate((im1, im2), axis=-1)
+        pred = model.predict(image)
+        vis_pred = pdx.seg.visualize(
+            img_file_1, pred, weight=0., save_dir=None, color=color)
+
+        # 可视化标注文件
+        label = np.asarray(Image.open(label_file))
+        pred = {'label_map': label}
+        vis_gt = pdx.seg.visualize(
+            img_file_1, pred, weight=0., save_dir=None, color=color)
+
+        ims = cv2.hconcat([im1, im2])
+        labels = cv2.hconcat([vis_gt, vis_pred])
+        data = cv2.vconcat([ims, labels])
+        cv2.imwrite("{}/{}".format(save_dir, items[0].split('/')[-1]), data)

+ 63 - 0
examples/change_detection/unet_3.py

@@ -0,0 +1,63 @@
+# 环境变量配置,用于控制是否使用GPU
+# 说明文档:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html#gpu
+import os
+os.environ['CUDA_VISIBLE_DEVICES'] = '2,3,4,5'
+
+import paddlex as pdx
+from paddlex.seg import transforms
+
+# 定义训练和验证时的transforms
+# API说明 https://paddlex.readthedocs.io/zh_CN/develop/apis/transforms/seg_transforms.html
+train_transforms = transforms.Compose([
+    transforms.ResizeStepScaling(
+        min_scale_factor=0.5, max_scale_factor=2.,
+        scale_step_size=0.25), transforms.RandomRotate(
+            rotate_range=180,
+            im_padding_value=[127.5] * 6), transforms.RandomPaddingCrop(
+                crop_size=769, im_padding_value=[127.5] * 6),
+    transforms.RandomDistort(), transforms.RandomHorizontalFlip(),
+    transforms.RandomVerticalFlip(), transforms.Normalize(
+        mean=[0.5] * 6, std=[0.5] * 6, min_val=[0] * 6, max_val=[255] * 6)
+])
+
+eval_transforms = transforms.Compose([
+    transforms.Padding(
+        target_size=769, im_padding_value=[127.5] * 6), transforms.Normalize(
+            mean=[0.5] * 6, std=[0.5] * 6, min_val=[0] * 6, max_val=[255] * 6)
+])
+
+# 定义训练和验证所用的数据集
+# API说明:https://paddlex.readthedocs.io/zh_CN/develop/apis/datasets.html#paddlex-datasets-segdataset
+train_dataset = pdx.datasets.ChangeDetDataset(
+    data_dir='tiled_dataset',
+    file_list='tiled_dataset/train_list.txt',
+    label_list='tiled_dataset/labels.txt',
+    transforms=train_transforms,
+    num_workers=4,
+    shuffle=True)
+eval_dataset = pdx.datasets.ChangeDetDataset(
+    data_dir='tiled_dataset',
+    file_list='tiled_dataset/val_list.txt',
+    label_list='tiled_dataset/labels.txt',
+    num_workers=4,
+    transforms=eval_transforms)
+
+# 初始化模型,并进行训练
+# 可使用VisualDL查看训练指标,参考https://paddlex.readthedocs.io/zh_CN/develop/train/visualdl.html
+num_classes = len(train_dataset.labels)
+
+# API说明:https://paddlex.readthedocs.io/zh_CN/develop/apis/models/semantic_segmentation.html#paddlex-seg-deeplabv3p
+model = pdx.seg.UNet(num_classes=num_classes, input_channel=6)
+
+# API说明:https://paddlex.readthedocs.io/zh_CN/develop/apis/models/semantic_segmentation.html#train
+# 各参数介绍与调整说明:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html
+model.train(
+    num_epochs=400,
+    train_dataset=train_dataset,
+    train_batch_size=16,
+    eval_dataset=eval_dataset,
+    learning_rate=0.1,
+    save_interval_epochs=10,
+    pretrain_weights='CITYSCAPES',
+    save_dir='output/unet_3',
+    use_vdl=True)

+ 1 - 0
paddlex/cv/models/utils/visualize.py

@@ -93,6 +93,7 @@ def visualize_segmentation(image,
     if abs(weight) < 1e-5:
         vis_result = pseudo_img
     else:
+        print(im.dtype, pseudo_img.dtype)
         vis_result = cv2.addWeighted(im, weight, pseudo_img, 1 - weight, 0)
 
     if save_dir is not None:

+ 38 - 22
paddlex/cv/transforms/seg_transforms.py

@@ -840,14 +840,19 @@ class RandomPaddingCrop(SegTransform):
                 img_width = im.shape[1]
 
             if crop_height > 0 and crop_width > 0:
-                h_off = np.random.randint(img_height - crop_height + 1)
-                w_off = np.random.randint(img_width - crop_width + 1)
-
-                im = im[h_off:(crop_height + h_off), w_off:(w_off + crop_width
-                                                            ), :]
-                if label is not None:
-                    label = label[h_off:(crop_height + h_off), w_off:(
-                        w_off + crop_width)]
+                while 1:
+                    h_off = np.random.randint(img_height - crop_height + 1)
+                    w_off = np.random.randint(img_width - crop_width + 1)
+
+                    im = im[h_off:(crop_height + h_off), w_off:(w_off +
+                                                                crop_width), :]
+                    if label is not None:
+                        label = label[h_off:(crop_height + h_off), w_off:(
+                            w_off + crop_width)]
+                    if np.max(im) != np.min(im):
+                        break
+                    else:
+                        print('There is only one class\n')
         if label is None:
             return (im, im_info)
         else:
@@ -936,7 +941,7 @@ class RandomRotate(SegTransform):
                 存储与图像相关信息的字典和标注图像np.ndarray数据。
         """
         if self.rotate_range > 0:
-            (h, w) = im.shape[:2]
+            h, w, c = im.shape
             do_rotation = np.random.uniform(-self.rotate_range,
                                             self.rotate_range)
             pc = (w // 2, h // 2)
@@ -951,13 +956,18 @@ class RandomRotate(SegTransform):
             r[0, 2] += (nw / 2) - cx
             r[1, 2] += (nh / 2) - cy
             dsize = (nw, nh)
-            im = cv2.warpAffine(
-                im,
-                r,
-                dsize=dsize,
-                flags=cv2.INTER_LINEAR,
-                borderMode=cv2.BORDER_CONSTANT,
-                borderValue=self.im_padding_value)
+            rot_ims = list()
+            for i in range(0, c, 3):
+                ori_im = im[:, :, i:i + 3]
+                rot_im = cv2.warpAffine(
+                    ori_im,
+                    r,
+                    dsize=dsize,
+                    flags=cv2.INTER_LINEAR,
+                    borderMode=cv2.BORDER_CONSTANT,
+                    borderValue=self.im_padding_value[i:i + 3])
+                rot_ims.append(rot_im)
+            im = np.concatenate(rot_ims, axis=-1)
             label = cv2.warpAffine(
                 label,
                 r,
@@ -1119,12 +1129,18 @@ class RandomDistort(SegTransform):
             'saturation': self.saturation_prob,
             'hue': self.hue_prob
         }
-        for id in range(4):
-            params = params_dict[ops[id].__name__]
-            prob = prob_dict[ops[id].__name__]
-            params['im'] = im
-            if np.random.uniform(0, 1) < prob:
-                im = ops[id](**params)
+        dis_ims = list()
+        h, w, c = im.shape
+        for i in range(0, c, 3):
+            ori_im = im[:, :, i:i + 3]
+            for id in range(4):
+                params = params_dict[ops[id].__name__]
+                prob = prob_dict[ops[id].__name__]
+                params['im'] = ori_im
+                if np.random.uniform(0, 1) < prob:
+                    ori_im = ops[id](**params)
+            dis_ims.append(ori_im)
+        im = np.concatenate(dis_ims, axis=-1)
         im = im.astype('float32')
         if label is None:
             return (im, im_info)

+ 19 - 6
paddlex/cv/transforms/visualize.py

@@ -236,16 +236,22 @@ def seg_compose(im,
                     len(im.shape)))
     else:
         try:
-            im = cv2.imread(im).astype('float32')
+            im = cv2.imread(im)
         except:
             raise ValueError('Can\'t read The image file {}!'.format(im))
-    im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
+    #im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
+    im = im.astype('float32')
+    h, w, c = im.shape
     if label is not None:
         if not isinstance(label, np.ndarray):
             label = np.asarray(Image.open(label))
     if vdl_writer is not None:
-        vdl_writer.add_image(
-            tag='0. OriginalImage' + '/' + str(step), img=im, step=0)
+        for i in range(0, c, 3):
+            if c > 3:
+                tag = '0. OriginalImage/{}_{}'.format(str(step), str(i // 3))
+            else:
+                tag = '0. OriginalImage/{}'.format(str(step))
+            vdl_writer.add_image(tag=tag, img=im[:, :, i:i + 3], step=0)
     op_id = 1
     for op in transforms:
         if isinstance(op, SegTransform):
@@ -264,8 +270,15 @@ def seg_compose(im,
             else:
                 outputs = (im, im_info)
         if vdl_writer is not None:
-            tag = str(op_id) + '. ' + op.__class__.__name__ + '/' + str(step)
-            vdl_writer.add_image(tag=tag, img=im, step=0)
+            for i in range(0, c, 3):
+                if c > 3:
+                    tag = str(
+                        op_id) + '. ' + op.__class__.__name__ + '/' + str(
+                            step) + '_' + str(i // 3)
+                else:
+                    tag = str(
+                        op_id) + '. ' + op.__class__.__name__ + '/' + str(step)
+                vdl_writer.add_image(tag=tag, img=im[:, :, i:i + 3], step=0)
         op_id += 1