prepara_data.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. import os
  2. import os.path as osp
  3. import numpy as np
  4. import cv2
  5. import shutil
  6. from PIL import Image
  7. import paddlex as pdx
  8. # 定义训练集切分时的滑动窗口大小和步长,格式为(W, H)
  9. train_tile_size = (512, 512)
  10. train_stride = (256, 256)
  11. # 定义验证集切分时的滑动窗口大小和步长,格式(W, H)
  12. val_tile_size = (256, 256)
  13. val_stride = (256, 256)
  14. ## 下载并解压2015 CCF大数据比赛提供的高清遥感影像
  15. #SZTAKI_AirChange_Benchmark = 'https://bj.bcebos.com/paddlex/examples/remote_sensing/datasets/ccf_remote_dataset.tar.gz'
  16. #pdx.utils.download_and_decompress(SZTAKI_AirChange_Benchmark, path='./')
  17. if not osp.exists('./dataset/JPEGImages'):
  18. os.makedirs('./dataset/JPEGImages')
  19. if not osp.exists('./dataset/Annotations'):
  20. os.makedirs('./dataset/Annotations')
  21. # 将前4张图片划分入训练集,并切分成小块之后加入到训练集中
  22. # 并生成train_list.txt
  23. train_list = {'Szada': [2, 3, 4, 5, 6, 7], 'Tiszadob': [1, 2, 4, 5]}
  24. val_list = {'Szada': [1], 'Tiszadob': [3]}
  25. all_list = [train_list, val_list]
  26. for i, data_list in enumerate(all_list):
  27. id = 0
  28. if i == 0:
  29. for key, value in data_list.items():
  30. for v in value:
  31. shutil.copyfile(
  32. "SZTAKI_AirChange_Benchmark/{}/{}/im1.bmp".format(key, v),
  33. "./dataset/JPEGImages/{}_{}_im1.bmp".format(key, v))
  34. shutil.copyfile(
  35. "SZTAKI_AirChange_Benchmark/{}/{}/im2.bmp".format(key, v),
  36. "./dataset/JPEGImages/{}_{}_im2.bmp".format(key, v))
  37. label = cv2.imread(
  38. "SZTAKI_AirChange_Benchmark/{}/{}/gt.bmp".format(key, v))
  39. label = label[:, :, 0]
  40. label = label != 0
  41. label = label.astype(np.uint8)
  42. cv2.imwrite("./dataset/Annotations/{}_{}_gt.png".format(
  43. key, v), label)
  44. id += 1
  45. mode = 'w' if id == 1 else 'a'
  46. with open('./dataset/train_list.txt', mode) as f:
  47. f.write(
  48. "JPEGImages/{}_{}_im1.bmp JPEGImages/{}_{}_im2.bmp Annotations/{}_{}_gt.png\n".
  49. format(key, v, key, v, key, v))
  50. if i == 0:
  51. stride = train_stride
  52. tile_size = train_tile_size
  53. else:
  54. stride = val_stride
  55. tile_size = val_tile_size
  56. for key, value in data_list.items():
  57. for v in value:
  58. im1 = cv2.imread("SZTAKI_AirChange_Benchmark/{}/{}/im1.bmp".format(
  59. key, v))
  60. im2 = cv2.imread("SZTAKI_AirChange_Benchmark/{}/{}/im2.bmp".format(
  61. key, v))
  62. label = cv2.imread(
  63. "SZTAKI_AirChange_Benchmark/{}/{}/gt.bmp".format(key, v))
  64. label = label[:, :, 0]
  65. label = label != 0
  66. label = label.astype(np.uint8)
  67. H, W, C = im1.shape
  68. tile_id = 1
  69. for h in range(0, H, stride[1]):
  70. for w in range(0, W, stride[0]):
  71. left = w
  72. upper = h
  73. right = min(w + tile_size[0], W)
  74. lower = min(h + tile_size[1], H)
  75. tile_im1 = im1[upper:lower, left:right, :]
  76. tile_im2 = im2[upper:lower, left:right, :]
  77. cv2.imwrite("./dataset/JPEGImages/{}_{}_{}_im1.bmp".format(
  78. key, v, tile_id), tile_im1)
  79. cv2.imwrite("./dataset/JPEGImages/{}_{}_{}_im2.bmp".format(
  80. key, v, tile_id), tile_im2)
  81. cut_label = label[upper:lower, left:right]
  82. cv2.imwrite("./dataset/Annotations/{}_{}_{}_gt.png".format(
  83. key, v, tile_id), cut_label)
  84. with open('./dataset/{}_list.txt'.format(
  85. 'train' if i == 0 else 'val'), 'a') as f:
  86. f.write(
  87. "JPEGImages/{}_{}_{}_im1.bmp JPEGImages/{}_{}_{}_im2.bmp Annotations/{}_{}_{}_gt.png\n".
  88. format(key, v, tile_id, key, v, tile_id, key, v,
  89. tile_id))
  90. tile_id += 1
  91. # 生成labels.txt
  92. label_list = ['unchanged', 'changed']
  93. for i, label in enumerate(label_list):
  94. mode = 'w' if i == 0 else 'a'
  95. with open('./dataset/labels.txt', 'a') as f:
  96. name = "{}\n".format(label) if i < len(
  97. label_list) - 1 else "{}".format(label)
  98. f.write(name)