will-jl944 преди 4 години
родител
ревизия
0fe92f3b26
променени са 1 файла, в които са добавени 146 реда и са изтрити 20 реда
  1. 146 20
      paddlex/cv/models/classifier.py

+ 146 - 20
paddlex/cv/models/classifier.py

@@ -144,40 +144,78 @@ class BaseClassifier(BaseModel):
 
         return outputs
 
-    def default_optimizer(self, parameters, learning_rate, warmup_steps,
-                          warmup_start_lr, lr_decay_epochs, lr_decay_gamma,
-                          num_steps_each_epoch):
-        boundaries = [b * num_steps_each_epoch for b in lr_decay_epochs]
-        values = [
-            learning_rate * (lr_decay_gamma**i)
-            for i in range(len(lr_decay_epochs) + 1)
-        ]
-        scheduler = paddle.optimizer.lr.PiecewiseDecay(boundaries, values)
-        if warmup_steps > 0:
-            if warmup_steps > lr_decay_epochs[0] * num_steps_each_epoch:
+    def default_optimizer(self,
+                          parameters,
+                          learning_rate,
+                          warmup_steps,
+                          warmup_start_lr,
+                          lr_decay_epochs,
+                          lr_decay_gamma,
+                          num_steps_each_epoch,
+                          reg_coeff=1e-04,
+                          scheduler='Piecewise',
+                          num_epochs=None):
+        if scheduler.lower() == 'piecewise':
+            if warmup_steps > 0 and warmup_steps > lr_decay_epochs[
+                    0] * num_steps_each_epoch:
                 logging.error(
-                    "In function train(), parameters should satisfy: "
-                    "warmup_steps <= lr_decay_epochs[0]*num_samples_in_train_dataset.",
+                    "In function train(), parameters must satisfy: "
+                    "warmup_steps <= lr_decay_epochs[0] * num_samples_in_train_dataset. "
+                    "See this doc for more information: "
+                    "https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/parameters.md",
                     exit=False)
                 logging.error(
+                    "Either `warmup_steps` be less than {} or lr_decay_epochs[0] be greater than {} "
+                    "must be satisfied, please modify 'warmup_steps' or 'lr_decay_epochs' in train function".
+                    format(lr_decay_epochs[0] * num_steps_each_epoch,
+                           warmup_steps // num_steps_each_epoch),
+                    exit=True)
+            boundaries = [b * num_steps_each_epoch for b in lr_decay_epochs]
+            values = [
+                learning_rate * (lr_decay_gamma**i)
+                for i in range(len(lr_decay_epochs) + 1)
+            ]
+            scheduler = paddle.optimizer.lr.PiecewiseDecay(boundaries, values)
+        elif scheduler.lower() == 'cosine':
+            if num_epochs is None:
+                logging.error(
+                    "`num_epochs` must be set while using cosine annealing decay scheduler, but received {}".
+                    format(num_epochs),
+                    exit=False)
+            if warmup_steps > 0 and warmup_steps > num_epochs * num_steps_each_epoch:
+                logging.error(
+                    "In function train(), parameters must satisfy: "
+                    "warmup_steps <= num_epochs * num_samples_in_train_dataset. "
                     "See this doc for more information: "
                     "https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/parameters.md",
                     exit=False)
                 logging.error(
-                    "warmup_steps should less than {} or lr_decay_epochs[0] greater than {}, "
-                    "please modify 'lr_decay_epochs' or 'warmup_steps' in train function".
-                    format(lr_decay_epochs[0] * num_steps_each_epoch,
-                           warmup_steps // num_steps_each_epoch))
+                    "`warmup_steps` must be less than the total number of steps({}), "
+                    "please modify 'num_epochs' or 'warmup_steps' in train function".
+                    format(num_epochs * num_steps_each_epoch),
+                    exit=True)
+            T_max = num_epochs * num_steps_each_epoch - warmup_steps
+            scheduler = paddle.optimizer.lr.CosineAnnealingDecay(
+                learning_rate=learning_rate,
+                T_max=T_max,
+                eta_min=0.0,
+                last_epoch=-1)
+        else:
+            logging.error(
+                "Invalid learning rate scheduler: {}!".format(scheduler),
+                exit=True)
 
+        if warmup_steps > 0:
             scheduler = paddle.optimizer.lr.LinearWarmup(
                 learning_rate=scheduler,
                 warmup_steps=warmup_steps,
                 start_lr=warmup_start_lr,
                 end_lr=learning_rate)
+
         optimizer = paddle.optimizer.Momentum(
             scheduler,
             momentum=.9,
-            weight_decay=paddle.regularizer.L2Decay(coeff=1e-04),
+            weight_decay=paddle.regularizer.L2Decay(coeff=reg_coeff),
             parameters=parameters)
         return optimizer
 
@@ -849,8 +887,96 @@ class PPLCNet(BaseClassifier):
         super(PPLCNet, self).__init__(
             model_name=model_name, num_classes=num_classes, **params)
 
+    def train(self,
+              num_epochs,
+              train_dataset,
+              train_batch_size=64,
+              eval_dataset=None,
+              optimizer=None,
+              save_interval_epochs=1,
+              log_interval_steps=10,
+              save_dir='output',
+              pretrain_weights='IMAGENET',
+              learning_rate=.1,
+              warmup_steps=0,
+              warmup_start_lr=0.0,
+              lr_decay_epochs=(30, 60, 90),
+              lr_decay_gamma=0.1,
+              label_smoothing=None,
+              early_stop=False,
+              early_stop_patience=5,
+              use_vdl=True,
+              resume_checkpoint=None):
+        """
+        Train the model.
+        Args:
+            num_epochs(int): The number of epochs.
+            train_dataset(paddlex.dataset): Training dataset.
+            train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 64.
+            eval_dataset(paddlex.dataset, optional):
+                Evaluation dataset. If None, the model will not be evaluated during training process. Defaults to None.
+            optimizer(paddle.optimizer.Optimizer or None, optional):
+                Optimizer used for training. If None, a default optimizer is used. Defaults to None.
+            save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
+            log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
+            save_dir(str, optional): Directory to save the model. Defaults to 'output'.
+            pretrain_weights(str or None, optional):
+                None or name/path of pretrained weights. If None, no pretrained weights will be loaded.
+                At most one of `resume_checkpoint` and `pretrain_weights` can be set simultaneously.
+                Defaults to 'IMAGENET'.
+            learning_rate(float, optional): Learning rate for training. Defaults to .025.
+            warmup_steps(int, optional): The number of steps of warm-up training. Defaults to 0.
+            warmup_start_lr(float, optional): Start learning rate of warm-up training. Defaults to 0..
+            lr_decay_epochs(List[int] or Tuple[int], optional):
+                Epoch milestones for learning rate decay. Defaults to (20, 60, 90).
+            lr_decay_gamma(float, optional): Gamma coefficient of learning rate decay, default .1.
+            label_smoothing(float, bool or None, optional): Whether to adopt label smoothing or not.
+                If float, the value refer to epsilon coefficient of label smoothing. If False or None, label smoothing
+                will not be adopted. Otherwise, adopt label smoothing with epsilon equals to 0.1. Defaults to None.
+            early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
+            early_stop_patience(int, optional): Early stop patience. Defaults to 5.
+            use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
+            resume_checkpoint(str or None, optional): The path of the checkpoint to resume training from.
+                If None, no training checkpoint will be resumed. At most one of `resume_checkpoint` and
+                `pretrain_weights` can be set simultaneously. Defaults to None.
+
+        """
+        if optimizer is None:
+            num_steps_each_epoch = len(train_dataset) // train_batch_size
+            optimizer = self.default_optimizer(
+                parameters=self.net.parameters(),
+                learning_rate=learning_rate,
+                warmup_steps=warmup_steps,
+                warmup_start_lr=warmup_start_lr,
+                lr_decay_epochs=lr_decay_epochs,
+                lr_decay_gamma=lr_decay_gamma,
+                num_steps_each_epoch=num_steps_each_epoch,
+                reg_coeff=3e-5,
+                scheduler='Cosine',
+                num_epochs=num_epochs)
+        super(PPLCNet, self).train(
+            num_epochs=num_epochs,
+            train_dataset=train_dataset,
+            train_batch_size=train_batch_size,
+            eval_dataset=eval_dataset,
+            optimizer=optimizer,
+            save_interval_epochs=save_interval_epochs,
+            log_interval_steps=log_interval_steps,
+            save_dir=save_dir,
+            pretrain_weights=pretrain_weights,
+            learning_rate=learning_rate,
+            warmup_steps=warmup_steps,
+            warmup_start_lr=warmup_start_lr,
+            lr_decay_epochs=lr_decay_epochs,
+            lr_decay_gamma=lr_decay_gamma,
+            label_smoothing=label_smoothing,
+            early_stop=early_stop,
+            early_stop_patience=early_stop_patience,
+            use_vdl=use_vdl,
+            resume_checkpoint=resume_checkpoint)
+
 
-class PPLCNet_ssld(BaseClassifier):
+class PPLCNet_ssld(PPLCNet):
     def __init__(self, num_classes=1000, scale=1., **params):
         supported_scale = [.5, 1., 2.5]
         if scale not in supported_scale:
@@ -858,6 +984,6 @@ class PPLCNet_ssld(BaseClassifier):
                             "scale is forcibly set to 1.0".format(scale))
             scale = 1.0
         model_name = 'PPLCNet_x' + str(float(scale)).replace('.', '_')
-        super(PPLCNet_ssld, self).__init__(
+        super(PPLCNet, self).__init__(
             model_name=model_name, num_classes=num_classes, **params)
         self.model_name = model_name + '_ssld'