Bläddra i källkod

change train code

FlyingQianMM 5 år sedan
förälder
incheckning
377ea08ecb
1 ändrade filer med 22 tillägg och 8 borttagningar
  1. 22 8
      examples/change_detection/unet.py

+ 22 - 8
examples/change_detection/unet.py

@@ -1,7 +1,7 @@
 # 环境变量配置,用于控制是否使用GPU
 # 说明文档:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html#gpu
 import os
-os.environ['CUDA_VISIBLE_DEVICES'] = '1'
+os.environ['CUDA_VISIBLE_DEVICES'] = '2'
 
 import paddlex as pdx
 from paddlex.seg import transforms
@@ -9,16 +9,24 @@ from paddlex.seg import transforms
 # 定义训练和验证时的transforms
 # API说明 https://paddlex.readthedocs.io/zh_CN/develop/apis/transforms/seg_transforms.html
 train_transforms = transforms.Compose([
+    #transforms.ResizeStepScaling(
+    #    min_scale_factor=0.5,
+    #    max_scale_factor=2.,
+    #    scale_step_size=0.25),
     transforms.RandomPaddingCrop(
-        crop_size=256,
-        im_padding_value=[127.5] * 6), transforms.RandomHorizontalFlip(),
-    transforms.RandomVerticalFlip(), transforms.Normalize(
+        crop_size=769, im_padding_value=[127.5] * 6),
+    #transforms.ResizeByLong(long_size=512),
+    #transforms.Padding(
+    #    target_size=512, im_padding_value=[127.5] * 6),
+    transforms.RandomHorizontalFlip(),
+    transforms.RandomVerticalFlip(),
+    transforms.Normalize(
         mean=[0.5] * 6, std=[0.5] * 6, min_val=[0] * 6, max_val=[255] * 6)
 ])
 
 eval_transforms = transforms.Compose([
     transforms.Padding(
-        target_size=256, im_padding_value=[127.5] * 6), transforms.Normalize(
+        target_size=1000, im_padding_value=[127.5] * 6), transforms.Normalize(
             mean=[0.5] * 6, std=[0.5] * 6, min_val=[0] * 6, max_val=[255] * 6)
 ])
 
@@ -29,11 +37,13 @@ train_dataset = pdx.datasets.ChangeDetDataset(
     file_list='dataset/train_list.txt',
     label_list='dataset/labels.txt',
     transforms=train_transforms,
+    num_workers=4,
     shuffle=True)
 eval_dataset = pdx.datasets.ChangeDetDataset(
     data_dir='dataset',
     file_list='dataset/val_list.txt',
     label_list='dataset/labels.txt',
+    num_workers=4,
     transforms=eval_transforms)
 
 # 初始化模型,并进行训练
@@ -41,17 +51,21 @@ eval_dataset = pdx.datasets.ChangeDetDataset(
 num_classes = len(train_dataset.labels)
 
 # API说明:https://paddlex.readthedocs.io/zh_CN/develop/apis/models/semantic_segmentation.html#paddlex-seg-deeplabv3p
-model = pdx.seg.UNet(num_classes=num_classes, input_channel=6)
+model = pdx.seg.UNet(
+    num_classes=num_classes,
+    input_channel=6,
+    use_bce_loss=True,
+    use_dice_loss=True)
 
 # API说明:https://paddlex.readthedocs.io/zh_CN/develop/apis/models/semantic_segmentation.html#train
 # 各参数介绍与调整说明:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html
 model.train(
     num_epochs=400,
     train_dataset=train_dataset,
-    train_batch_size=16,
+    train_batch_size=4,
     eval_dataset=eval_dataset,
     learning_rate=0.01,
     save_interval_epochs=10,
     pretrain_weights='CITYSCAPES',
-    save_dir='output/unet',
+    save_dir='output/unet_3',
     use_vdl=True)