Преглед на файлове

add resume training from a checkpoint

FlyingQianMM преди 5 години
родител
ревизия
31bf269ed7

+ 12 - 6
docs/apis/models.md

@@ -17,7 +17,7 @@ paddlex.cls.ResNet50(num_classes=1000)
 #### 分类器训练函数接口
 
 > ```python
-> train(self, num_epochs, train_dataset, train_batch_size=64, eval_dataset=None, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.025, lr_decay_epochs=[30, 60, 90], lr_decay_gamma=0.1, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5)
+> train(self, num_epochs, train_dataset, train_batch_size=64, eval_dataset=None, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.025, lr_decay_epochs=[30, 60, 90], lr_decay_gamma=0.1, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5, resume_checkpoint=None)
 > ```
 >
 > **参数:**
@@ -39,6 +39,7 @@ paddlex.cls.ResNet50(num_classes=1000)
 > > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
 > > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
 > > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
+> > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
 
 #### 分类器评估函数接口
 
@@ -111,7 +112,7 @@ paddlex.det.YOLOv3(num_classes=80, backbone='MobileNetV1', anchors=None, anchor_
 #### YOLOv3训练函数接口
 
 > ```python
-> train(self, num_epochs, train_dataset, train_batch_size=8, eval_dataset=None, save_interval_epochs=20, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=1.0/8000, warmup_steps=1000, warmup_start_lr=0.0, lr_decay_epochs=[213, 240], lr_decay_gamma=0.1, metric=None, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5)
+> train(self, num_epochs, train_dataset, train_batch_size=8, eval_dataset=None, save_interval_epochs=20, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=1.0/8000, warmup_steps=1000, warmup_start_lr=0.0, lr_decay_epochs=[213, 240], lr_decay_gamma=0.1, metric=None, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5, resume_checkpoint=None)
 > ```
 >
 > **参数:**
@@ -136,6 +137,7 @@ paddlex.det.YOLOv3(num_classes=80, backbone='MobileNetV1', anchors=None, anchor_
 > > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
 > > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
 > > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
+> > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
 
 #### YOLOv3评估函数接口
 
@@ -190,7 +192,7 @@ paddlex.det.FasterRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspec
 #### FasterRCNN训练函数接口
 
 > ```python
-> train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, save_interval_epochs=1, log_interval_steps=2,save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.0025, warmup_steps=500, warmup_start_lr=1.0/1200, lr_decay_epochs=[8, 11], lr_decay_gamma=0.1, metric=None, use_vdl=False, early_stop=False, early_stop_patience=5)
+> train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, save_interval_epochs=1, log_interval_steps=2,save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.0025, warmup_steps=500, warmup_start_lr=1.0/1200, lr_decay_epochs=[8, 11], lr_decay_gamma=0.1, metric=None, use_vdl=False, early_stop=False, early_stop_patience=5, resume_checkpoint=None)
 >
 > ```
 >
@@ -214,6 +216,7 @@ paddlex.det.FasterRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspec
 > > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认值为False。
 > > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
 > > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
+> > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
 
 #### FasterRCNN评估函数接口
 
@@ -270,7 +273,7 @@ paddlex.det.MaskRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspect_
 #### MaskRCNN训练函数接口
 
 > ```python
-> train(self, num_epochs, train_dataset, train_batch_size=1, eval_dataset=None, save_interval_epochs=1, log_interval_steps=20, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=1.0/800, warmup_steps=500, warmup_start_lr=1.0 / 2400, lr_decay_epochs=[8, 11], lr_decay_gamma=0.1, metric=None, use_vdl=False, early_stop=False, early_stop_patience=5)
+> train(self, num_epochs, train_dataset, train_batch_size=1, eval_dataset=None, save_interval_epochs=1, log_interval_steps=20, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=1.0/800, warmup_steps=500, warmup_start_lr=1.0 / 2400, lr_decay_epochs=[8, 11], lr_decay_gamma=0.1, metric=None, use_vdl=False, early_stop=False, early_stop_patience=5, resume_checkpoint=None)
 >
 > ```
 >
@@ -294,6 +297,7 @@ paddlex.det.MaskRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspect_
 > > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认值为False。
 > > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
 > > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
+> > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
 
 #### MaskRCNN评估函数接口
 
@@ -358,7 +362,7 @@ paddlex.seg.DeepLabv3p(num_classes=2, backbone='MobileNetV2_x1.0', output_stride
 #### DeepLabv3训练函数接口
 
 > ```python
-> train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, eval_batch_size=1, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.01, lr_decay_power=0.9, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5):
+> train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, eval_batch_size=1, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.01, lr_decay_power=0.9, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5, resume_checkpoint=None):
 >
 > ```
 >
@@ -380,6 +384,7 @@ paddlex.seg.DeepLabv3p(num_classes=2, backbone='MobileNetV2_x1.0', output_stride
 > > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
 > > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
 > > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
+> > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
 
 #### DeepLabv3评估函数接口
 
@@ -437,7 +442,7 @@ paddlex.seg.UNet(num_classes=2, upsample_mode='bilinear', use_bce_loss=False, us
 #### Unet训练函数接口
 
 > ```python
-> train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, eval_batch_size=1, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='COCO', optimizer=None, learning_rate=0.01, lr_decay_power=0.9, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5):
+> train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, eval_batch_size=1, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='COCO', optimizer=None, learning_rate=0.01, lr_decay_power=0.9, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5, resume_checkpoint=None):
 > ```
 >
 > **参数:**
@@ -458,6 +463,7 @@ paddlex.seg.UNet(num_classes=2, upsample_mode='bilinear', use_bce_loss=False, us
 > > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
 > > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
 > > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
+> > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
 
 #### Unet评估函数接口
 

+ 13 - 1
paddlex/cv/models/base.py

@@ -213,6 +213,17 @@ class BaseAPI:
             prune_program(self, prune_params_ratios)
             self.status = 'Prune'
 
+    def resume_checkpoint(self, path, startup_prog=None):
+        if not osp.isdir(path):
+            raise Exception("Model pretrain path {} does not "
+                            "exists.".format(path))
+        if osp.exists(osp.join(path, 'model.pdparams')):
+            path = osp.join(path, 'model')
+        if startup_prog is None:
+            startup_prog = fluid.default_startup_program()
+        self.exe.run(startup_prog)
+        fluid.load(self.train_prog, path, executor=self.exe)
+
     def get_model_info(self):
         info = dict()
         info['version'] = paddlex.__version__
@@ -334,6 +345,7 @@ class BaseAPI:
                    num_epochs,
                    train_dataset,
                    train_batch_size,
+                   start_epoch=0,
                    eval_dataset=None,
                    save_interval_epochs=1,
                    log_interval_steps=10,
@@ -408,7 +420,7 @@ class BaseAPI:
         best_accuracy_key = ""
         best_accuracy = -1.0
         best_model_epoch = 1
-        for i in range(num_epochs):
+        for i in range(start_epoch, num_epochs):
             records = list()
             step_start_time = time.time()
             epoch_start_time = time.time()

+ 21 - 7
paddlex/cv/models/classifier.py

@@ -112,7 +112,8 @@ class BaseClassifier(BaseAPI):
               sensitivities_file=None,
               eval_metric_loss=0.05,
               early_stop=False,
-              early_stop_patience=5):
+              early_stop_patience=5,
+              resume_checkpoint=None):
         """训练。
 
         Args:
@@ -137,6 +138,7 @@ class BaseClassifier(BaseAPI):
             early_stop (bool): 是否使用提前终止训练策略。默认值为False。
             early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
                 连续下降或持平,则终止训练。默认值为5。
+            resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
 
         Raises:
             ValueError: 模型从inference model进行加载。
@@ -155,15 +157,27 @@ class BaseClassifier(BaseAPI):
         # 构建训练、验证、预测网络
         self.build_program()
         # 初始化网络权重
-        self.net_initialize(
-            startup_prog=fluid.default_startup_program(),
-            pretrain_weights=pretrain_weights,
-            save_dir=save_dir,
-            sensitivities_file=sensitivities_file,
-            eval_metric_loss=eval_metric_loss)
+        if resume_checkpoint:
+            self.resume_checkpoint(
+                path=resume_checkpoint,
+                startup_prog=fluid.default_startup_program())
+            scope = fluid.global_scope()
+            v = scope.find_var('@LR_DECAY_COUNTER@')
+            step = np.array(v.get_tensor())[0] if v else 0
+            num_steps_each_epoch = train_dataset.num_samples // train_batch_size
+            start_epoch = step // num_steps_each_epoch + 1
+        else:
+            self.net_initialize(
+                startup_prog=fluid.default_startup_program(),
+                pretrain_weights=pretrain_weights,
+                save_dir=save_dir,
+                sensitivities_file=sensitivities_file,
+                eval_metric_loss=eval_metric_loss)
+            start_epoch = 0
 
         # 训练
         self.train_loop(
+            start_epoch=start_epoch,
             num_epochs=num_epochs,
             train_dataset=train_dataset,
             train_batch_size=train_batch_size,

+ 24 - 8
paddlex/cv/models/deeplabv3p.py

@@ -234,7 +234,8 @@ class DeepLabv3p(BaseAPI):
               sensitivities_file=None,
               eval_metric_loss=0.05,
               early_stop=False,
-              early_stop_patience=5):
+              early_stop_patience=5,
+              resume_checkpoint=None):
         """训练。
 
         Args:
@@ -258,6 +259,7 @@ class DeepLabv3p(BaseAPI):
             early_stop (bool): 是否使用提前终止训练策略。默认值为False。
             early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
                 连续下降或持平,则终止训练。默认值为5。
+            resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
 
         Raises:
             ValueError: 模型从inference model进行加载。
@@ -279,14 +281,27 @@ class DeepLabv3p(BaseAPI):
         # 构建训练、验证、预测网络
         self.build_program()
         # 初始化网络权重
-        self.net_initialize(
-            startup_prog=fluid.default_startup_program(),
-            pretrain_weights=pretrain_weights,
-            save_dir=save_dir,
-            sensitivities_file=sensitivities_file,
-            eval_metric_loss=eval_metric_loss)
+        if resume_checkpoint:
+            self.resume_checkpoint(
+                path=resume_checkpoint,
+                startup_prog=fluid.default_startup_program())
+            scope = fluid.global_scope()
+            v = scope.find_var('@LR_DECAY_COUNTER@')
+            step = np.array(v.get_tensor())[0] if v else 0
+            num_steps_each_epoch = train_dataset.num_samples // train_batch_size
+            start_epoch = step // num_steps_each_epoch + 1
+        else:
+            self.net_initialize(
+                startup_prog=fluid.default_startup_program(),
+                pretrain_weights=pretrain_weights,
+                save_dir=save_dir,
+                sensitivities_file=sensitivities_file,
+                eval_metric_loss=eval_metric_loss)
+            start_epoch = 0
+
         # 训练
         self.train_loop(
+            start_epoch=start_epoch,
             num_epochs=num_epochs,
             train_dataset=train_dataset,
             train_batch_size=train_batch_size,
@@ -405,5 +420,6 @@ class DeepLabv3p(BaseAPI):
                 w, h = info[1][1], info[1][0]
                 pred = pred[0:h, 0:w]
             else:
-                raise Exception("Unexpected info '{}' in im_info".format(info[0]))
+                raise Exception("Unexpected info '{}' in im_info".format(
+                    info[0]))
         return {'label_map': pred, 'score_map': result[1]}

+ 21 - 6
paddlex/cv/models/faster_rcnn.py

@@ -167,7 +167,8 @@ class FasterRCNN(BaseAPI):
               metric=None,
               use_vdl=False,
               early_stop=False,
-              early_stop_patience=5):
+              early_stop_patience=5,
+              resume_checkpoint=None):
         """训练。
 
         Args:
@@ -193,6 +194,7 @@ class FasterRCNN(BaseAPI):
             early_stop (bool): 是否使用提前终止训练策略。默认值为False。
             early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
                 连续下降或持平,则终止训练。默认值为5。
+            resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
 
         Raises:
             ValueError: 评估类型不在指定列表中。
@@ -227,13 +229,26 @@ class FasterRCNN(BaseAPI):
         fuse_bn = True
         if self.with_fpn and self.backbone in ['ResNet18', 'ResNet50']:
             fuse_bn = False
-        self.net_initialize(
-            startup_prog=fluid.default_startup_program(),
-            pretrain_weights=pretrain_weights,
-            fuse_bn=fuse_bn,
-            save_dir=save_dir)
+        if resume_checkpoint:
+            self.resume_checkpoint(
+                path=resume_checkpoint,
+                startup_prog=fluid.default_startup_program())
+            scope = fluid.global_scope()
+            v = scope.find_var('@LR_DECAY_COUNTER@')
+            step = np.array(v.get_tensor())[0] if v else 0
+            num_steps_each_epoch = train_dataset.num_samples // train_batch_size
+            start_epoch = step // num_steps_each_epoch + 1
+        else:
+            self.net_initialize(
+                startup_prog=fluid.default_startup_program(),
+                pretrain_weights=pretrain_weights,
+                fuse_bn=fuse_bn,
+                save_dir=save_dir)
+            start_epoch = 0
+
         # 训练
         self.train_loop(
+            start_epoch=start_epoch,
             num_epochs=num_epochs,
             train_dataset=train_dataset,
             train_batch_size=train_batch_size,

+ 23 - 7
paddlex/cv/models/mask_rcnn.py

@@ -132,7 +132,8 @@ class MaskRCNN(FasterRCNN):
               metric=None,
               use_vdl=False,
               early_stop=False,
-              early_stop_patience=5):
+              early_stop_patience=5,
+              resume_checkpoint=None):
         """训练。
 
         Args:
@@ -158,6 +159,7 @@ class MaskRCNN(FasterRCNN):
             early_stop (bool): 是否使用提前终止训练策略。默认值为False。
             early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
                 连续下降或持平,则终止训练。默认值为5。
+            resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
 
         Raises:
             ValueError: 评估类型不在指定列表中。
@@ -169,7 +171,8 @@ class MaskRCNN(FasterRCNN):
                 metric = 'COCO'
             else:
                 raise Exception(
-                    "train_dataset should be datasets.COCODetection or datasets.EasyDataDet.")
+                    "train_dataset should be datasets.COCODetection or datasets.EasyDataDet."
+                )
         assert metric in ['COCO', 'VOC'], "Metric only support 'VOC' or 'COCO'"
         self.metric = metric
         if not self.trainable:
@@ -193,13 +196,26 @@ class MaskRCNN(FasterRCNN):
         fuse_bn = True
         if self.with_fpn and self.backbone in ['ResNet18', 'ResNet50']:
             fuse_bn = False
-        self.net_initialize(
-            startup_prog=fluid.default_startup_program(),
-            pretrain_weights=pretrain_weights,
-            fuse_bn=fuse_bn,
-            save_dir=save_dir)
+        if resume_checkpoint:
+            self.resume_checkpoint(
+                path=resume_checkpoint,
+                startup_prog=fluid.default_startup_program())
+            scope = fluid.global_scope()
+            v = scope.find_var('@LR_DECAY_COUNTER@')
+            step = np.array(v.get_tensor())[0] if v else 0
+            num_steps_each_epoch = train_dataset.num_samples // train_batch_size
+            start_epoch = step // num_steps_each_epoch + 1
+        else:
+            self.net_initialize(
+                startup_prog=fluid.default_startup_program(),
+                pretrain_weights=pretrain_weights,
+                fuse_bn=fuse_bn,
+                save_dir=save_dir)
+            start_epoch = 0
+
         # 训练
         self.train_loop(
+            start_epoch=start_epoch,
             num_epochs=num_epochs,
             train_dataset=train_dataset,
             train_batch_size=train_batch_size,

+ 9 - 8
paddlex/cv/models/unet.py

@@ -121,7 +121,8 @@ class UNet(DeepLabv3p):
               sensitivities_file=None,
               eval_metric_loss=0.05,
               early_stop=False,
-              early_stop_patience=5):
+              early_stop_patience=5,
+              resume_checkpoint=None):
         """训练。
 
         Args:
@@ -145,14 +146,14 @@ class UNet(DeepLabv3p):
             early_stop (bool): 是否使用提前终止训练策略。默认值为False。
             early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
                 连续下降或持平,则终止训练。默认值为5。
+            resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
 
         Raises:
             ValueError: 模型从inference model进行加载。
         """
-        return super(
-            UNet,
-            self).train(num_epochs, train_dataset, train_batch_size,
-                        eval_dataset, save_interval_epochs, log_interval_steps,
-                        save_dir, pretrain_weights, optimizer, learning_rate,
-                        lr_decay_power, use_vdl, sensitivities_file,
-                        eval_metric_loss, early_stop, early_stop_patience)
+        return super(UNet, self).train(
+            num_epochs, train_dataset, train_batch_size, eval_dataset,
+            save_interval_epochs, log_interval_steps, save_dir,
+            pretrain_weights, optimizer, learning_rate, lr_decay_power,
+            use_vdl, sensitivities_file, eval_metric_loss, early_stop,
+            early_stop_patience, resume_checkpoint)

+ 22 - 7
paddlex/cv/models/yolo_v3.py

@@ -166,7 +166,8 @@ class YOLOv3(BaseAPI):
               sensitivities_file=None,
               eval_metric_loss=0.05,
               early_stop=False,
-              early_stop_patience=5):
+              early_stop_patience=5,
+              resume_checkpoint=None):
         """训练。
 
         Args:
@@ -195,6 +196,7 @@ class YOLOv3(BaseAPI):
             early_stop (bool): 是否使用提前终止训练策略。默认值为False。
             early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
                 连续下降或持平,则终止训练。默认值为5。
+            resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
 
         Raises:
             ValueError: 评估类型不在指定列表中。
@@ -231,14 +233,27 @@ class YOLOv3(BaseAPI):
         # 构建训练、验证、预测网络
         self.build_program()
         # 初始化网络权重
-        self.net_initialize(
-            startup_prog=fluid.default_startup_program(),
-            pretrain_weights=pretrain_weights,
-            save_dir=save_dir,
-            sensitivities_file=sensitivities_file,
-            eval_metric_loss=eval_metric_loss)
+        if resume_checkpoint:
+            self.resume_checkpoint(
+                path=resume_checkpoint,
+                startup_prog=fluid.default_startup_program())
+            scope = fluid.global_scope()
+            v = scope.find_var('@LR_DECAY_COUNTER@')
+            step = np.array(v.get_tensor())[0] if v else 0
+            num_steps_each_epoch = train_dataset.num_samples // train_batch_size
+            start_epoch = step // num_steps_each_epoch + 1
+        else:
+            self.net_initialize(
+                startup_prog=fluid.default_startup_program(),
+                pretrain_weights=pretrain_weights,
+                save_dir=save_dir,
+                sensitivities_file=sensitivities_file,
+                eval_metric_loss=eval_metric_loss)
+            start_epoch = 0
+
         # 训练
         self.train_loop(
+            start_epoch=start_epoch,
             num_epochs=num_epochs,
             train_dataset=train_dataset,
             train_batch_size=train_batch_size,