Parcourir la source

Merge pull request #38 from FlyingQianMM/develop_qh

add resume training from a checkpoint
Jason il y a 5 ans
Parent
commit
a1036d6e90

+ 12 - 6
docs/apis/models.md

@@ -17,7 +17,7 @@ paddlex.cls.ResNet50(num_classes=1000)
 #### 分类器训练函数接口
 
 > ```python
-> train(self, num_epochs, train_dataset, train_batch_size=64, eval_dataset=None, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.025, lr_decay_epochs=[30, 60, 90], lr_decay_gamma=0.1, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5)
+> train(self, num_epochs, train_dataset, train_batch_size=64, eval_dataset=None, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.025, lr_decay_epochs=[30, 60, 90], lr_decay_gamma=0.1, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5, resume_checkpoint=None)
 > ```
 >
 > **参数:**
@@ -39,6 +39,7 @@ paddlex.cls.ResNet50(num_classes=1000)
 > > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
 > > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
 > > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
+> > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
 
 #### 分类器评估函数接口
 
@@ -111,7 +112,7 @@ paddlex.det.YOLOv3(num_classes=80, backbone='MobileNetV1', anchors=None, anchor_
 #### YOLOv3训练函数接口
 
 > ```python
-> train(self, num_epochs, train_dataset, train_batch_size=8, eval_dataset=None, save_interval_epochs=20, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=1.0/8000, warmup_steps=1000, warmup_start_lr=0.0, lr_decay_epochs=[213, 240], lr_decay_gamma=0.1, metric=None, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5)
+> train(self, num_epochs, train_dataset, train_batch_size=8, eval_dataset=None, save_interval_epochs=20, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=1.0/8000, warmup_steps=1000, warmup_start_lr=0.0, lr_decay_epochs=[213, 240], lr_decay_gamma=0.1, metric=None, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5, resume_checkpoint=None)
 > ```
 >
 > **参数:**
@@ -136,6 +137,7 @@ paddlex.det.YOLOv3(num_classes=80, backbone='MobileNetV1', anchors=None, anchor_
 > > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
 > > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
 > > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
+> > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
 
 #### YOLOv3评估函数接口
 
@@ -190,7 +192,7 @@ paddlex.det.FasterRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspec
 #### FasterRCNN训练函数接口
 
 > ```python
-> train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, save_interval_epochs=1, log_interval_steps=2,save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.0025, warmup_steps=500, warmup_start_lr=1.0/1200, lr_decay_epochs=[8, 11], lr_decay_gamma=0.1, metric=None, use_vdl=False, early_stop=False, early_stop_patience=5)
+> train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, save_interval_epochs=1, log_interval_steps=2,save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.0025, warmup_steps=500, warmup_start_lr=1.0/1200, lr_decay_epochs=[8, 11], lr_decay_gamma=0.1, metric=None, use_vdl=False, early_stop=False, early_stop_patience=5, resume_checkpoint=None)
 >
 > ```
 >
@@ -214,6 +216,7 @@ paddlex.det.FasterRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspec
 > > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认值为False。
 > > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
 > > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
+> > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
 
 #### FasterRCNN评估函数接口
 
@@ -270,7 +273,7 @@ paddlex.det.MaskRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspect_
 #### MaskRCNN训练函数接口
 
 > ```python
-> train(self, num_epochs, train_dataset, train_batch_size=1, eval_dataset=None, save_interval_epochs=1, log_interval_steps=20, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=1.0/800, warmup_steps=500, warmup_start_lr=1.0 / 2400, lr_decay_epochs=[8, 11], lr_decay_gamma=0.1, metric=None, use_vdl=False, early_stop=False, early_stop_patience=5)
+> train(self, num_epochs, train_dataset, train_batch_size=1, eval_dataset=None, save_interval_epochs=1, log_interval_steps=20, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=1.0/800, warmup_steps=500, warmup_start_lr=1.0 / 2400, lr_decay_epochs=[8, 11], lr_decay_gamma=0.1, metric=None, use_vdl=False, early_stop=False, early_stop_patience=5, resume_checkpoint=None)
 >
 > ```
 >
@@ -294,6 +297,7 @@ paddlex.det.MaskRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspect_
 > > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认值为False。
 > > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
 > > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
+> > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
 
 #### MaskRCNN评估函数接口
 
@@ -358,7 +362,7 @@ paddlex.seg.DeepLabv3p(num_classes=2, backbone='MobileNetV2_x1.0', output_stride
 #### DeepLabv3训练函数接口
 
 > ```python
-> train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, eval_batch_size=1, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.01, lr_decay_power=0.9, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5):
+> train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, eval_batch_size=1, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.01, lr_decay_power=0.9, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5, resume_checkpoint=None):
 >
 > ```
 >
@@ -380,6 +384,7 @@ paddlex.seg.DeepLabv3p(num_classes=2, backbone='MobileNetV2_x1.0', output_stride
 > > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
 > > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
 > > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
+> > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
 
 #### DeepLabv3评估函数接口
 
@@ -437,7 +442,7 @@ paddlex.seg.UNet(num_classes=2, upsample_mode='bilinear', use_bce_loss=False, us
 #### Unet训练函数接口
 
 > ```python
-> train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, eval_batch_size=1, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='COCO', optimizer=None, learning_rate=0.01, lr_decay_power=0.9, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5):
+> train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, eval_batch_size=1, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='COCO', optimizer=None, learning_rate=0.01, lr_decay_power=0.9, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5, resume_checkpoint=None):
 > ```
 >
 > **参数:**
@@ -458,6 +463,7 @@ paddlex.seg.UNet(num_classes=2, upsample_mode='bilinear', use_bce_loss=False, us
 > > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
 > > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
 > > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
+> > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
 
 #### Unet评估函数接口
 

+ 39 - 17
paddlex/cv/models/base.py

@@ -70,6 +70,8 @@ class BaseAPI:
         self.sync_bn = False
         # 当前模型状态
         self.status = 'Normal'
+        # 已完成迭代轮数,为恢复训练时的起始轮数
+        self.completed_epochs = 0
 
     def _get_single_card_bs(self, batch_size):
         if batch_size % len(self.places) == 0:
@@ -182,24 +184,39 @@ class BaseAPI:
                        fuse_bn=False,
                        save_dir='.',
                        sensitivities_file=None,
-                       eval_metric_loss=0.05):
-        pretrain_dir = osp.join(save_dir, 'pretrain')
-        if not os.path.isdir(pretrain_dir):
-            if os.path.exists(pretrain_dir):
-                os.remove(pretrain_dir)
-            os.makedirs(pretrain_dir)
-        if hasattr(self, 'backbone'):
-            backbone = self.backbone
-        else:
-            backbone = self.__class__.__name__
-        pretrain_weights = get_pretrain_weights(
-            pretrain_weights, self.model_type, backbone, pretrain_dir)
+                       eval_metric_loss=0.05,
+                       resume_checkpoint=None):
+        if not resume_checkpoint:
+            pretrain_dir = osp.join(save_dir, 'pretrain')
+            if not os.path.isdir(pretrain_dir):
+                if os.path.exists(pretrain_dir):
+                    os.remove(pretrain_dir)
+                os.makedirs(pretrain_dir)
+            if hasattr(self, 'backbone'):
+                backbone = self.backbone
+            else:
+                backbone = self.__class__.__name__
+            pretrain_weights = get_pretrain_weights(
+                pretrain_weights, self.model_type, backbone, pretrain_dir)
         if startup_prog is None:
             startup_prog = fluid.default_startup_program()
         self.exe.run(startup_prog)
-        if pretrain_weights is not None:
+        if resume_checkpoint:
+            logging.info(
+                "Resume checkpoint from {}.".format(resume_checkpoint),
+                use_color=True)
+            paddlex.utils.utils.load_pretrain_weights(
+                self.exe, self.train_prog, resume_checkpoint, resume=True)
+            if not osp.exists(osp.join(resume_checkpoint, "model.yml")):
+                raise Exception(
+                    "There's not model.yml in {}".format(resume_checkpoint))
+            with open(osp.join(resume_checkpoint, "model.yml")) as f:
+                info = yaml.load(f.read(), Loader=yaml.Loader)
+                self.completed_epochs = info['completed_epochs']
+        elif pretrain_weights is not None:
             logging.info(
-                "Load pretrain weights from {}.".format(pretrain_weights), use_color=True)
+                "Load pretrain weights from {}.".format(pretrain_weights),
+                use_color=True)
             paddlex.utils.utils.load_pretrain_weights(
                 self.exe, self.train_prog, pretrain_weights, fuse_bn)
         # 进行裁剪
@@ -211,7 +228,8 @@ class BaseAPI:
             from .slim.prune import get_params_ratios, prune_program
             logging.info(
                 "Start to prune program with eval_metric_loss = {}".format(
-                    eval_metric_loss), use_color=True)
+                    eval_metric_loss),
+                use_color=True)
             origin_flops = paddleslim.analysis.flops(self.test_prog)
             prune_params_ratios = get_params_ratios(
                 sensitivities_file, eval_metric_loss=eval_metric_loss)
@@ -220,7 +238,8 @@ class BaseAPI:
             remaining_ratio = current_flops / origin_flops
             logging.info(
                 "Finish prune program, before FLOPs:{}, after prune FLOPs:{}, remaining ratio:{}"
-                .format(origin_flops, current_flops, remaining_ratio), use_color=True)
+                .format(origin_flops, current_flops, remaining_ratio),
+                use_color=True)
             self.status = 'Prune'
 
     def get_model_info(self):
@@ -258,6 +277,7 @@ class BaseAPI:
                     name = op.__class__.__name__
                     attr = op.__dict__
                     info['Transforms'].append({name: attr})
+        info['completed_epochs'] = self.completed_epochs
         return info
 
     def save_model(self, save_dir):
@@ -418,7 +438,8 @@ class BaseAPI:
         best_accuracy_key = ""
         best_accuracy = -1.0
         best_model_epoch = -1
-        for i in range(num_epochs):
+        start_epoch = self.completed_epochs
+        for i in range(start_epoch, num_epochs):
             records = list()
             step_start_time = time.time()
             epoch_start_time = time.time()
@@ -498,6 +519,7 @@ class BaseAPI:
                         return_details=True)
                     logging.info('[EVAL] Finished, Epoch={}, {} .'.format(
                         i + 1, dict2str(self.eval_metrics)))
+                    self.completed_epochs += 1
                     # 保存最优模型
                     best_accuracy_key = list(self.eval_metrics.keys())[0]
                     current_accuracy = self.eval_metrics[best_accuracy_key]

+ 5 - 3
paddlex/cv/models/classifier.py

@@ -112,7 +112,8 @@ class BaseClassifier(BaseAPI):
               sensitivities_file=None,
               eval_metric_loss=0.05,
               early_stop=False,
-              early_stop_patience=5):
+              early_stop_patience=5,
+              resume_checkpoint=None):
         """训练。
 
         Args:
@@ -137,6 +138,7 @@ class BaseClassifier(BaseAPI):
             early_stop (bool): 是否使用提前终止训练策略。默认值为False。
             early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
                 连续下降或持平,则终止训练。默认值为5。
+            resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
 
         Raises:
             ValueError: 模型从inference model进行加载。
@@ -160,8 +162,8 @@ class BaseClassifier(BaseAPI):
             pretrain_weights=pretrain_weights,
             save_dir=save_dir,
             sensitivities_file=sensitivities_file,
-            eval_metric_loss=eval_metric_loss)
-
+            eval_metric_loss=eval_metric_loss,
+            resume_checkpoint=resume_checkpoint)
         # 训练
         self.train_loop(
             num_epochs=num_epochs,

+ 7 - 3
paddlex/cv/models/deeplabv3p.py

@@ -234,7 +234,8 @@ class DeepLabv3p(BaseAPI):
               sensitivities_file=None,
               eval_metric_loss=0.05,
               early_stop=False,
-              early_stop_patience=5):
+              early_stop_patience=5,
+              resume_checkpoint=None):
         """训练。
 
         Args:
@@ -258,6 +259,7 @@ class DeepLabv3p(BaseAPI):
             early_stop (bool): 是否使用提前终止训练策略。默认值为False。
             early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
                 连续下降或持平,则终止训练。默认值为5。
+            resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
 
         Raises:
             ValueError: 模型从inference model进行加载。
@@ -284,7 +286,8 @@ class DeepLabv3p(BaseAPI):
             pretrain_weights=pretrain_weights,
             save_dir=save_dir,
             sensitivities_file=sensitivities_file,
-            eval_metric_loss=eval_metric_loss)
+            eval_metric_loss=eval_metric_loss,
+            resume_checkpoint=resume_checkpoint)
         # 训练
         self.train_loop(
             num_epochs=num_epochs,
@@ -405,5 +408,6 @@ class DeepLabv3p(BaseAPI):
                 w, h = info[1][1], info[1][0]
                 pred = pred[0:h, 0:w]
             else:
-                raise Exception("Unexpected info '{}' in im_info".format(info[0]))
+                raise Exception("Unexpected info '{}' in im_info".format(
+                    info[0]))
         return {'label_map': pred, 'score_map': result[1]}

+ 6 - 2
paddlex/cv/models/faster_rcnn.py

@@ -167,7 +167,8 @@ class FasterRCNN(BaseAPI):
               metric=None,
               use_vdl=False,
               early_stop=False,
-              early_stop_patience=5):
+              early_stop_patience=5,
+              resume_checkpoint=None):
         """训练。
 
         Args:
@@ -193,6 +194,7 @@ class FasterRCNN(BaseAPI):
             early_stop (bool): 是否使用提前终止训练策略。默认值为False。
             early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
                 连续下降或持平,则终止训练。默认值为5。
+            resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
 
         Raises:
             ValueError: 评估类型不在指定列表中。
@@ -231,7 +233,9 @@ class FasterRCNN(BaseAPI):
             startup_prog=fluid.default_startup_program(),
             pretrain_weights=pretrain_weights,
             fuse_bn=fuse_bn,
-            save_dir=save_dir)
+            save_dir=save_dir,
+            resume_checkpoint=resume_checkpoint)
+
         # 训练
         self.train_loop(
             num_epochs=num_epochs,

+ 7 - 3
paddlex/cv/models/mask_rcnn.py

@@ -132,7 +132,8 @@ class MaskRCNN(FasterRCNN):
               metric=None,
               use_vdl=False,
               early_stop=False,
-              early_stop_patience=5):
+              early_stop_patience=5,
+              resume_checkpoint=None):
         """训练。
 
         Args:
@@ -158,6 +159,7 @@ class MaskRCNN(FasterRCNN):
             early_stop (bool): 是否使用提前终止训练策略。默认值为False。
             early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
                 连续下降或持平,则终止训练。默认值为5。
+            resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
 
         Raises:
             ValueError: 评估类型不在指定列表中。
@@ -169,7 +171,8 @@ class MaskRCNN(FasterRCNN):
                 metric = 'COCO'
             else:
                 raise Exception(
-                    "train_dataset should be datasets.COCODetection or datasets.EasyDataDet.")
+                    "train_dataset should be datasets.COCODetection or datasets.EasyDataDet."
+                )
         assert metric in ['COCO', 'VOC'], "Metric only support 'VOC' or 'COCO'"
         self.metric = metric
         if not self.trainable:
@@ -197,7 +200,8 @@ class MaskRCNN(FasterRCNN):
             startup_prog=fluid.default_startup_program(),
             pretrain_weights=pretrain_weights,
             fuse_bn=fuse_bn,
-            save_dir=save_dir)
+            save_dir=save_dir,
+            resume_checkpoint=resume_checkpoint)
         # 训练
         self.train_loop(
             num_epochs=num_epochs,

+ 9 - 8
paddlex/cv/models/unet.py

@@ -121,7 +121,8 @@ class UNet(DeepLabv3p):
               sensitivities_file=None,
               eval_metric_loss=0.05,
               early_stop=False,
-              early_stop_patience=5):
+              early_stop_patience=5,
+              resume_checkpoint=None):
         """训练。
 
         Args:
@@ -145,14 +146,14 @@ class UNet(DeepLabv3p):
             early_stop (bool): 是否使用提前终止训练策略。默认值为False。
             early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
                 连续下降或持平,则终止训练。默认值为5。
+            resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
 
         Raises:
             ValueError: 模型从inference model进行加载。
         """
-        return super(
-            UNet,
-            self).train(num_epochs, train_dataset, train_batch_size,
-                        eval_dataset, save_interval_epochs, log_interval_steps,
-                        save_dir, pretrain_weights, optimizer, learning_rate,
-                        lr_decay_power, use_vdl, sensitivities_file,
-                        eval_metric_loss, early_stop, early_stop_patience)
+        return super(UNet, self).train(
+            num_epochs, train_dataset, train_batch_size, eval_dataset,
+            save_interval_epochs, log_interval_steps, save_dir,
+            pretrain_weights, optimizer, learning_rate, lr_decay_power,
+            use_vdl, sensitivities_file, eval_metric_loss, early_stop,
+            early_stop_patience, resume_checkpoint)

+ 5 - 2
paddlex/cv/models/yolo_v3.py

@@ -166,7 +166,8 @@ class YOLOv3(BaseAPI):
               sensitivities_file=None,
               eval_metric_loss=0.05,
               early_stop=False,
-              early_stop_patience=5):
+              early_stop_patience=5,
+              resume_checkpoint=None):
         """训练。
 
         Args:
@@ -195,6 +196,7 @@ class YOLOv3(BaseAPI):
             early_stop (bool): 是否使用提前终止训练策略。默认值为False。
             early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
                 连续下降或持平,则终止训练。默认值为5。
+            resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
 
         Raises:
             ValueError: 评估类型不在指定列表中。
@@ -236,7 +238,8 @@ class YOLOv3(BaseAPI):
             pretrain_weights=pretrain_weights,
             save_dir=save_dir,
             sensitivities_file=sensitivities_file,
-            eval_metric_loss=eval_metric_loss)
+            eval_metric_loss=eval_metric_loss,
+            resume_checkpoint=resume_checkpoint)
         # 训练
         self.train_loop(
             num_epochs=num_epochs,

+ 115 - 2
paddlex/utils/utils.py

@@ -170,11 +170,85 @@ def load_pdparams(exe, main_prog, model_dir):
             len(vars_to_load), model_dir))
 
 
-def load_pretrain_weights(exe, main_prog, weights_dir, fuse_bn=False):
+def is_persistable(var):
+    import paddle.fluid as fluid
+    from paddle.fluid.proto.framework_pb2 import VarType
+
+    if var.desc.type() == fluid.core.VarDesc.VarType.FEED_MINIBATCH or \
+        var.desc.type() == fluid.core.VarDesc.VarType.FETCH_LIST or \
+        var.desc.type() == fluid.core.VarDesc.VarType.READER:
+        return False
+    return var.persistable
+
+
+def is_belong_to_optimizer(var):
+    import paddle.fluid as fluid
+    from paddle.fluid.proto.framework_pb2 import VarType
+
+    if not (isinstance(var, fluid.framework.Parameter)
+            or var.desc.need_check_feed()):
+        return is_persistable(var)
+    return False
+
+
+def load_pdopt(exe, main_prog, model_dir):
+    import paddle.fluid as fluid
+
+    optimizer_var_list = list()
+    vars_to_load = list()
+    import pickle
+    with open(osp.join(model_dir, 'model.pdopt'), 'rb') as f:
+        opt_dict = pickle.load(f) if six.PY2 else pickle.load(
+            f, encoding='latin1')
+    optimizer_var_list = list(
+        filter(is_belong_to_optimizer, main_prog.list_vars()))
+    exception_message = "the training process can not be resumed due to optimizer set now and last time is different. Recommend to use `pretrain_weights` instead of `resume_checkpoint`"
+    if len(optimizer_var_list) > 0:
+        for var in optimizer_var_list:
+            if var.name not in opt_dict:
+                raise Exception(
+                    "{} is not in saved paddlex optimizer, {}".format(
+                        var.name, exception_message))
+            if var.shape != opt_dict[var.name].shape:
+                raise Exception(
+                    "Shape of optimizer variable {} doesn't match.(Last: {}, Now: {}), {}"
+                    .format(var.name, opt_dict[var.name].shape,
+                            var.shape), exception_message)
+        optimizer_varname_list = [var.name for var in optimizer_var_list]
+        for k, v in opt_dict.items():
+            if k not in optimizer_varname_list:
+                raise Exception(
+                    "{} in saved paddlex optimizer is not in the model, {}".
+                    format(k, exception_message))
+        fluid.io.set_program_state(main_prog, opt_dict)
+
+    if len(optimizer_var_list) == 0:
+        raise Exception(
+            "There is no optimizer parameters in the model, please set the optimizer!"
+        )
+    else:
+        logging.info(
+            "There are {} optimizer parameters in {} are loaded.".format(
+                len(optimizer_var_list), model_dir))
+
+
+def load_pretrain_weights(exe,
+                          main_prog,
+                          weights_dir,
+                          fuse_bn=False,
+                          resume=False):
     if not osp.exists(weights_dir):
         raise Exception("Path {} not exists.".format(weights_dir))
     if osp.exists(osp.join(weights_dir, "model.pdparams")):
-        return load_pdparams(exe, main_prog, weights_dir)
+        load_pdparams(exe, main_prog, weights_dir)
+        if resume:
+            if osp.exists(osp.join(weights_dir, "model.pdopt")):
+                load_pdopt(exe, main_prog, weights_dir)
+            else:
+                raise Exception(
+                    "Optimizer file {} does not exist. Stop resumming training. Recommend to use `pretrain_weights` instead of `resume_checkpoint`"
+                    .format(osp.join(weights_dir, "model.pdopt")))
+        return
     import paddle.fluid as fluid
     vars_to_load = list()
     for var in main_prog.list_vars():
@@ -209,6 +283,45 @@ def load_pretrain_weights(exe, main_prog, weights_dir, fuse_bn=False):
             len(vars_to_load), weights_dir))
     if fuse_bn:
         fuse_bn_weights(exe, main_prog, weights_dir)
+    if resume:
+        exception_message = "the training process can not be resumed due to optimizer set now and last time is different. Recommend to use `pretrain_weights` instead of `resume_checkpoint`"
+        optimizer_var_list = list(
+            filter(is_belong_to_optimizer, main_prog.list_vars()))
+        if len(optimizer_var_list) > 0:
+            for var in optimizer_var_list:
+                if not osp.exists(osp.join(weights_dir, var.name)):
+                    raise Exception(
+                        "Optimizer parameter {} doesn't exist, {}".format(
+                            osp.join(weights_dir, var.name),
+                            exception_message))
+                pretrained_shape = parse_param_file(
+                    osp.join(weights_dir, var.name))
+                actual_shape = tuple(var.shape)
+                if pretrained_shape != actual_shape:
+                    raise Exception(
+                        "Shape of optimizer variable {} doesn't match.(Last: {}, Now: {}), {}"
+                        .format(var.name, pretrained_shape,
+                                actual_shape), exception_message)
+            optimizer_varname_list = [var.name for var in optimizer_var_list]
+            if os.exists(osp.join(weights_dir, 'learning_rate')
+                         ) and 'learning_rate' not in optimizer_varname_list:
+                raise Exception(
+                    "Optimizer parameter {}/learning_rate is not in the model, {}"
+                    .format(weights_dir, exception_message))
+            fluid.io.load_vars(
+                executor=exe,
+                dirname=weights_dir,
+                main_program=main_prog,
+                vars=optimizer_var_list)
+
+        if len(optimizer_var_list) == 0:
+            raise Exception(
+                "There is no optimizer parameters in the model, please set the optimizer!"
+            )
+        else:
+            logging.info(
+                "There are {} optimizer parameters in {} are loaded.".format(
+                    len(optimizer_var_list), weights_dir))
 
 
 class EarlyStop: