unet_train.py 1.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. import os
  2. os.environ['CUDA_VISIBLE_DEVICES'] = '0'
  3. import paddlex as pdx
  4. from paddlex.seg import transforms
  5. optic_dataset = 'https://bj.bcebos.com/paddlex/datasets/optic_disc_seg.tar.gz'
  6. pdx.utils.download_and_decompress(optic_dataset, path='./')
  7. train_transforms = transforms.Compose([
  8. transforms.RandomHorizontalFlip(), transforms.ResizeRangeScaling(),
  9. transforms.RandomPaddingCrop(crop_size=512), transforms.Normalize()
  10. ])
  11. eval_transforms = transforms.Compose([
  12. transforms.ResizeByLong(long_size=512), transforms.Padding(target_size=512),
  13. transforms.Normalize()
  14. ])
  15. train_dataset = pdx.datasets.SegDataset(
  16. data_dir='optic_disc_seg',
  17. file_list='optic_disc_seg/train_list.txt',
  18. label_list='optic_disc_seg/labels.txt',
  19. transforms=train_transforms,
  20. shuffle=True)
  21. eval_dataset = pdx.datasets.SegDataset(
  22. data_dir='optic_disc_seg',
  23. file_list='optic_disc_seg/val_list.txt',
  24. label_list='optic_disc_seg/labels.txt',
  25. transforms=eval_transforms)
  26. num_classes = len(train_dataset.labels)
  27. model = pdx.seg.UNet(num_classes=num_classes)
  28. model.train(
  29. num_epochs=20,
  30. train_dataset=train_dataset,
  31. train_batch_size=4,
  32. eval_dataset=eval_dataset,
  33. learning_rate=0.01,
  34. save_dir='output/unet',
  35. use_vdl=True)