فهرست منبع

add change detection example

FlyingQianMM 5 سال پیش
والد
کامیت
c01726ca43

+ 106 - 0
examples/change_detection/prepara_data.py

@@ -0,0 +1,106 @@
+import os
+import os.path as osp
+import numpy as np
+import cv2
+import shutil
+from PIL import Image
+import paddlex as pdx
+
+# 定义训练集切分时的滑动窗口大小和步长,格式为(W, H)
+train_tile_size = (512, 512)
+train_stride = (256, 256)
+# 定义验证集切分时的滑动窗口大小和步长,格式(W, H)
+val_tile_size = (256, 256)
+val_stride = (256, 256)
+
+## 下载并解压2015 CCF大数据比赛提供的高清遥感影像
+#SZTAKI_AirChange_Benchmark = 'https://bj.bcebos.com/paddlex/examples/remote_sensing/datasets/ccf_remote_dataset.tar.gz'
+#pdx.utils.download_and_decompress(SZTAKI_AirChange_Benchmark, path='./')
+
+if not osp.exists('./dataset/JPEGImages'):
+    os.makedirs('./dataset/JPEGImages')
+if not osp.exists('./dataset/Annotations'):
+    os.makedirs('./dataset/Annotations')
+
+# 将前4张图片划分入训练集,并切分成小块之后加入到训练集中
+# 并生成train_list.txt
+train_list = {'Szada': [2, 3, 4, 5, 6, 7], 'Tiszadob': [1, 2, 4, 5]}
+val_list = {'Szada': [1], 'Tiszadob': [3]}
+all_list = [train_list, val_list]
+
+for i, data_list in enumerate(all_list):
+    id = 0
+    if i == 0:
+        for key, value in data_list.items():
+            for v in value:
+                shutil.copyfile(
+                    "SZTAKI_AirChange_Benchmark/{}/{}/im1.bmp".format(key, v),
+                    "./dataset/JPEGImages/{}_{}_im1.bmp".format(key, v))
+                shutil.copyfile(
+                    "SZTAKI_AirChange_Benchmark/{}/{}/im2.bmp".format(key, v),
+                    "./dataset/JPEGImages/{}_{}_im2.bmp".format(key, v))
+                label = cv2.imread(
+                    "SZTAKI_AirChange_Benchmark/{}/{}/gt.bmp".format(key, v))
+                label = label[:, :, 0]
+                label = label != 0
+                label = label.astype(np.uint8)
+                cv2.imwrite("./dataset/Annotations/{}_{}_gt.png".format(
+                    key, v), label)
+
+                id += 1
+                mode = 'w' if id == 1 else 'a'
+                with open('./dataset/train_list.txt', mode) as f:
+                    f.write(
+                        "JPEGImages/{}_{}_im1.bmp JPEGImages/{}_{}_im2.bmp Annotations/{}_{}_gt.png\n".
+                        format(key, v, key, v, key, v))
+
+    if i == 0:
+        stride = train_stride
+        tile_size = train_tile_size
+    else:
+        stride = val_stride
+        tile_size = val_tile_size
+    for key, value in data_list.items():
+        for v in value:
+            im1 = cv2.imread("SZTAKI_AirChange_Benchmark/{}/{}/im1.bmp".format(
+                key, v))
+            im2 = cv2.imread("SZTAKI_AirChange_Benchmark/{}/{}/im2.bmp".format(
+                key, v))
+            label = cv2.imread(
+                "SZTAKI_AirChange_Benchmark/{}/{}/gt.bmp".format(key, v))
+            label = label[:, :, 0]
+            label = label != 0
+            label = label.astype(np.uint8)
+            H, W, C = im1.shape
+            tile_id = 1
+            for h in range(0, H, stride[1]):
+                for w in range(0, W, stride[0]):
+                    left = w
+                    upper = h
+                    right = min(w + tile_size[0], W)
+                    lower = min(h + tile_size[1], H)
+                    tile_im1 = im1[upper:lower, left:right, :]
+                    tile_im2 = im2[upper:lower, left:right, :]
+                    cv2.imwrite("./dataset/JPEGImages/{}_{}_{}_im1.bmp".format(
+                        key, v, tile_id), tile_im1)
+                    cv2.imwrite("./dataset/JPEGImages/{}_{}_{}_im2.bmp".format(
+                        key, v, tile_id), tile_im2)
+                    cut_label = label[upper:lower, left:right]
+                    cv2.imwrite("./dataset/Annotations/{}_{}_{}_gt.png".format(
+                        key, v, tile_id), cut_label)
+                    with open('./dataset/{}_list.txt'.format(
+                            'train' if i == 0 else 'val'), 'a') as f:
+                        f.write(
+                            "JPEGImages/{}_{}_{}_im1.bmp JPEGImages/{}_{}_{}_im2.bmp Annotations/{}_{}_{}_gt.png\n".
+                            format(key, v, tile_id, key, v, tile_id, key, v,
+                                   tile_id))
+                    tile_id += 1
+
+# 生成labels.txt
+label_list = ['unchanged', 'changed']
+for i, label in enumerate(label_list):
+    mode = 'w' if i == 0 else 'a'
+    with open('./dataset/labels.txt', 'a') as f:
+        name = "{}\n".format(label) if i < len(
+            label_list) - 1 else "{}".format(label)
+        f.write(name)

+ 57 - 0
examples/change_detection/unet.py

@@ -0,0 +1,57 @@
+# 环境变量配置,用于控制是否使用GPU
+# 说明文档:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html#gpu
+import os
+os.environ['CUDA_VISIBLE_DEVICES'] = '1'
+
+import paddlex as pdx
+from paddlex.seg import transforms
+
+# 定义训练和验证时的transforms
+# API说明 https://paddlex.readthedocs.io/zh_CN/develop/apis/transforms/seg_transforms.html
+train_transforms = transforms.Compose([
+    transforms.RandomPaddingCrop(
+        crop_size=256,
+        im_padding_value=[127.5] * 6), transforms.RandomHorizontalFlip(),
+    transforms.RandomVerticalFlip(), transforms.Normalize(
+        mean=[0.5] * 6, std=[0.5] * 6, min_val=[0] * 6, max_val=[255] * 6)
+])
+
+eval_transforms = transforms.Compose([
+    transforms.Padding(
+        target_size=256, im_padding_value=[127.5] * 6), transforms.Normalize(
+            mean=[0.5] * 6, std=[0.5] * 6, min_val=[0] * 6, max_val=[255] * 6)
+])
+
+# 定义训练和验证所用的数据集
+# API说明:https://paddlex.readthedocs.io/zh_CN/develop/apis/datasets.html#paddlex-datasets-segdataset
+train_dataset = pdx.datasets.ChangeDetDataset(
+    data_dir='dataset',
+    file_list='dataset/train_list.txt',
+    label_list='dataset/labels.txt',
+    transforms=train_transforms,
+    shuffle=True)
+eval_dataset = pdx.datasets.ChangeDetDataset(
+    data_dir='dataset',
+    file_list='dataset/val_list.txt',
+    label_list='dataset/labels.txt',
+    transforms=eval_transforms)
+
+# 初始化模型,并进行训练
+# 可使用VisualDL查看训练指标,参考https://paddlex.readthedocs.io/zh_CN/develop/train/visualdl.html
+num_classes = len(train_dataset.labels)
+
+# API说明:https://paddlex.readthedocs.io/zh_CN/develop/apis/models/semantic_segmentation.html#paddlex-seg-deeplabv3p
+model = pdx.seg.UNet(num_classes=num_classes, input_channel=6)
+
+# API说明:https://paddlex.readthedocs.io/zh_CN/develop/apis/models/semantic_segmentation.html#train
+# 各参数介绍与调整说明:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html
+model.train(
+    num_epochs=400,
+    train_dataset=train_dataset,
+    train_batch_size=16,
+    eval_dataset=eval_dataset,
+    learning_rate=0.01,
+    save_interval_epochs=10,
+    pretrain_weights='CITYSCAPES',
+    save_dir='output/unet',
+    use_vdl=True)

+ 1 - 0
paddlex/cv/datasets/__init__.py

@@ -21,3 +21,4 @@ from .easydata_det import EasyDataDet
 from .easydata_seg import EasyDataSeg
 from .dataset import generate_minibatch
 from .analysis import Seg
+from .change_det_dataet import ChangeDetDataset

+ 108 - 0
paddlex/cv/datasets/change_det_dataet.py

@@ -0,0 +1,108 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+import os.path as osp
+import random
+import copy
+import numpy as np
+import paddlex.utils.logging as logging
+from paddlex.cv.transforms import seg_transforms
+from paddlex.utils import path_normalization
+from .dataset import Dataset
+from .dataset import get_encoding
+
+
+class ChangeDetDataset(Dataset):
+    """读取语义分割任务数据集,并对样本进行相应的处理。
+
+    Args:
+        data_dir (str): 数据集所在的目录路径。
+        file_list (str): 描述数据集图片文件和对应标注文件的文件路径(文本内每行路径为相对data_dir的相对路)。
+        label_list (str): 描述数据集包含的类别信息文件路径。默认值为None。
+        transforms (list): 数据集中每个样本的预处理/增强算子。
+        num_workers (int): 数据集中样本在预处理过程中的线程或进程数。默认为'auto'。
+        buffer_size (int): 数据集中样本在预处理过程中队列的缓存长度,以样本数为单位。默认为100。
+        parallel_method (str): 数据集中样本在预处理过程中并行处理的方式,支持'thread'
+            线程和'process'进程两种方式。默认为'process'(Windows和Mac下会强制使用thread,该参数无效)。
+        shuffle (bool): 是否需要对数据集中样本打乱顺序。默认为False。
+    """
+
+    def __init__(self,
+                 data_dir,
+                 file_list,
+                 label_list=None,
+                 transforms=None,
+                 num_workers='auto',
+                 buffer_size=100,
+                 parallel_method='process',
+                 shuffle=False):
+        super(ChangeDetDataset, self).__init__(
+            transforms=transforms,
+            num_workers=num_workers,
+            buffer_size=buffer_size,
+            parallel_method=parallel_method,
+            shuffle=shuffle)
+        self.file_list = list()
+        self.labels = list()
+        self._epoch = 0
+
+        if label_list is not None:
+            with open(label_list, encoding=get_encoding(label_list)) as f:
+                for line in f:
+                    item = line.strip()
+                    self.labels.append(item)
+        with open(file_list, encoding=get_encoding(file_list)) as f:
+            for line in f:
+                items = line.strip().split()
+                if len(items) > 3:
+                    raise Exception(
+                        "A space is defined as the separator, but it exists in image or label name {}."
+                        .format(line))
+                items[0] = path_normalization(items[0])
+                items[1] = path_normalization(items[1])
+                items[2] = path_normalization(items[2])
+                full_path_im1 = osp.join(data_dir, items[0])
+                full_path_im2 = osp.join(data_dir, items[1])
+                full_path_label = osp.join(data_dir, items[2])
+                if not osp.exists(full_path_im1):
+                    raise IOError('The image file {} is not exist!'.format(
+                        full_path_im1))
+                if not osp.exists(full_path_im2):
+                    raise IOError('The image file {} is not exist!'.format(
+                        full_path_im2))
+                if not osp.exists(full_path_label):
+                    raise IOError('The image file {} is not exist!'.format(
+                        full_path_label))
+                self.file_list.append(
+                    [full_path_im1, full_path_im2, full_path_label])
+        self.num_samples = len(self.file_list)
+        logging.info("{} samples in file {}".format(
+            len(self.file_list), file_list))
+
+    def iterator(self):
+        self._epoch += 1
+        self._pos = 0
+        files = copy.deepcopy(self.file_list)
+        if self.shuffle:
+            random.shuffle(files)
+        files = files[:self.num_samples]
+        self.num_samples = len(files)
+        for f in files:
+            label_path = f[2]
+            image1 = seg_transforms.Compose.read_img(f[0])
+            image2 = seg_transforms.Compose.read_img(f[1])
+            image = np.concatenate((image1, image2), axis=-1)
+            sample = [image, None, label_path]
+            yield sample