predict.py 726 B

123456789101112131415161718192021
  1. import numpy as np
  2. from PIL import Image
  3. import paddlex as pdx
  4. model_dir = "saved_model/remote_sensing_unet/best_model/"
  5. img_file = "dataset/remote_sensing_seg/data/LC80150242014146LGN00_23_data.tif"
  6. label_file = "dataset/remote_sensing_seg/mask/LC80150242014146LGN00_23_mask.png"
  7. color = [255, 255, 255, 0, 0, 0, 255, 255, 0, 255, 0, 0, 150, 150, 150]
  8. # 预测并可视化预测结果
  9. model = pdx.load_model(model_dir)
  10. pred = model.predict(img_file)
  11. pdx.seg.visualize(
  12. img_file, pred, weight=0., save_dir='./output/pred', color=color)
  13. # 可视化标注文件
  14. label = np.asarray(Image.open(label_file))
  15. pred = {'label_map': label}
  16. pdx.seg.visualize(
  17. img_file, pred, weight=0., save_dir='./output/gt', color=color)