|
@@ -48,6 +48,8 @@ class BaseModel:
|
|
|
self.train_data_loader = None
|
|
self.train_data_loader = None
|
|
|
self.eval_data_loader = None
|
|
self.eval_data_loader = None
|
|
|
self.eval_metrics = None
|
|
self.eval_metrics = None
|
|
|
|
|
+ self.best_accuracy = -1.
|
|
|
|
|
+ self.best_model_epoch = -1
|
|
|
# 是否使用多卡间同步BatchNorm均值和方差
|
|
# 是否使用多卡间同步BatchNorm均值和方差
|
|
|
self.sync_bn = False
|
|
self.sync_bn = False
|
|
|
self.status = 'Normal'
|
|
self.status = 'Normal'
|
|
@@ -115,6 +117,8 @@ class BaseModel:
|
|
|
with open(osp.join(resume_checkpoint, "model.yml")) as f:
|
|
with open(osp.join(resume_checkpoint, "model.yml")) as f:
|
|
|
info = yaml.load(f.read(), Loader=yaml.Loader)
|
|
info = yaml.load(f.read(), Loader=yaml.Loader)
|
|
|
self.completed_epochs = info['completed_epochs']
|
|
self.completed_epochs = info['completed_epochs']
|
|
|
|
|
+ self.best_accuracy = info['_Attributes']['best_accuracy']
|
|
|
|
|
+ self.best_model_epoch = info['_Attributes']['best_model_epoch']
|
|
|
load_checkpoint(
|
|
load_checkpoint(
|
|
|
self.net,
|
|
self.net,
|
|
|
self.optimizer,
|
|
self.optimizer,
|
|
@@ -125,7 +129,12 @@ class BaseModel:
|
|
|
info = dict()
|
|
info = dict()
|
|
|
info['version'] = paddlex.__version__
|
|
info['version'] = paddlex.__version__
|
|
|
info['Model'] = self.__class__.__name__
|
|
info['Model'] = self.__class__.__name__
|
|
|
- info['_Attributes'] = {'model_type': self.model_type}
|
|
|
|
|
|
|
+ info['_Attributes'] = dict(
|
|
|
|
|
+ [('model_type', self.model_type),
|
|
|
|
|
+ ('num_classes', self.num_classes), ('labels', self.labels),
|
|
|
|
|
+ ('fixed_input_shape', self.fixed_input_shape),
|
|
|
|
|
+ ('best_accuracy', self.best_accuracy),
|
|
|
|
|
+ ('best_model_epoch', self.best_model_epoch)])
|
|
|
if 'self' in self.init_params:
|
|
if 'self' in self.init_params:
|
|
|
del self.init_params['self']
|
|
del self.init_params['self']
|
|
|
if '__class__' in self.init_params:
|
|
if '__class__' in self.init_params:
|
|
@@ -137,10 +146,6 @@ class BaseModel:
|
|
|
|
|
|
|
|
info['_init_params'] = self.init_params
|
|
info['_init_params'] = self.init_params
|
|
|
|
|
|
|
|
- info['_Attributes']['num_classes'] = self.num_classes
|
|
|
|
|
- info['_Attributes']['labels'] = self.labels
|
|
|
|
|
- info['_Attributes']['fixed_input_shape'] = self.fixed_input_shape
|
|
|
|
|
-
|
|
|
|
|
try:
|
|
try:
|
|
|
primary_metric_key = list(self.eval_metrics.keys())[0]
|
|
primary_metric_key = list(self.eval_metrics.keys())[0]
|
|
|
primary_metric_value = float(self.eval_metrics[primary_metric_key])
|
|
primary_metric_value = float(self.eval_metrics[primary_metric_key])
|
|
@@ -317,9 +322,6 @@ class BaseModel:
|
|
|
eval_batch_size = train_batch_size
|
|
eval_batch_size = train_batch_size
|
|
|
eval_epoch_time = 0
|
|
eval_epoch_time = 0
|
|
|
|
|
|
|
|
- best_accuracy_key = ""
|
|
|
|
|
- best_accuracy = -1.0
|
|
|
|
|
- best_model_epoch = -1
|
|
|
|
|
current_step = 0
|
|
current_step = 0
|
|
|
for i in range(start_epoch, num_epochs):
|
|
for i in range(start_epoch, num_epochs):
|
|
|
self.net.train()
|
|
self.net.train()
|
|
@@ -384,11 +386,12 @@ class BaseModel:
|
|
|
.format(i + 1, train_avg_metrics.log()))
|
|
.format(i + 1, train_avg_metrics.log()))
|
|
|
self.completed_epochs += 1
|
|
self.completed_epochs += 1
|
|
|
|
|
|
|
|
- # 每间隔save_interval_epochs, 在验证集上评估和对模型进行保存
|
|
|
|
|
if ema is not None:
|
|
if ema is not None:
|
|
|
weight = copy.deepcopy(self.net.state_dict())
|
|
weight = copy.deepcopy(self.net.state_dict())
|
|
|
self.net.set_state_dict(ema.apply())
|
|
self.net.set_state_dict(ema.apply())
|
|
|
eval_epoch_tic = time.time()
|
|
eval_epoch_tic = time.time()
|
|
|
|
|
+
|
|
|
|
|
+ # 每间隔save_interval_epochs, 在验证集上评估和对模型进行保存
|
|
|
if (i + 1) % save_interval_epochs == 0 or i == num_epochs - 1:
|
|
if (i + 1) % save_interval_epochs == 0 or i == num_epochs - 1:
|
|
|
if eval_dataset is not None and eval_dataset.num_samples > 0:
|
|
if eval_dataset is not None and eval_dataset.num_samples > 0:
|
|
|
eval_result = self.evaluate(
|
|
eval_result = self.evaluate(
|
|
@@ -410,16 +413,16 @@ class BaseModel:
|
|
|
i + 1, dict2str(self.eval_metrics)))
|
|
i + 1, dict2str(self.eval_metrics)))
|
|
|
best_accuracy_key = list(self.eval_metrics.keys())[0]
|
|
best_accuracy_key = list(self.eval_metrics.keys())[0]
|
|
|
current_accuracy = self.eval_metrics[best_accuracy_key]
|
|
current_accuracy = self.eval_metrics[best_accuracy_key]
|
|
|
- if current_accuracy > best_accuracy:
|
|
|
|
|
- best_accuracy = current_accuracy
|
|
|
|
|
- best_model_epoch = i + 1
|
|
|
|
|
|
|
+ if current_accuracy > self.best_accuracy:
|
|
|
|
|
+ self.best_accuracy = current_accuracy
|
|
|
|
|
+ self.best_model_epoch = i + 1
|
|
|
best_model_dir = osp.join(save_dir, "best_model")
|
|
best_model_dir = osp.join(save_dir, "best_model")
|
|
|
self.save_model(save_dir=best_model_dir)
|
|
self.save_model(save_dir=best_model_dir)
|
|
|
- if best_model_epoch > 0:
|
|
|
|
|
|
|
+ if self.best_model_epoch > 0:
|
|
|
logging.info(
|
|
logging.info(
|
|
|
'Current evaluated best model on eval_dataset is epoch_{}, {}={}'
|
|
'Current evaluated best model on eval_dataset is epoch_{}, {}={}'
|
|
|
- .format(best_model_epoch, best_accuracy_key,
|
|
|
|
|
- best_accuracy))
|
|
|
|
|
|
|
+ .format(self.best_model_epoch,
|
|
|
|
|
+ best_accuracy_key, self.best_accuracy))
|
|
|
eval_epoch_time = time.time() - eval_epoch_tic
|
|
eval_epoch_time = time.time() - eval_epoch_tic
|
|
|
|
|
|
|
|
current_save_dir = osp.join(save_dir, "epoch_{}".format(i + 1))
|
|
current_save_dir = osp.join(save_dir, "epoch_{}".format(i + 1))
|