|
|
@@ -15,6 +15,7 @@
|
|
|
from __future__ import absolute_import
|
|
|
import paddle.fluid as fluid
|
|
|
import os
|
|
|
+import sys
|
|
|
import numpy as np
|
|
|
import time
|
|
|
import math
|
|
|
@@ -252,6 +253,9 @@ class BaseAPI:
|
|
|
del self.init_params['self']
|
|
|
if '__class__' in self.init_params:
|
|
|
del self.init_params['__class__']
|
|
|
+ if 'model_name' in self.init_params:
|
|
|
+ del self.init_params['model_name']
|
|
|
+
|
|
|
info['_init_params'] = self.init_params
|
|
|
|
|
|
info['_Attributes']['num_classes'] = self.num_classes
|
|
|
@@ -372,6 +376,8 @@ class BaseAPI:
|
|
|
use_vdl=False,
|
|
|
early_stop=False,
|
|
|
early_stop_patience=5):
|
|
|
+ if train_dataset.num_samples < train_batch_size:
|
|
|
+ raise Exception('The amount of training datset must be larger than batch size.')
|
|
|
if not osp.isdir(save_dir):
|
|
|
if osp.exists(save_dir):
|
|
|
os.remove(save_dir)
|
|
|
@@ -429,9 +435,7 @@ class BaseAPI:
|
|
|
|
|
|
if use_vdl:
|
|
|
# VisualDL component
|
|
|
- log_writer = LogWriter(vdl_logdir, sync_cycle=20)
|
|
|
- train_step_component = OrderedDict()
|
|
|
- eval_component = OrderedDict()
|
|
|
+ log_writer = LogWriter(vdl_logdir)
|
|
|
|
|
|
thresh = 0.0001
|
|
|
if early_stop:
|
|
|
@@ -469,13 +473,7 @@ class BaseAPI:
|
|
|
|
|
|
if use_vdl:
|
|
|
for k, v in step_metrics.items():
|
|
|
- if k not in train_step_component.keys():
|
|
|
- with log_writer.mode('Each_Step_while_Training'
|
|
|
- ) as step_logger:
|
|
|
- train_step_component[
|
|
|
- k] = step_logger.scalar(
|
|
|
- 'Training: {}'.format(k))
|
|
|
- train_step_component[k].add_record(num_steps, v)
|
|
|
+ log_writer.add_scalar('Metrics/Training(Step): {}'.format(k), v, num_steps)
|
|
|
|
|
|
# 估算剩余时间
|
|
|
avg_step_time = np.mean(time_stat)
|
|
|
@@ -536,12 +534,7 @@ class BaseAPI:
|
|
|
if isinstance(v, np.ndarray):
|
|
|
if v.size > 1:
|
|
|
continue
|
|
|
- if k not in eval_component:
|
|
|
- with log_writer.mode('Each_Epoch_on_Eval_Data'
|
|
|
- ) as eval_logger:
|
|
|
- eval_component[k] = eval_logger.scalar(
|
|
|
- 'Evaluation: {}'.format(k))
|
|
|
- eval_component[k].add_record(i + 1, v)
|
|
|
+ log_writer.add_scalar("Metrics/Eval(Epoch): {}".format(k), v, i+1)
|
|
|
self.save_model(save_dir=current_save_dir)
|
|
|
time_eval_one_epoch = time.time() - eval_epoch_start_time
|
|
|
eval_epoch_start_time = time.time()
|
|
|
@@ -552,4 +545,4 @@ class BaseAPI:
|
|
|
best_accuracy))
|
|
|
if eval_dataset is not None and early_stop:
|
|
|
if earlystop(current_accuracy):
|
|
|
- break
|
|
|
+ break
|