prepara_data_cd.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. import os
  2. import os.path as osp
  3. import numpy as np
  4. import cv2
  5. import shutil
  6. import random
  7. random.seed(0)
  8. from PIL import Image
  9. import paddlex as pdx
  10. # 定义训练集切分时的滑动窗口大小和步长,格式为(W, H)
  11. train_tile_size = (1024, 1024)
  12. train_stride = (512, 512)
  13. # 定义验证集切分时的滑动窗口大小和步长,格式(W, H)
  14. val_tile_size = (769, 769)
  15. val_stride = (769, 769)
  16. # 训练集和验证集比例
  17. train_ratio = 0.8
  18. val_ratio = 0.2
  19. change_det_dataset = './change_det_data'
  20. tiled_dataset = './tiled_dataset'
  21. origin_dataset = './origin_dataset'
  22. tiled_image_dir = osp.join(tiled_dataset, 'JPEGImages')
  23. tiled_anno_dir = osp.join(tiled_dataset, 'Annotations')
  24. if not osp.exists(tiled_image_dir):
  25. os.makedirs(tiled_image_dir)
  26. if not osp.exists(tiled_anno_dir):
  27. os.makedirs(tiled_anno_dir)
  28. # 划分数据集
  29. im1_file_list = os.listdir(osp.join(change_det_dataset, 'T1'))
  30. im2_file_list = os.listdir(osp.join(change_det_dataset, 'T2'))
  31. label_file_list = os.listdir(osp.join(change_det_dataset, 'labels_change'))
  32. im1_file_list = sorted(
  33. im1_file_list, key=lambda k: int(k.split('test')[-1].split('_')[0]))
  34. im2_file_list = sorted(
  35. im2_file_list, key=lambda k: int(k.split('test')[-1].split('_')[0]))
  36. label_file_list = sorted(
  37. label_file_list, key=lambda k: int(k.split('test')[-1].split('_')[0]))
  38. file_list = list()
  39. for im1_file, im2_file, label_file in zip(im1_file_list, im2_file_list,
  40. label_file_list):
  41. im1_file = osp.join(osp.join(change_det_dataset, 'T1'), im1_file)
  42. im2_file = osp.join(osp.join(change_det_dataset, 'T2'), im2_file)
  43. label_file = osp.join(
  44. osp.join(change_det_dataset, 'labels_change'), label_file)
  45. file_list.append((im1_file, im2_file, label_file))
  46. random.shuffle(file_list)
  47. train_num = int(len(file_list) * train_ratio)
  48. for i, item in enumerate(file_list):
  49. im1_file, im2_file, label_file = item[:]
  50. if i < train_num:
  51. stride = train_stride
  52. tile_size = train_tile_size
  53. else:
  54. stride = val_stride
  55. tile_size = val_tile_size
  56. i += 1
  57. set_name = 'train' if i < train_num else 'val'
  58. im1 = cv2.imread(im1_file)
  59. im2 = cv2.imread(im2_file)
  60. label = cv2.imread(label_file, cv2.IMREAD_GRAYSCALE)
  61. label = label != 0
  62. label = label.astype(np.uint8)
  63. H, W, C = im1.shape
  64. tile_id = 1
  65. im1_name = osp.split(im1_file)[-1].split('.')[0]
  66. im2_name = osp.split(im2_file)[-1].split('.')[0]
  67. label_name = osp.split(label_file)[-1].split('.')[0]
  68. for h in range(0, H, stride[1]):
  69. for w in range(0, W, stride[0]):
  70. left = w
  71. upper = h
  72. right = min(w + tile_size[0], W)
  73. lower = min(h + tile_size[1], H)
  74. tile_im1 = im1[upper:lower, left:right, :]
  75. tile_im2 = im2[upper:lower, left:right, :]
  76. cv2.imwrite(
  77. osp.join(tiled_image_dir,
  78. "{}_{}.bmp".format(im1_name, tile_id)), tile_im1)
  79. cv2.imwrite(
  80. osp.join(tiled_image_dir,
  81. "{}_{}.bmp".format(im2_name, tile_id)), tile_im2)
  82. cut_label = label[upper:lower, left:right]
  83. cv2.imwrite(
  84. osp.join(tiled_anno_dir,
  85. "{}_{}.png".format(label_name, tile_id)), cut_label)
  86. mode = 'w' if i in [0, train_num] and tile_id == 1 else 'a'
  87. with open(
  88. osp.join(tiled_dataset, '{}_list.txt'.format(set_name)),
  89. mode) as f:
  90. f.write(
  91. "JPEGImages/{}_{}.bmp JPEGImages/{}_{}.bmp Annotations/{}_{}.png\n".
  92. format(im1_name, tile_id, im2_name, tile_id, label_name,
  93. tile_id))
  94. tile_id += 1
  95. # 生成labels.txt
  96. label_list = ['unchanged', 'changed']
  97. for i, label in enumerate(label_list):
  98. mode = 'w' if i == 0 else 'a'
  99. with open(osp.join(tiled_dataset, 'labels.txt'), 'a') as f:
  100. name = "{}\n".format(label) if i < len(
  101. label_list) - 1 else "{}".format(label)
  102. f.write(name)