ソースを参照

add warmup for classifier

jiangjiajun 5 年 前
コミット
aefe18c7a3
3 ファイル変更59 行追加16 行削除
  1. 24 0
      docs/appendix/parameters.md
  2. 28 8
      paddlex/cv/models/classifier.py
  3. 7 8
      paddlex/utils/logging.py

+ 24 - 0
docs/appendix/parameters.md

@@ -23,3 +23,27 @@ Batch Size指模型在训练过程中,一次性处理的样本数量, 如若
 - [实例分割MaskRCNN-train](https://paddlex.readthedocs.io/zh_CN/latest/apis/models/instance_segmentation.html#train)
 - [语义分割DeepLabv3p-train](https://paddlex.readthedocs.io/zh_CN/latest/apis/models/semantic_segmentation.html#train)
 - [语义分割UNet](https://paddlex.readthedocs.io/zh_CN/latest/apis/models/semantic_segmentation.html#id2)
+
+## 关于lr_decay_epoch, warmup_steps等参数的说明
+
+在PaddleX或其它深度学习模型的训练过程中,经常见到lr_decay_epoch, warmup_steps, warmup_start_lr等参数设置,下面介绍一些这些参数的作用。  
+
+首先这些参数都是用于控制模型训练过程中学习率的变化方式,例如我们在训练时将learning_rate设为0.1, 通常情况,在模型的训练过程中,学习率一直以0.1不变训练下去, 但为了调出更好的模型效果,我们往往不希望学习率一直保持不变。
+
+### warmup_steps和warmup_start_lr
+
+我们在训练模型时,一般都会使用预训练模型,例如检测模型在训练时使用backbone在ImageNet数据集上的预训练权重。但由于在自行训练时,自己的数据与ImageNet数据集存在较大的差异,可能会一开始由于梯度过大使得训练出现问题,因此可以在刚开始训练时,让学习率以一个较小的值,慢慢增长到设定的学习率。因此`warmup_steps`和`warmup_start_lr`就是这个作用,模型开始训练时,学习率会从`warmup_start_lr`开始,在`warmup_steps`内线性增长到设定的学习率。
+
+### lr_decay_epochs和lr_decay_gamma
+
+`lr_decay_epochs`用于让学习率在模型训练后期逐步衰减,它一般是一个list,如[6, 8, 10],表示学习率在第6个epoch时衰减一次,第8个epoch时再衰减一次,第10个epoch时再衰减一次。每次学习率衰减为之前的学习率*lr_decay_gamma
+
+### PaddleX中对warmup_steps和lr_decay_epochs的约束限制
+
+在PaddleX中,限制warmup需要在第一个学习率decay衰减前结束,因此要满足下面的公式
+```
+warmup_steps <= lr_decay_epochs[0] * num_steps_each_epoch
+```
+其中公式中`num_steps_each_epoch = num_samples_in_train_dataset // train_batch_size`。  
+
+>  因此如若在训练时PaddleX提示`warmup_steps should be less than xxx`时,即可根据上述公式来调整你的`lr_decay_epochs`或者是`warmup_steps`使得两个参数满足上面的条件

+ 28 - 8
paddlex/cv/models/classifier.py

@@ -52,8 +52,7 @@ class BaseClassifier(BaseAPI):
             input_shape = [
                 None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0]
             ]
-            image = fluid.data(
-                dtype='float32', shape=input_shape, name='image')
+            image = fluid.data(dtype='float32', shape=input_shape, name='image')
         else:
             image = fluid.data(
                 dtype='float32', shape=[None, 3, None, None], name='image')
@@ -81,7 +80,8 @@ class BaseClassifier(BaseAPI):
             del outputs['loss']
         return inputs, outputs
 
-    def default_optimizer(self, learning_rate, lr_decay_epochs, lr_decay_gamma,
+    def default_optimizer(self, learning_rate, warmup_steps, warmup_start_lr,
+                          lr_decay_epochs, lr_decay_gamma,
                           num_steps_each_epoch):
         boundaries = [b * num_steps_each_epoch for b in lr_decay_epochs]
         values = [
@@ -90,6 +90,22 @@ class BaseClassifier(BaseAPI):
         ]
         lr_decay = fluid.layers.piecewise_decay(
             boundaries=boundaries, values=values)
+        if warmup_steps > 0:
+            if warmup_steps > lr_decay_epochs[0] * num_steps_each_epoch:
+                logging.error(
+                    "In function train(), parameters should satisfy: warmup_steps <= lr_decay_epochs[0]*num_samples_in_train_dataset",
+                    exit=False)
+                logging.error(
+                    "See this doc for more information: xxxx", exit=False)
+                logging.error(
+                    "warmup_steps should less than {}, please modify 'lr_decay_epochs' or 'warmup_steps' in train function".
+                    format(lr_decay_epochs[0] * num_steps_each_epoch))
+
+            lr_decay = fluid.layers.linear_lr_warmup(
+                learning_rate=lr_decay,
+                warmup_steps=warmup_steps,
+                start_lr=warmup_start_lr,
+                end_lr=learning_rate)
         optimizer = fluid.optimizer.Momentum(
             lr_decay,
             momentum=0.9,
@@ -107,6 +123,8 @@ class BaseClassifier(BaseAPI):
               pretrain_weights='IMAGENET',
               optimizer=None,
               learning_rate=0.025,
+              warmup_steps=0,
+              warmup_start_lr=0.0,
               lr_decay_epochs=[30, 60, 90],
               lr_decay_gamma=0.1,
               use_vdl=False,
@@ -129,6 +147,8 @@ class BaseClassifier(BaseAPI):
             optimizer (paddle.fluid.optimizer): 优化器。当该参数为None时,使用默认优化器:
                 fluid.layers.piecewise_decay衰减策略,fluid.optimizer.Momentum优化方法。
             learning_rate (float): 默认优化器的初始学习率。默认为0.025。
+            warmup_steps(int): 学习率从warmup_start_lr上升至设定的learning_rate,所需的步数,默认为0
+            warmup_start_lr(float): 学习率在warmup阶段时的起始值,默认为0.0
             lr_decay_epochs (list): 默认优化器的学习率衰减轮数。默认为[30, 60, 90]。
             lr_decay_gamma (float): 默认优化器的学习率衰减率。默认为0.1。
             use_vdl (bool): 是否使用VisualDL进行可视化。默认值为False。
@@ -149,6 +169,8 @@ class BaseClassifier(BaseAPI):
             num_steps_each_epoch = train_dataset.num_samples // train_batch_size
             optimizer = self.default_optimizer(
                 learning_rate=learning_rate,
+                warmup_steps=warmup_steps,
+                warmup_start_lr=warmup_start_lr,
                 lr_decay_epochs=lr_decay_epochs,
                 lr_decay_gamma=lr_decay_gamma,
                 num_steps_each_epoch=num_steps_each_epoch)
@@ -193,8 +215,7 @@ class BaseClassifier(BaseAPI):
           tuple (metrics, eval_details): 当return_details为True时,增加返回dict,
               包含关键字:'true_labels'、'pred_scores',分别代表真实类别id、每个类别的预测得分。
         """
-        self.arrange_transforms(
-            transforms=eval_dataset.transforms, mode='eval')
+        self.arrange_transforms(transforms=eval_dataset.transforms, mode='eval')
         data_generator = eval_dataset.generator(
             batch_size=batch_size, drop_last=False)
         k = min(5, self.num_classes)
@@ -206,9 +227,8 @@ class BaseClassifier(BaseAPI):
                 self.test_prog).with_data_parallel(
                     share_vars_from=self.parallel_train_prog)
         batch_size_each_gpu = self._get_single_card_bs(batch_size)
-        logging.info(
-            "Start to evaluating(total_samples={}, total_steps={})...".format(
-                eval_dataset.num_samples, total_steps))
+        logging.info("Start to evaluating(total_samples={}, total_steps={})...".
+                     format(eval_dataset.num_samples, total_steps))
         for step, data in tqdm.tqdm(
                 enumerate(data_generator()), total=total_steps):
             images = np.array([d[0] for d in data]).astype('float32')

+ 7 - 8
paddlex/utils/logging.py

@@ -29,13 +29,11 @@ def log(level=2, message="", use_color=False):
     current_time = time.strftime("%Y-%m-%d %H:%M:%S", time_array)
     if paddlex.log_level >= level:
         if use_color:
-            print("\033[1;31;40m{} [{}]\t{}\033[0m".format(
-                current_time, levels[level],
-                message).encode("utf-8").decode("latin1"))
+            print("\033[1;31;40m{} [{}]\t{}\033[0m".format(current_time, levels[
+                level], message).encode("utf-8").decode("latin1"))
         else:
-            print(
-                "{} [{}]\t{}".format(current_time, levels[level],
-                                     message).encode("utf-8").decode("latin1"))
+            print("{} [{}]\t{}".format(current_time, levels[level], message)
+                  .encode("utf-8").decode("latin1"))
         sys.stdout.flush()
 
 
@@ -51,6 +49,7 @@ def warning(message="", use_color=True):
     log(level=1, message=message, use_color=use_color)
 
 
-def error(message="", use_color=True):
+def error(message="", use_color=True, exit=True):
     log(level=0, message=message, use_color=use_color)
-    sys.exit(-1)
+    if exit:
+        sys.exit(-1)