Browse Source

add early_stop policy

FlyingQianMM 5 years ago
parent
commit
1c205a1469

+ 18 - 6
docs/apis/models.md

@@ -17,7 +17,7 @@ paddlex.cls.ResNet50(num_classes=1000)
 #### 分类器训练函数接口
 
 > ```python
-> train(self, num_epochs, train_dataset, train_batch_size=64, eval_dataset=None, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.025, lr_decay_epochs=[30, 60, 90], lr_decay_gamma=0.1, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05)
+> train(self, num_epochs, train_dataset, train_batch_size=64, eval_dataset=None, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.025, lr_decay_epochs=[30, 60, 90], lr_decay_gamma=0.1, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5)
 > ```
 >
 > **参数:**
@@ -37,6 +37,8 @@ paddlex.cls.ResNet50(num_classes=1000)
 > > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认值为False。
 > > - **sensitivities_file** (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',则自动下载在ImageNet图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
 > > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
+> > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
+> > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
 
 #### 分类器评估函数接口
 
@@ -109,7 +111,7 @@ paddlex.det.YOLOv3(num_classes=80, backbone='MobileNetV1', anchors=None, anchor_
 #### YOLOv3训练函数接口
 
 > ```python
-> train(self, num_epochs, train_dataset, train_batch_size=8, eval_dataset=None, save_interval_epochs=20, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=1.0/8000, warmup_steps=1000, warmup_start_lr=0.0, lr_decay_epochs=[213, 240], lr_decay_gamma=0.1, metric=None, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05)
+> train(self, num_epochs, train_dataset, train_batch_size=8, eval_dataset=None, save_interval_epochs=20, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=1.0/8000, warmup_steps=1000, warmup_start_lr=0.0, lr_decay_epochs=[213, 240], lr_decay_gamma=0.1, metric=None, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5)
 > ```
 >
 > **参数:**
@@ -132,6 +134,8 @@ paddlex.det.YOLOv3(num_classes=80, backbone='MobileNetV1', anchors=None, anchor_
 > > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认值为False。
 > > - **sensitivities_file** (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',则自动下载在PascalVOC数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
 > > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
+> > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
+> > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
 
 #### YOLOv3评估函数接口
 
@@ -186,7 +190,7 @@ paddlex.det.FasterRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspec
 #### FasterRCNN训练函数接口
 
 > ```python
-> train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, save_interval_epochs=1, log_interval_steps=2,save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.0025, warmup_steps=500, warmup_start_lr=1.0/1200, lr_decay_epochs=[8, 11], lr_decay_gamma=0.1, metric=None, use_vdl=False)
+> train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, save_interval_epochs=1, log_interval_steps=2,save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.0025, warmup_steps=500, warmup_start_lr=1.0/1200, lr_decay_epochs=[8, 11], lr_decay_gamma=0.1, metric=None, use_vdl=False, early_stop=False, early_stop_patience=5)
 >
 > ```
 >
@@ -208,6 +212,8 @@ paddlex.det.FasterRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspec
 > > - **lr_decay_gamma** (float): 默认优化器的学习率衰减率。默认为0.1。
 > > - **metric** (bool): 训练过程中评估的方式,取值范围为['COCO', 'VOC']。默认值为None。
 > > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认值为False。
+> > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
+> > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
 
 #### FasterRCNN评估函数接口
 
@@ -264,7 +270,7 @@ paddlex.det.MaskRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspect_
 #### MaskRCNN训练函数接口
 
 > ```python
-> train(self, num_epochs, train_dataset, train_batch_size=1, eval_dataset=None, save_interval_epochs=1, log_interval_steps=20, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=1.0/800, warmup_steps=500, warmup_start_lr=1.0 / 2400, lr_decay_epochs=[8, 11], lr_decay_gamma=0.1, metric=None, use_vdl=False)
+> train(self, num_epochs, train_dataset, train_batch_size=1, eval_dataset=None, save_interval_epochs=1, log_interval_steps=20, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=1.0/800, warmup_steps=500, warmup_start_lr=1.0 / 2400, lr_decay_epochs=[8, 11], lr_decay_gamma=0.1, metric=None, use_vdl=False, early_stop=False, early_stop_patience=5)
 >
 > ```
 >
@@ -286,6 +292,8 @@ paddlex.det.MaskRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspect_
 > > - **lr_decay_gamma** (float): 默认优化器的学习率衰减率。默认为0.1。
 > > - **metric** (bool): 训练过程中评估的方式,取值范围为['COCO', 'VOC']。默认值为None。
 > > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认值为False。
+> > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
+> > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
 
 #### MaskRCNN评估函数接口
 
@@ -350,7 +358,7 @@ paddlex.seg.DeepLabv3p(num_classes=2, backbone='MobileNetV2_x1.0', output_stride
 #### DeepLabv3训练函数接口
 
 > ```python
-> train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, eval_batch_size=1, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.01, lr_decay_power=0.9, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05):
+> train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, eval_batch_size=1, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.01, lr_decay_power=0.9, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5):
 >
 > ```
 >
@@ -370,6 +378,8 @@ paddlex.seg.DeepLabv3p(num_classes=2, backbone='MobileNetV2_x1.0', output_stride
 > > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认False。
 > > - **sensitivities_file** (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',则自动下载在ImageNet图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
 > > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
+> > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
+> > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
 
 #### DeepLabv3评估函数接口
 
@@ -427,7 +437,7 @@ paddlex.seg.UNet(num_classes=2, upsample_mode='bilinear', use_bce_loss=False, us
 #### Unet训练函数接口
 
 > ```python
-> train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, eval_batch_size=1, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='COCO', optimizer=None, learning_rate=0.01, lr_decay_power=0.9, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05):
+> train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, eval_batch_size=1, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='COCO', optimizer=None, learning_rate=0.01, lr_decay_power=0.9, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5):
 > ```
 >
 > **参数:**
@@ -446,6 +456,8 @@ paddlex.seg.UNet(num_classes=2, upsample_mode='bilinear', use_bce_loss=False, us
 > > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认False。
 > > - **sensitivities_file** (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',则自动下载在ImageNet图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
 > > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
+> > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
+> > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
 
 #### Unet评估函数接口
 

+ 11 - 1
paddlex/cv/models/base.py

@@ -24,6 +24,7 @@ import json
 import functools
 import paddlex.utils.logging as logging
 from paddlex.utils import seconds_to_hms
+from paddlex.utils.utils import EarlyStop
 import paddlex
 from collections import OrderedDict
 from os import path as osp
@@ -334,7 +335,9 @@ class BaseAPI:
                    save_interval_epochs=1,
                    log_interval_steps=10,
                    save_dir='output',
-                   use_vdl=False):
+                   use_vdl=False,
+                   early_stop=False,
+                   early_stop_patience=5):
         if not osp.isdir(save_dir):
             if osp.exists(save_dir):
                 os.remove(save_dir)
@@ -396,6 +399,9 @@ class BaseAPI:
             train_step_component = OrderedDict()
             eval_component = OrderedDict()
 
+        thresh = 0.0001
+        if early_stop:
+            earlystop = EarlyStop(early_stop_patience, thresh)
         best_accuracy_key = ""
         best_accuracy = -1.0
         best_model_epoch = 1
@@ -507,3 +513,7 @@ class BaseAPI:
                     'Current evaluated best model in eval_dataset is epoch_{}, {}={}'
                     .format(best_model_epoch, best_accuracy_key,
                             best_accuracy))
+                if eval_dataset is not None:
+                    if early_stop:
+                        if earlystop(current_accuracy):
+                            break

+ 9 - 2
paddlex/cv/models/classifier.py

@@ -102,7 +102,9 @@ class BaseClassifier(BaseAPI):
               lr_decay_gamma=0.1,
               use_vdl=False,
               sensitivities_file=None,
-              eval_metric_loss=0.05):
+              eval_metric_loss=0.05,
+              early_stop=False,
+              early_stop_patience=5):
         """训练。
 
         Args:
@@ -124,6 +126,9 @@ class BaseClassifier(BaseAPI):
             sensitivities_file (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',
                 则自动下载在ImageNet图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
             eval_metric_loss (float): 可容忍的精度损失。默认为0.05。
+            early_stop (bool): 是否使用提前终止训练策略。默认值为False。
+            early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
+                连续下降或持平,则终止训练。默认值为5。
 
         Raises:
             ValueError: 模型从inference model进行加载。
@@ -158,7 +163,9 @@ class BaseClassifier(BaseAPI):
             save_interval_epochs=save_interval_epochs,
             log_interval_steps=log_interval_steps,
             save_dir=save_dir,
-            use_vdl=use_vdl)
+            use_vdl=use_vdl,
+            early_stop=early_stop,
+            early_stop_patience=early_stop_patience)
 
     def evaluate(self,
                  eval_dataset,

+ 9 - 2
paddlex/cv/models/deeplabv3p.py

@@ -231,7 +231,9 @@ class DeepLabv3p(BaseAPI):
               lr_decay_power=0.9,
               use_vdl=False,
               sensitivities_file=None,
-              eval_metric_loss=0.05):
+              eval_metric_loss=0.05,
+              early_stop=False,
+              early_stop_patience=5):
         """训练。
 
         Args:
@@ -252,6 +254,9 @@ class DeepLabv3p(BaseAPI):
             sensitivities_file (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',
                 则自动下载在ImageNet图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
             eval_metric_loss (float): 可容忍的精度损失。默认为0.05。
+            early_stop (bool): 是否使用提前终止训练策略。默认值为False。
+            early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
+                连续下降或持平,则终止训练。默认值为5。
 
         Raises:
             ValueError: 模型从inference model进行加载。
@@ -288,7 +293,9 @@ class DeepLabv3p(BaseAPI):
             save_interval_epochs=save_interval_epochs,
             log_interval_steps=log_interval_steps,
             save_dir=save_dir,
-            use_vdl=use_vdl)
+            use_vdl=use_vdl,
+            early_stop=early_stop,
+            early_stop_patience=early_stop_patience)
 
     def evaluate(self,
                  eval_dataset,

+ 9 - 2
paddlex/cv/models/faster_rcnn.py

@@ -163,7 +163,9 @@ class FasterRCNN(BaseAPI):
               lr_decay_epochs=[8, 11],
               lr_decay_gamma=0.1,
               metric=None,
-              use_vdl=False):
+              use_vdl=False,
+              early_stop=False,
+              early_stop_patience=5):
         """训练。
 
         Args:
@@ -186,6 +188,9 @@ class FasterRCNN(BaseAPI):
             lr_decay_gamma (float): 默认优化器的学习率衰减率。默认为0.1。
             metric (bool): 训练过程中评估的方式,取值范围为['COCO', 'VOC']。默认值为None。
             use_vdl (bool): 是否使用VisualDL进行可视化。默认值为False。
+            early_stop (bool): 是否使用提前终止训练策略。默认值为False。
+            early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
+                连续下降或持平,则终止训练。默认值为5。
 
         Raises:
             ValueError: 评估类型不在指定列表中。
@@ -233,7 +238,9 @@ class FasterRCNN(BaseAPI):
             save_interval_epochs=save_interval_epochs,
             log_interval_steps=log_interval_steps,
             save_dir=save_dir,
-            use_vdl=use_vdl)
+            use_vdl=use_vdl,
+            early_stop=early_stop,
+            early_stop_patience=early_stop_patience)
 
     def evaluate(self,
                  eval_dataset,

+ 9 - 2
paddlex/cv/models/mask_rcnn.py

@@ -128,7 +128,9 @@ class MaskRCNN(FasterRCNN):
               lr_decay_epochs=[8, 11],
               lr_decay_gamma=0.1,
               metric=None,
-              use_vdl=False):
+              use_vdl=False,
+              early_stop=False,
+              early_stop_patience=5):
         """训练。
 
         Args:
@@ -151,6 +153,9 @@ class MaskRCNN(FasterRCNN):
             lr_decay_gamma (float): 默认优化器的学习率衰减率。默认为0.1。
             metric (bool): 训练过程中评估的方式,取值范围为['COCO', 'VOC']。
             use_vdl (bool): 是否使用VisualDL进行可视化。默认值为False。
+            early_stop (bool): 是否使用提前终止训练策略。默认值为False。
+            early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
+                连续下降或持平,则终止训练。默认值为5。
 
         Raises:
             ValueError: 评估类型不在指定列表中。
@@ -199,7 +204,9 @@ class MaskRCNN(FasterRCNN):
             save_interval_epochs=save_interval_epochs,
             log_interval_steps=log_interval_steps,
             save_dir=save_dir,
-            use_vdl=use_vdl)
+            use_vdl=use_vdl,
+            early_stop=early_stop,
+            early_stop_patience=early_stop_patience)
 
     def evaluate(self,
                  eval_dataset,

+ 13 - 6
paddlex/cv/models/unet.py

@@ -117,7 +117,9 @@ class UNet(DeepLabv3p):
               lr_decay_power=0.9,
               use_vdl=False,
               sensitivities_file=None,
-              eval_metric_loss=0.05):
+              eval_metric_loss=0.05,
+              early_stop=False,
+              early_stop_patience=5):
         """训练。
 
         Args:
@@ -138,12 +140,17 @@ class UNet(DeepLabv3p):
             sensitivities_file (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',
                 则自动下载在ImageNet图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
             eval_metric_loss (float): 可容忍的精度损失。默认为0.05。
+            early_stop (bool): 是否使用提前终止训练策略。默认值为False。
+            early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
+                连续下降或持平,则终止训练。默认值为5。
 
         Raises:
             ValueError: 模型从inference model进行加载。
         """
-        return super(UNet, self).train(
-            num_epochs, train_dataset, train_batch_size, eval_dataset,
-            save_interval_epochs, log_interval_steps, save_dir,
-            pretrain_weights, optimizer, learning_rate, lr_decay_power,
-            use_vdl, sensitivities_file, eval_metric_loss)
+        return super(
+            UNet,
+            self).train(num_epochs, train_dataset, train_batch_size,
+                        eval_dataset, save_interval_epochs, log_interval_steps,
+                        save_dir, pretrain_weights, optimizer, learning_rate,
+                        lr_decay_power, use_vdl, sensitivities_file,
+                        eval_metric_loss, early_stop, early_stop_patience)

+ 9 - 2
paddlex/cv/models/yolo_v3.py

@@ -162,7 +162,9 @@ class YOLOv3(BaseAPI):
               metric=None,
               use_vdl=False,
               sensitivities_file=None,
-              eval_metric_loss=0.05):
+              eval_metric_loss=0.05,
+              early_stop=False,
+              early_stop_patience=5):
         """训练。
 
         Args:
@@ -188,6 +190,9 @@ class YOLOv3(BaseAPI):
             sensitivities_file (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',
                 则自动下载在ImageNet图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
             eval_metric_loss (float): 可容忍的精度损失。默认为0.05。
+            early_stop (bool): 是否使用提前终止训练策略。默认值为False。
+            early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
+                连续下降或持平,则终止训练。默认值为5。
 
         Raises:
             ValueError: 评估类型不在指定列表中。
@@ -238,7 +243,9 @@ class YOLOv3(BaseAPI):
             save_interval_epochs=save_interval_epochs,
             log_interval_steps=log_interval_steps,
             save_dir=save_dir,
-            use_vdl=use_vdl)
+            use_vdl=use_vdl,
+            early_stop=early_stop,
+            early_stop_patience=early_stop_patience)
 
     def evaluate(self,
                  eval_dataset,

+ 36 - 0
paddlex/utils/utils.py

@@ -220,3 +220,39 @@ def load_pretrain_weights(exe, main_prog, weights_dir, fuse_bn=False):
             len(vars_to_load), weights_dir))
     if fuse_bn:
         fuse_bn_weights(exe, main_prog, weights_dir)
+
+
+class EarlyStop:
+    def __init__(self, patience, thresh):
+        self.patience = patience
+        self.counter = 0
+        self.score = None
+        self.max = 0
+        self.thresh = thresh
+        if patience < 1:
+            raise Exception("Argument patience should be a positive integer.")
+
+    def __call__(self, current_score):
+        if self.score is None:
+            self.score = current_score
+            return False
+        elif current_score > self.max:
+            self.counter = 0
+            self.score = current_score
+            self.max = current_score
+            return False
+        else:
+            if (abs(self.score - current_score) < self.thresh
+                    or current_score < self.score):
+                self.counter += 1
+                self.score = current_score
+                logging.debug(
+                    "EarlyStopping: %i / %i" % (self.counter, self.patience))
+                if self.counter >= self.patience:
+                    logging.info("EarlyStopping: Stop training")
+                    return True
+                return False
+            else:
+                self.counter = 0
+                self.score = current_score
+                return False