|
|
@@ -417,7 +417,7 @@ class BaseAPI:
|
|
|
earlystop = EarlyStop(early_stop_patience, thresh)
|
|
|
best_accuracy_key = ""
|
|
|
best_accuracy = -1.0
|
|
|
- best_model_epoch = 1
|
|
|
+ best_model_epoch = -1
|
|
|
for i in range(num_epochs):
|
|
|
records = list()
|
|
|
step_start_time = time.time()
|
|
|
@@ -490,7 +490,7 @@ class BaseAPI:
|
|
|
current_save_dir = osp.join(save_dir, "epoch_{}".format(i + 1))
|
|
|
if not osp.isdir(current_save_dir):
|
|
|
os.makedirs(current_save_dir)
|
|
|
- if eval_dataset is not None:
|
|
|
+ if eval_dataset is not None and eval_dataset.num_samples > 0:
|
|
|
self.eval_metrics, self.eval_details = self.evaluate(
|
|
|
eval_dataset=eval_dataset,
|
|
|
batch_size=eval_batch_size,
|
|
|
@@ -522,10 +522,11 @@ class BaseAPI:
|
|
|
self.save_model(save_dir=current_save_dir)
|
|
|
time_eval_one_epoch = time.time() - eval_epoch_start_time
|
|
|
eval_epoch_start_time = time.time()
|
|
|
- logging.info(
|
|
|
- 'Current evaluated best model in eval_dataset is epoch_{}, {}={}'
|
|
|
- .format(best_model_epoch, best_accuracy_key,
|
|
|
- best_accuracy))
|
|
|
+ if best_model_epoch > 0:
|
|
|
+ logging.info(
|
|
|
+ 'Current evaluated best model in eval_dataset is epoch_{}, {}={}'
|
|
|
+ .format(best_model_epoch, best_accuracy_key,
|
|
|
+ best_accuracy))
|
|
|
if eval_dataset is not None and early_stop:
|
|
|
if earlystop(current_accuracy):
|
|
|
break
|