prepara_data.py 4.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. import os
  2. import os.path as osp
  3. import numpy as np
  4. import cv2
  5. import shutil
  6. from PIL import Image
  7. import paddlex as pdx
  8. # 定义训练集切分时的滑动窗口大小和步长,格式为(W, H)
  9. train_tile_size = (1024, 1024)
  10. train_stride = (512, 512)
  11. # 定义验证集切分时的滑动窗口大小和步长,格式(W, H)
  12. val_tile_size = (769, 769)
  13. val_stride = (769, 769)
  14. # 下载并解压2015 CCF大数据比赛提供的高清遥感影像
  15. ccf_remote_dataset = 'https://bj.bcebos.com/paddlex/examples/remote_sensing/datasets/ccf_remote_dataset.tar.gz'
  16. pdx.utils.download_and_decompress(ccf_remote_dataset, path='./')
  17. if not osp.exists('./dataset/JPEGImages'):
  18. os.makedirs('./dataset/JPEGImages')
  19. if not osp.exists('./dataset/Annotations'):
  20. os.makedirs('./dataset/Annotations')
  21. # 将前4张图片划分入训练集,并切分成小块之后加入到训练集中
  22. # 并生成train_list.txt
  23. for train_id in range(1, 5):
  24. shutil.copyfile("ccf_remote_dataset/{}.png".format(train_id),
  25. "./dataset/JPEGImages/{}.png".format(train_id))
  26. shutil.copyfile("ccf_remote_dataset/{}_class.png".format(train_id),
  27. "./dataset/Annotations/{}_class.png".format(train_id))
  28. mode = 'w' if train_id == 1 else 'a'
  29. with open('./dataset/train_list.txt', mode) as f:
  30. f.write("JPEGImages/{}.png Annotations/{}_class.png\n".format(
  31. train_id, train_id))
  32. for train_id in range(1, 5):
  33. image = cv2.imread('ccf_remote_dataset/{}.png'.format(train_id))
  34. label = Image.open('ccf_remote_dataset/{}_class.png'.format(train_id))
  35. H, W, C = image.shape
  36. train_tile_id = 1
  37. for h in range(0, H, train_stride[1]):
  38. for w in range(0, W, train_stride[0]):
  39. left = w
  40. upper = h
  41. right = min(w + train_tile_size[0] * 2, W)
  42. lower = min(h + train_tile_size[1] * 2, H)
  43. tile_image = image[upper:lower, left:right, :]
  44. cv2.imwrite("./dataset/JPEGImages/{}_{}.png".format(
  45. train_id, train_tile_id), tile_image)
  46. cut_label = label.crop((left, upper, right, lower))
  47. cut_label.save("./dataset/Annotations/{}_class_{}.png".format(
  48. train_id, train_tile_id))
  49. with open('./dataset/train_list.txt', 'a') as f:
  50. f.write("JPEGImages/{}_{}.png Annotations/{}_class_{}.png\n".
  51. format(train_id, train_tile_id, train_id,
  52. train_tile_id))
  53. train_tile_id += 1
  54. # 将第5张图片切分成小块之后加入到验证集中
  55. val_id = 5
  56. val_tile_id = 1
  57. shutil.copyfile("ccf_remote_dataset/{}.png".format(val_id),
  58. "./dataset/JPEGImages/{}.png".format(val_id))
  59. shutil.copyfile("ccf_remote_dataset/{}_class.png".format(val_id),
  60. "./dataset/Annotations/{}_class.png".format(val_id))
  61. image = cv2.imread('ccf_remote_dataset/{}.png'.format(val_id))
  62. label = Image.open('ccf_remote_dataset/{}_class.png'.format(val_id))
  63. H, W, C = image.shape
  64. for h in range(0, H, val_stride[1]):
  65. for w in range(0, W, val_stride[0]):
  66. left = w
  67. upper = h
  68. right = min(w + val_tile_size[0], W)
  69. lower = min(h + val_tile_size[1], H)
  70. cut_image = image[upper:lower, left:right, :]
  71. cv2.imwrite("./dataset/JPEGImages/{}_{}.png".format(
  72. val_id, val_tile_id), cut_image)
  73. cut_label = label.crop((left, upper, right, lower))
  74. cut_label.save("./dataset/Annotations/{}_class_{}.png".format(
  75. val_id, val_tile_id))
  76. mode = 'w' if val_tile_id == 1 else 'a'
  77. with open('./dataset/val_list.txt', mode) as f:
  78. f.write("JPEGImages/{}_{}.png Annotations/{}_class_{}.png\n".
  79. format(val_id, val_tile_id, val_id, val_tile_id))
  80. val_tile_id += 1
  81. # 生成labels.txt
  82. label_list = ['background', 'vegetation', 'road', 'building', 'water']
  83. for i, label in enumerate(label_list):
  84. mode = 'w' if i == 0 else 'a'
  85. with open('./dataset/labels.txt', 'a') as f:
  86. name = "{}\n".format(label) if i < len(
  87. label_list) - 1 else "{}".format(label)
  88. f.write(name)