|
|
@@ -333,12 +333,13 @@ class BaseModel:
|
|
|
eval_epoch_tic = time.time()
|
|
|
if (i + 1) % save_interval_epochs == 0 or i == num_epochs - 1:
|
|
|
if eval_dataset is not None and eval_dataset.num_samples > 0:
|
|
|
- self.eval_metrics, self.eval_details = self.evaluate(
|
|
|
+ eval_result = self.evaluate(
|
|
|
eval_dataset,
|
|
|
batch_size=eval_batch_size,
|
|
|
return_details=True)
|
|
|
# 保存最优模型
|
|
|
if local_rank == 0:
|
|
|
+ self.eval_metrics, self.eval_details = eval_result
|
|
|
logging.info('[EVAL] Finished, Epoch={}, {} .'.format(
|
|
|
i + 1, dict2str(self.eval_metrics)))
|
|
|
best_accuracy_key = list(self.eval_metrics.keys())[0]
|