prepara_data.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. import os
  2. import os.path as osp
  3. import numpy as np
  4. import cv2
  5. import shutil
  6. import random
  7. # 为保证每次运行该脚本时划分的样本一致,故固定随机种子
  8. random.seed(0)
  9. import paddlex as pdx
  10. # 定义训练集切分时的滑动窗口大小和步长,格式为(W, H)
  11. train_tile_size = (1024, 1024)
  12. train_stride = (512, 512)
  13. # 定义验证集切分时的滑动窗口大小和步长,格式(W, H)
  14. val_tile_size = (769, 769)
  15. val_stride = (769, 769)
  16. # 训练集和验证集比例
  17. train_ratio = 0.75
  18. val_ratio = 0.25
  19. # 切分后的数据集保存路径
  20. tiled_dataset = './tiled_dataset'
  21. # 切分后的图像文件保存路径
  22. tiled_image_dir = osp.join(tiled_dataset, 'JPEGImages')
  23. # 切分后的标注文件保存路径
  24. tiled_anno_dir = osp.join(tiled_dataset, 'Annotations')
  25. # 下载和解压Google Dataset数据集
  26. change_det_dataset = 'https://bj.bcebos.com/paddlex/examples/change_detection/dataset/google_change_det_dataset.tar.gz'
  27. pdx.utils.download_and_decompress(change_det_dataset, path='./')
  28. change_det_dataset = './google_change_det_dataset'
  29. image1_dir = osp.join(change_det_dataset, 'T1')
  30. image2_dir = osp.join(change_det_dataset, 'T2')
  31. label_dir = osp.join(change_det_dataset, 'labels_change')
  32. if not osp.exists(tiled_image_dir):
  33. os.makedirs(tiled_image_dir)
  34. if not osp.exists(tiled_anno_dir):
  35. os.makedirs(tiled_anno_dir)
  36. # 划分数据集
  37. im1_file_list = os.listdir(image1_dir)
  38. im2_file_list = os.listdir(image2_dir)
  39. label_file_list = os.listdir(label_dir)
  40. im1_file_list = sorted(
  41. im1_file_list, key=lambda k: int(k.split('test')[-1].split('_')[0]))
  42. im2_file_list = sorted(
  43. im2_file_list, key=lambda k: int(k.split('test')[-1].split('_')[0]))
  44. label_file_list = sorted(
  45. label_file_list, key=lambda k: int(k.split('test')[-1].split('_')[0]))
  46. file_list = list()
  47. for im1_file, im2_file, label_file in zip(im1_file_list, im2_file_list,
  48. label_file_list):
  49. im1_file = osp.join(image1_dir, im1_file)
  50. im2_file = osp.join(image2_dir, im2_file)
  51. label_file = osp.join(label_dir, label_file)
  52. file_list.append((im1_file, im2_file, label_file))
  53. random.shuffle(file_list)
  54. train_num = int(len(file_list) * train_ratio)
  55. # 将大图切分成小图
  56. for i, item in enumerate(file_list):
  57. if i < train_num:
  58. stride = train_stride
  59. tile_size = train_tile_size
  60. else:
  61. stride = val_stride
  62. tile_size = val_tile_size
  63. set_name = 'train' if i < train_num else 'val'
  64. # 生成原图的file_list
  65. im1_file, im2_file, label_file = item[:]
  66. mode = 'w' if i in [0, train_num] else 'a'
  67. with open(
  68. osp.join(change_det_dataset, '{}_list.txt'.format(set_name)),
  69. mode) as f:
  70. f.write("T1/{} T2/{} labels_change/{}\n".format(
  71. osp.split(im1_file)[-1],
  72. osp.split(im2_file)[-1], osp.split(label_file)[-1]))
  73. im1 = cv2.imread(im1_file)
  74. im2 = cv2.imread(im2_file)
  75. # 将三通道的label图像转换成单通道的png格式图片
  76. # 且将标注0和255转换成0和1
  77. label = cv2.imread(label_file, cv2.IMREAD_GRAYSCALE)
  78. label = label != 0
  79. label = label.astype(np.uint8)
  80. H, W, C = im1.shape
  81. tile_id = 1
  82. im1_name = osp.split(im1_file)[-1].split('.')[0]
  83. im2_name = osp.split(im2_file)[-1].split('.')[0]
  84. label_name = osp.split(label_file)[-1].split('.')[0]
  85. for h in range(0, H, stride[1]):
  86. for w in range(0, W, stride[0]):
  87. left = w
  88. upper = h
  89. right = min(w + tile_size[0], W)
  90. lower = min(h + tile_size[1], H)
  91. tile_im1 = im1[upper:lower, left:right, :]
  92. tile_im2 = im2[upper:lower, left:right, :]
  93. cv2.imwrite(
  94. osp.join(tiled_image_dir,
  95. "{}_{}.bmp".format(im1_name, tile_id)), tile_im1)
  96. cv2.imwrite(
  97. osp.join(tiled_image_dir,
  98. "{}_{}.bmp".format(im2_name, tile_id)), tile_im2)
  99. cut_label = label[upper:lower, left:right]
  100. cv2.imwrite(
  101. osp.join(tiled_anno_dir,
  102. "{}_{}.png".format(label_name, tile_id)), cut_label)
  103. mode = 'w' if i in [0, train_num] and tile_id == 1 else 'a'
  104. with open(
  105. osp.join(tiled_dataset, '{}_list.txt'.format(set_name)),
  106. mode) as f:
  107. f.write(
  108. "JPEGImages/{}_{}.bmp JPEGImages/{}_{}.bmp Annotations/{}_{}.png\n".
  109. format(im1_name, tile_id, im2_name, tile_id, label_name,
  110. tile_id))
  111. tile_id += 1
  112. # 生成labels.txt
  113. label_list = ['unchanged', 'changed']
  114. for i, label in enumerate(label_list):
  115. mode = 'w' if i == 0 else 'a'
  116. with open(osp.join(tiled_dataset, 'labels.txt'), 'a') as f:
  117. name = "{}\n".format(label) if i < len(
  118. label_list) - 1 else "{}".format(label)
  119. f.write(name)