Browse Source

Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleX into develop_ui

FlyingQianMM 4 years ago
parent
commit
afdafa89c3

+ 1 - 1
docs/apis/models/detection.md

@@ -264,7 +264,7 @@ paddlex.det.PPYOLOTiny(num_classes=80, backbone='MobileNetV3', anchors=[[10, 15]
 > - quant_aware_train 在线量化接口说明同 [PPYOLOv2模型quant_aware_train接口](#quant_aware_train)
 
 
-## <h2 id="4">paddlex.det.YOLOv3</h2>
+## <h2 id="4">paddlex.det.PicoDet</h2>
 
 ```python
 paddlex.det.PicoDet(num_classes=80, backbone='ESNet_m', nms_score_threshold=.025, nms_topk=1000, nms_keep_topk=100, nms_iou_threshold=.6)

+ 140 - 13
paddlex/cv/models/detector.py

@@ -112,25 +112,66 @@ class BaseDetector(BaseModel):
 
         return outputs
 
-    def default_optimizer(self, parameters, learning_rate, warmup_steps,
-                          warmup_start_lr, lr_decay_epochs, lr_decay_gamma,
-                          num_steps_each_epoch):
-        boundaries = [b * num_steps_each_epoch for b in lr_decay_epochs]
-        values = [(lr_decay_gamma**i) * learning_rate
-                  for i in range(len(lr_decay_epochs) + 1)]
-        scheduler = paddle.optimizer.lr.PiecewiseDecay(
-            boundaries=boundaries, values=values)
-        if warmup_steps > 0:
-            if warmup_steps > lr_decay_epochs[0] * num_steps_each_epoch:
+    def default_optimizer(self,
+                          parameters,
+                          learning_rate,
+                          warmup_steps,
+                          warmup_start_lr,
+                          lr_decay_epochs,
+                          lr_decay_gamma,
+                          num_steps_each_epoch,
+                          reg_coeff=1e-04,
+                          scheduler='Piecewise',
+                          num_epochs=None):
+        if scheduler.lower() == 'piecewise':
+            if warmup_steps > 0 and warmup_steps > lr_decay_epochs[
+                    0] * num_steps_each_epoch:
+                logging.error(
+                    "In function train(), parameters must satisfy: "
+                    "warmup_steps <= lr_decay_epochs[0] * num_samples_in_train_dataset. "
+                    "See this doc for more information: "
+                    "https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/parameters.md",
+                    exit=False)
+                logging.error(
+                    "Either `warmup_steps` be less than {} or lr_decay_epochs[0] be greater than {} "
+                    "must be satisfied, please modify 'warmup_steps' or 'lr_decay_epochs' in train function".
+                    format(lr_decay_epochs[0] * num_steps_each_epoch,
+                           warmup_steps // num_steps_each_epoch),
+                    exit=True)
+            boundaries = [b * num_steps_each_epoch for b in lr_decay_epochs]
+            values = [(lr_decay_gamma**i) * learning_rate
+                      for i in range(len(lr_decay_epochs) + 1)]
+            scheduler = paddle.optimizer.lr.PiecewiseDecay(boundaries, values)
+        elif scheduler.lower() == 'cosine':
+            if num_epochs is None:
                 logging.error(
-                    "In function train(), parameters should satisfy: "
-                    "warmup_steps <= lr_decay_epochs[0]*num_samples_in_train_dataset",
+                    "`num_epochs` must be set while using cosine annealing decay scheduler, but received {}".
+                    format(num_epochs),
                     exit=False)
+            if warmup_steps > 0 and warmup_steps > num_epochs * num_steps_each_epoch:
                 logging.error(
+                    "In function train(), parameters must satisfy: "
+                    "warmup_steps <= num_epochs * num_samples_in_train_dataset. "
                     "See this doc for more information: "
                     "https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/parameters.md",
                     exit=False)
+                logging.error(
+                    "`warmup_steps` must be less than the total number of steps({}), "
+                    "please modify 'num_epochs' or 'warmup_steps' in train function".
+                    format(num_epochs * num_steps_each_epoch),
+                    exit=True)
+            T_max = num_epochs * num_steps_each_epoch - warmup_steps
+            scheduler = paddle.optimizer.lr.CosineAnnealingDecay(
+                learning_rate=learning_rate,
+                T_max=T_max,
+                eta_min=0.0,
+                last_epoch=-1)
+        else:
+            logging.error(
+                "Invalid learning rate scheduler: {}!".format(scheduler),
+                exit=True)
 
+        if warmup_steps > 0:
             scheduler = paddle.optimizer.lr.LinearWarmup(
                 learning_rate=scheduler,
                 warmup_steps=warmup_steps,
@@ -139,7 +180,7 @@ class BaseDetector(BaseModel):
         optimizer = paddle.optimizer.Momentum(
             scheduler,
             momentum=.9,
-            weight_decay=paddle.regularizer.L2Decay(coeff=1e-04),
+            weight_decay=paddle.regularizer.L2Decay(coeff=reg_coeff),
             parameters=parameters)
         return optimizer
 
@@ -782,6 +823,92 @@ class PicoDet(BaseDetector):
         self.fixed_input_shape = image_shape
         return self._define_input_spec(image_shape)
 
+    def train(self,
+              num_epochs,
+              train_dataset,
+              train_batch_size=64,
+              eval_dataset=None,
+              optimizer=None,
+              save_interval_epochs=1,
+              log_interval_steps=10,
+              save_dir='output',
+              pretrain_weights='IMAGENET',
+              learning_rate=.001,
+              warmup_steps=0,
+              warmup_start_lr=0.0,
+              lr_decay_epochs=(216, 243),
+              lr_decay_gamma=0.1,
+              metric=None,
+              use_ema=False,
+              early_stop=False,
+              early_stop_patience=5,
+              use_vdl=True,
+              resume_checkpoint=None):
+        """
+        Train the model.
+        Args:
+            num_epochs(int): The number of epochs.
+            train_dataset(paddlex.dataset): Training dataset.
+            train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 64.
+            eval_dataset(paddlex.dataset, optional):
+                Evaluation dataset. If None, the model will not be evaluated during training process. Defaults to None.
+            optimizer(paddle.optimizer.Optimizer or None, optional):
+                Optimizer used for training. If None, a default optimizer is used. Defaults to None.
+            save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
+            log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
+            save_dir(str, optional): Directory to save the model. Defaults to 'output'.
+            pretrain_weights(str or None, optional):
+                None or name/path of pretrained weights. If None, no pretrained weights will be loaded. Defaults to 'IMAGENET'.
+            learning_rate(float, optional): Learning rate for training. Defaults to .001.
+            warmup_steps(int, optional): The number of steps of warm-up training. Defaults to 0.
+            warmup_start_lr(float, optional): Start learning rate of warm-up training. Defaults to 0..
+            lr_decay_epochs(list or tuple, optional): Epoch milestones for learning rate decay. Defaults to (216, 243).
+            lr_decay_gamma(float, optional): Gamma coefficient of learning rate decay. Defaults to .1.
+            metric({'VOC', 'COCO', None}, optional):
+                Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
+            use_ema(bool, optional): Whether to use exponential moving average strategy. Defaults to False.
+            early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
+            early_stop_patience(int, optional): Early stop patience. Defaults to 5.
+            use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
+            resume_checkpoint(str or None, optional): The path of the checkpoint to resume training from.
+                If None, no training checkpoint will be resumed. At most one of `resume_checkpoint` and
+                `pretrain_weights` can be set simultaneously. Defaults to None.
+        """
+        if optimizer is None:
+            num_steps_each_epoch = len(train_dataset) // train_batch_size
+            optimizer = self.default_optimizer(
+                parameters=self.net.parameters(),
+                learning_rate=learning_rate,
+                warmup_steps=warmup_steps,
+                warmup_start_lr=warmup_start_lr,
+                lr_decay_epochs=lr_decay_epochs,
+                lr_decay_gamma=lr_decay_gamma,
+                num_steps_each_epoch=num_steps_each_epoch,
+                reg_coeff=4e-05,
+                scheduler='Cosine',
+                num_epochs=num_epochs)
+        super(PicoDet, self).train(
+            num_epochs=num_epochs,
+            train_dataset=train_dataset,
+            train_batch_size=train_batch_size,
+            eval_dataset=eval_dataset,
+            optimizer=optimizer,
+            save_interval_epochs=save_interval_epochs,
+            log_interval_steps=log_interval_steps,
+            save_dir=save_dir,
+            pretrain_weights=pretrain_weights,
+            learning_rate=learning_rate,
+            warmup_steps=warmup_steps,
+            warmup_start_lr=warmup_start_lr,
+            lr_decay_epochs=lr_decay_epochs,
+            lr_decay_gamma=lr_decay_gamma,
+            metric=metric,
+            use_ema=use_ema,
+            early_stop=early_stop,
+            early_stop_patience=early_stop_patience,
+            use_vdl=use_vdl,
+            resume_checkpoint=resume_checkpoint)
+
 
 class YOLOv3(BaseDetector):
     def __init__(self,

+ 1 - 1
paddlex/paddleseg/cvlibs/param_init.py

@@ -58,7 +58,7 @@ def normal_init(param, **kwargs):
 
 
 def kaiming_normal_init(param, **kwargs):
-    """
+    r"""
     Initialize the input tensor with Kaiming Normal initialization.
 
     This function implements the `param` initialization from the paper

+ 1 - 1
paddlex/paddleseg/models/losses/binary_cross_entropy_loss.py

@@ -21,7 +21,7 @@ from paddlex.paddleseg.cvlibs import manager
 
 @manager.LOSSES.add_component
 class BCELoss(nn.Layer):
-    """
+    r"""
     This operator combines the sigmoid layer and the :ref:`api_nn_loss_BCELoss` layer.
     Also, we can see it as the combine of ``sigmoid_cross_entropy_with_logits``
     layer and some reduce operations.

+ 3 - 3
paddlex/paddleseg/models/losses/lovasz_loss.py

@@ -41,7 +41,7 @@ class LovaszSoftmaxLoss(nn.Layer):
         self.classes = classes
 
     def forward(self, logits, labels):
-        """
+        r"""
         Forward computation.
 
         Args:
@@ -68,7 +68,7 @@ class LovaszHingeLoss(nn.Layer):
         self.ignore_index = ignore_index
 
     def forward(self, logits, labels):
-        """
+        r"""
         Forward computation.
 
         Args:
@@ -111,7 +111,7 @@ def binary_channel_to_unary(logits, eps=1e-9):
 
 
 def lovasz_hinge_flat(logits, labels):
-    """
+    r"""
     Binary Lovasz hinge loss.
 
     Args:

+ 2 - 1
setup.py

@@ -39,7 +39,8 @@ setuptools.setup(
     install_requires=[
         "pycocotools", 'pyyaml', 'colorama', 'tqdm', 'paddleslim==2.2.1',
         'visualdl>=2.2.2', 'shapely>=1.7.0', 'opencv-python', 'scipy', 'lap',
-        'motmetrics', 'scikit-learn==0.23.2', 'chardet', 'flask_cors'
+        'motmetrics', 'scikit-learn==0.23.2', 'chardet', 'flask_cors',
+        'openpyxl'
     ],
     classifiers=[
         "Programming Language :: Python :: 3",

+ 5 - 5
tutorials/train/object_detection/picodet.py

@@ -44,16 +44,16 @@ model = pdx.det.PicoDet(num_classes=num_classes, backbone='ESNet_l')
 # API说明:https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/apis/models/detection.md
 # 各参数介绍与调整说明:https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/parameters.md
 model.train(
-    num_epochs=40,
+    num_epochs=20,
     train_dataset=train_dataset,
     train_batch_size=14,
     eval_dataset=eval_dataset,
     pretrain_weights='COCO',
     learning_rate=.05,
-    warmup_steps=30,
-    warmup_start_lr=0.0,
-    save_interval_epochs=2,
-    lr_decay_epochs=[8, 13],
+    warmup_steps=24,
+    warmup_start_lr=0.005,
+    save_interval_epochs=1,
+    lr_decay_epochs=[6, 8, 11],
     use_ema=True,
     save_dir='output/picodet_esnet_l',
     use_vdl=True)