|
@@ -257,8 +257,8 @@ class BaseAPI:
|
|
|
logging.info(
|
|
logging.info(
|
|
|
"Load pretrain weights from {}.".format(pretrain_weights),
|
|
"Load pretrain weights from {}.".format(pretrain_weights),
|
|
|
use_color=True)
|
|
use_color=True)
|
|
|
- paddlex.utils.utils.load_pretrain_weights(
|
|
|
|
|
- self.exe, self.train_prog, pretrain_weights, fuse_bn)
|
|
|
|
|
|
|
+ paddlex.utils.utils.load_pretrain_weights(self.exe, self.train_prog,
|
|
|
|
|
+ pretrain_weights, fuse_bn)
|
|
|
# 进行裁剪
|
|
# 进行裁剪
|
|
|
if sensitivities_file is not None:
|
|
if sensitivities_file is not None:
|
|
|
import paddleslim
|
|
import paddleslim
|
|
@@ -364,9 +364,7 @@ class BaseAPI:
|
|
|
logging.info("Model saved in {}.".format(save_dir))
|
|
logging.info("Model saved in {}.".format(save_dir))
|
|
|
|
|
|
|
|
def export_inference_model(self, save_dir):
|
|
def export_inference_model(self, save_dir):
|
|
|
- test_input_names = [
|
|
|
|
|
- var.name for var in list(self.test_inputs.values())
|
|
|
|
|
- ]
|
|
|
|
|
|
|
+ test_input_names = [var.name for var in list(self.test_inputs.values())]
|
|
|
test_outputs = list(self.test_outputs.values())
|
|
test_outputs = list(self.test_outputs.values())
|
|
|
with fluid.scope_guard(self.scope):
|
|
with fluid.scope_guard(self.scope):
|
|
|
fluid.io.save_inference_model(
|
|
fluid.io.save_inference_model(
|
|
@@ -394,8 +392,7 @@ class BaseAPI:
|
|
|
|
|
|
|
|
# 模型保存成功的标志
|
|
# 模型保存成功的标志
|
|
|
open(osp.join(save_dir, '.success'), 'w').close()
|
|
open(osp.join(save_dir, '.success'), 'w').close()
|
|
|
- logging.info("Model for inference deploy saved in {}.".format(
|
|
|
|
|
- save_dir))
|
|
|
|
|
|
|
+ logging.info("Model for inference deploy saved in {}.".format(save_dir))
|
|
|
|
|
|
|
|
def train_loop(self,
|
|
def train_loop(self,
|
|
|
num_epochs,
|
|
num_epochs,
|
|
@@ -480,6 +477,9 @@ class BaseAPI:
|
|
|
best_accuracy = -1.0
|
|
best_accuracy = -1.0
|
|
|
best_model_epoch = -1
|
|
best_model_epoch = -1
|
|
|
start_epoch = self.completed_epochs
|
|
start_epoch = self.completed_epochs
|
|
|
|
|
+ # task_id: 目前由PaddleX GUI赋值
|
|
|
|
|
+ # 用于在VisualDL日志中注明所属任务id
|
|
|
|
|
+ task_id = getattr(paddlex, "task_id", "")
|
|
|
for i in range(start_epoch, num_epochs):
|
|
for i in range(start_epoch, num_epochs):
|
|
|
records = list()
|
|
records = list()
|
|
|
step_start_time = time.time()
|
|
step_start_time = time.time()
|
|
@@ -510,8 +510,8 @@ class BaseAPI:
|
|
|
if use_vdl:
|
|
if use_vdl:
|
|
|
for k, v in step_metrics.items():
|
|
for k, v in step_metrics.items():
|
|
|
log_writer.add_scalar(
|
|
log_writer.add_scalar(
|
|
|
- 'Metrics/Training(Step): {}'.format(k), v,
|
|
|
|
|
- num_steps)
|
|
|
|
|
|
|
+ '{}-Metrics/Training(Step): {}'.format(
|
|
|
|
|
+ task_id, k), v, num_steps)
|
|
|
|
|
|
|
|
# 估算剩余时间
|
|
# 估算剩余时间
|
|
|
avg_step_time = np.mean(time_stat)
|
|
avg_step_time = np.mean(time_stat)
|
|
@@ -522,13 +522,11 @@ class BaseAPI:
|
|
|
eta = ((num_epochs - i) * total_num_steps - step - 1
|
|
eta = ((num_epochs - i) * total_num_steps - step - 1
|
|
|
) * avg_step_time
|
|
) * avg_step_time
|
|
|
if time_eval_one_epoch is not None:
|
|
if time_eval_one_epoch is not None:
|
|
|
- eval_eta = (
|
|
|
|
|
- total_eval_times - i // save_interval_epochs
|
|
|
|
|
- ) * time_eval_one_epoch
|
|
|
|
|
|
|
+ eval_eta = (total_eval_times - i // save_interval_epochs
|
|
|
|
|
+ ) * time_eval_one_epoch
|
|
|
else:
|
|
else:
|
|
|
- eval_eta = (
|
|
|
|
|
- total_eval_times - i // save_interval_epochs
|
|
|
|
|
- ) * total_num_steps_eval * avg_step_time
|
|
|
|
|
|
|
+ eval_eta = (total_eval_times - i // save_interval_epochs
|
|
|
|
|
+ ) * total_num_steps_eval * avg_step_time
|
|
|
eta_str = seconds_to_hms(eta + eval_eta)
|
|
eta_str = seconds_to_hms(eta + eval_eta)
|
|
|
|
|
|
|
|
logging.info(
|
|
logging.info(
|
|
@@ -577,7 +575,8 @@ class BaseAPI:
|
|
|
if v.size > 1:
|
|
if v.size > 1:
|
|
|
continue
|
|
continue
|
|
|
log_writer.add_scalar(
|
|
log_writer.add_scalar(
|
|
|
- "Metrics/Eval(Epoch): {}".format(k), v, i + 1)
|
|
|
|
|
|
|
+ "{}-Metrics/Eval(Epoch): {}".format(task_id, k),
|
|
|
|
|
+ v, i + 1)
|
|
|
self.save_model(save_dir=current_save_dir)
|
|
self.save_model(save_dir=current_save_dir)
|
|
|
if getattr(self, 'use_ema', False):
|
|
if getattr(self, 'use_ema', False):
|
|
|
self.exe.run(self.ema.restore_program)
|
|
self.exe.run(self.ema.restore_program)
|