FlyingQianMM 5 ani în urmă
părinte
comite
3fa48cb303

+ 4 - 4
paddlex/cv/nets/detection/faster_rcnn.py

@@ -122,8 +122,8 @@ class FasterRCNN(object):
                     test_nms_thresh=test_nms_thresh,
                     rpn_cls_loss=rpn_cls_loss,
                     rpn_focal_loss_alpha=rpn_focal_loss_alpha,
-                    rpn_focal_loss_gamma=rpn_focal_loss_gamma,
-                    use_random=False)
+                    rpn_focal_loss_gamma=rpn_focal_loss_gamma)
+                #use_random=False)
             else:
                 rpn_head = FPNRPNHead(
                     anchor_start_size=anchor_sizes[0],
@@ -143,8 +143,8 @@ class FasterRCNN(object):
                     test_nms_thresh=test_nms_thresh,
                     rpn_cls_loss=rpn_cls_loss,
                     rpn_focal_loss_alpha=rpn_focal_loss_alpha,
-                    rpn_focal_loss_gamma=rpn_focal_loss_gamma,
-                    use_random=False)
+                    rpn_focal_loss_gamma=rpn_focal_loss_gamma)
+                #use_random=False)
         self.rpn_head = rpn_head
         if roi_extractor is None:
             if self.fpn is None:

+ 63 - 0
tutorials/train/object_detection/guang_2_r2_dcn_libra.py

@@ -0,0 +1,63 @@
+# 环境变量配置,用于控制是否使用GPU
+# 说明文档:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html#gpu
+import os
+os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
+
+from paddlex.det import transforms
+import paddlex as pdx
+
+# API说明 https://paddlex.readthedocs.io/zh_CN/develop/apis/transforms/det_transforms.html
+train_transforms = transforms.Compose([
+    transforms.RandomHorizontalFlip(), transforms.Normalize(),
+    transforms.ResizeByShort(
+        short_size=800, max_size=1333), transforms.Padding(coarsest_stride=32)
+])
+
+eval_transforms = transforms.Compose([
+    transforms.Normalize(),
+    transforms.ResizeByShort(
+        short_size=800, max_size=1333),
+    transforms.Padding(coarsest_stride=32),
+])
+
+# 定义训练和验证所用的数据集
+# API说明:https://paddlex.readthedocs.io/zh_CN/develop/apis/datasets.html#paddlex-datasets-vocdetection
+train_dataset = pdx.datasets.VOCDetection(
+    data_dir='dataset',
+    file_list='dataset/train_list.txt',
+    label_list='dataset/labels.txt',
+    transforms=train_transforms,
+    num_workers=8,
+    shuffle=True)
+eval_dataset = pdx.datasets.VOCDetection(
+    data_dir='dataset',
+    file_list='dataset/val_list.txt',
+    label_list='dataset/labels.txt',
+    num_workers=8,
+    transforms=eval_transforms)
+
+# 初始化模型,并进行训练
+# 可使用VisualDL查看训练指标,参考https://paddlex.readthedocs.io/zh_CN/develop/train/visualdl.html
+# num_classes 需要设置为包含背景类的类别数,即: 目标类别数量 + 1
+num_classes = len(train_dataset.labels) + 1
+
+# API说明: https://paddlex.readthedocs.io/zh_CN/develop/apis/models/detection.html#paddlex-det-fasterrcnn
+model = pdx.det.FasterRCNN(
+    num_classes=num_classes,
+    backbone='ResNet50_vd',
+    with_dcn=True,
+    bbox_assigner='LibraBBoxAssigner')
+
+# API说明: https://paddlex.readthedocs.io/zh_CN/develop/apis/models/detection.html#id1
+# 各参数介绍与调整说明:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html
+model.train(
+    num_epochs=55,
+    train_dataset=train_dataset,
+    train_batch_size=8,
+    eval_dataset=eval_dataset,
+    learning_rate=0.01,
+    lr_decay_epochs=[40, 50],
+    warmup_start_lr=0.001,
+    pretrain_weights='ResNet50_vd_ssld_pretrained',
+    save_dir='output/guan_2_r3_dcn_libra',
+    use_vdl=False)