瀏覽代碼

resume best acc and best acc epoch when resume training

will-jl944 3 年之前
父節點
當前提交
f1dd517173
共有 1 個文件被更改,包括 18 次插入15 次删除
  1. 18 15
      paddlex/cv/models/base.py

+ 18 - 15
paddlex/cv/models/base.py

@@ -48,6 +48,8 @@ class BaseModel:
         self.train_data_loader = None
         self.eval_data_loader = None
         self.eval_metrics = None
+        self.best_accuracy = -1.
+        self.best_model_epoch = -1
         # 是否使用多卡间同步BatchNorm均值和方差
         self.sync_bn = False
         self.status = 'Normal'
@@ -115,6 +117,8 @@ class BaseModel:
             with open(osp.join(resume_checkpoint, "model.yml")) as f:
                 info = yaml.load(f.read(), Loader=yaml.Loader)
                 self.completed_epochs = info['completed_epochs']
+                self.best_accuracy = info['_Attributes']['best_accuracy']
+                self.best_model_epoch = info['_Attributes']['best_model_epoch']
             load_checkpoint(
                 self.net,
                 self.optimizer,
@@ -125,7 +129,12 @@ class BaseModel:
         info = dict()
         info['version'] = paddlex.__version__
         info['Model'] = self.__class__.__name__
-        info['_Attributes'] = {'model_type': self.model_type}
+        info['_Attributes'] = dict(
+            [('model_type', self.model_type),
+             ('num_classes', self.num_classes), ('labels', self.labels),
+             ('fixed_input_shape', self.fixed_input_shape),
+             ('best_accuracy', self.best_accuracy),
+             ('best_model_epoch', self.best_model_epoch)])
         if 'self' in self.init_params:
             del self.init_params['self']
         if '__class__' in self.init_params:
@@ -137,10 +146,6 @@ class BaseModel:
 
         info['_init_params'] = self.init_params
 
-        info['_Attributes']['num_classes'] = self.num_classes
-        info['_Attributes']['labels'] = self.labels
-        info['_Attributes']['fixed_input_shape'] = self.fixed_input_shape
-
         try:
             primary_metric_key = list(self.eval_metrics.keys())[0]
             primary_metric_value = float(self.eval_metrics[primary_metric_key])
@@ -317,9 +322,6 @@ class BaseModel:
             eval_batch_size = train_batch_size
             eval_epoch_time = 0
 
-        best_accuracy_key = ""
-        best_accuracy = -1.0
-        best_model_epoch = -1
         current_step = 0
         for i in range(start_epoch, num_epochs):
             self.net.train()
@@ -384,11 +386,12 @@ class BaseModel:
                          .format(i + 1, train_avg_metrics.log()))
             self.completed_epochs += 1
 
-            # 每间隔save_interval_epochs, 在验证集上评估和对模型进行保存
             if ema is not None:
                 weight = copy.deepcopy(self.net.state_dict())
                 self.net.set_state_dict(ema.apply())
             eval_epoch_tic = time.time()
+
+            # 每间隔save_interval_epochs, 在验证集上评估和对模型进行保存
             if (i + 1) % save_interval_epochs == 0 or i == num_epochs - 1:
                 if eval_dataset is not None and eval_dataset.num_samples > 0:
                     eval_result = self.evaluate(
@@ -410,16 +413,16 @@ class BaseModel:
                             i + 1, dict2str(self.eval_metrics)))
                         best_accuracy_key = list(self.eval_metrics.keys())[0]
                         current_accuracy = self.eval_metrics[best_accuracy_key]
-                        if current_accuracy > best_accuracy:
-                            best_accuracy = current_accuracy
-                            best_model_epoch = i + 1
+                        if current_accuracy > self.best_accuracy:
+                            self.best_accuracy = current_accuracy
+                            self.best_model_epoch = i + 1
                             best_model_dir = osp.join(save_dir, "best_model")
                             self.save_model(save_dir=best_model_dir)
-                        if best_model_epoch > 0:
+                        if self.best_model_epoch > 0:
                             logging.info(
                                 'Current evaluated best model on eval_dataset is epoch_{}, {}={}'
-                                .format(best_model_epoch, best_accuracy_key,
-                                        best_accuracy))
+                                .format(self.best_model_epoch,
+                                        best_accuracy_key, self.best_accuracy))
                     eval_epoch_time = time.time() - eval_epoch_tic
 
                 current_save_dir = osp.join(save_dir, "epoch_{}".format(i + 1))