Browse Source

add cosine learning rate method for clas

FlyingQianMM 4 years ago
parent
commit
7d392e5078
1 changed files with 62 additions and 25 deletions
  1. 62 25
      paddlex/cv/models/classifier.py

+ 62 - 25
paddlex/cv/models/classifier.py

@@ -144,40 +144,77 @@ class BaseClassifier(BaseModel):
 
         return outputs
 
-    def default_optimizer(self, parameters, learning_rate, warmup_steps,
-                          warmup_start_lr, lr_decay_epochs, lr_decay_gamma,
-                          num_steps_each_epoch):
-        boundaries = [b * num_steps_each_epoch for b in lr_decay_epochs]
-        values = [
-            learning_rate * (lr_decay_gamma**i)
-            for i in range(len(lr_decay_epochs) + 1)
-        ]
-        scheduler = paddle.optimizer.lr.PiecewiseDecay(boundaries, values)
+    def default_optimizer(self,
+                          parameters,
+                          learning_rate,
+                          warmup_steps,
+                          warmup_start_lr,
+                          lr_decay_epochs,
+                          lr_decay_gamma,
+                          num_steps_each_epoch,
+                          decay_coff=1e-04,
+                          lr_method='Linear',
+                          num_epochs=None):
+        if warmup_steps > 0:
+            if lr_method == 'Linear':
+                if warmup_steps > lr_decay_epochs[0] * num_steps_each_epoch:
+                    logging.error(
+                        "In function train(), parameters should satisfy: "
+                        "warmup_steps <= lr_decay_epochs[0]*num_samples_in_train_dataset.",
+                        exit=False)
+                    logging.error(
+                        "See this doc for more information: "
+                        "https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/parameters.md",
+                        exit=False)
+                    logging.error(
+                        "warmup_steps should be less than {} or lr_decay_epochs[0] greater than {}, "
+                        "please modify 'lr_decay_epochs' or 'warmup_steps' in train function".
+                        format(lr_decay_epochs[0] * num_steps_each_epoch,
+                               warmup_steps // num_steps_each_epoch))
+            elif lr_method == 'Cosine':
+                if num_epochs is None:
+                    logging.error(
+                        "num_epochs must be set when using cosine learning rate method, but received is {}".
+                        format(num_epochs),
+                        exit=False)
+                if warmup_steps > num_epochs * num_steps_each_epoch:
+                    logging.error(
+                        "In function train(), parameters should satisfy: "
+                        "warmup_steps <= num_epochs*num_samples_in_train_dataset.",
+                        exit=False)
+                    logging.error(
+                        "See this doc for more information: "
+                        "https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/parameters.md",
+                        exit=False)
+                    logging.error(
+                        "warmup_steps should be less than {}, "
+                        "please modify 'num_epochs' or 'warmup_steps' in train function".
+                        format(num_epochs * num_steps_each_epoch))
+        if lr_method == 'Linear':
+            boundaries = [b * num_steps_each_epoch for b in lr_decay_epochs]
+            values = [
+                learning_rate * (lr_decay_gamma**i)
+                for i in range(len(lr_decay_epochs) + 1)
+            ]
+            scheduler = paddle.optimizer.lr.PiecewiseDecay(boundaries, values)
+        elif lr_method == 'Cosine':
+            T_max = num_epochs * num_steps_each_epoch - warmup_steps
+            scheduler = paddle.lr.CosineAnnealingDecay(
+                learning_rate=learning_rate,
+                T_max=T_max,
+                eta_min=0.0,
+                last_epoch=-1)
         if warmup_steps > 0:
-            if warmup_steps > lr_decay_epochs[0] * num_steps_each_epoch:
-                logging.error(
-                    "In function train(), parameters should satisfy: "
-                    "warmup_steps <= lr_decay_epochs[0]*num_samples_in_train_dataset.",
-                    exit=False)
-                logging.error(
-                    "See this doc for more information: "
-                    "https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/parameters.md",
-                    exit=False)
-                logging.error(
-                    "warmup_steps should less than {} or lr_decay_epochs[0] greater than {}, "
-                    "please modify 'lr_decay_epochs' or 'warmup_steps' in train function".
-                    format(lr_decay_epochs[0] * num_steps_each_epoch,
-                           warmup_steps // num_steps_each_epoch))
-
             scheduler = paddle.optimizer.lr.LinearWarmup(
                 learning_rate=scheduler,
                 warmup_steps=warmup_steps,
                 start_lr=warmup_start_lr,
                 end_lr=learning_rate)
+
         optimizer = paddle.optimizer.Momentum(
             scheduler,
             momentum=.9,
-            weight_decay=paddle.regularizer.L2Decay(coeff=1e-04),
+            weight_decay=paddle.regularizer.L2Decay(coeff=decay_coff),
             parameters=parameters)
         return optimizer