|
|
@@ -382,21 +382,20 @@ def start_train_task(data, workspace, monitored_processes):
|
|
|
tid = data['tid']
|
|
|
assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
|
|
|
path = workspace.tasks[tid].path
|
|
|
- if 'eval_metric_loss' in data and \
|
|
|
- data['eval_metric_loss'] is not None:
|
|
|
+ if 'pruned_flops' in data and \
|
|
|
+ data['pruned_flops'] is not None:
|
|
|
# 裁剪任务
|
|
|
parent_id = workspace.tasks[tid].parent_id
|
|
|
assert parent_id != "", "任务{}不是裁剪训练任务".format(tid)
|
|
|
parent_path = workspace.tasks[parent_id].path
|
|
|
- sensitivities_path = osp.join(parent_path, 'prune',
|
|
|
- 'sensitivities.data')
|
|
|
- eval_metric_loss = data['eval_metric_loss']
|
|
|
+ sensitivities_path = osp.join(parent_path, 'prune')
|
|
|
+ pruned_flops = data['pruned_flops']
|
|
|
parent_best_model_path = osp.join(parent_path, 'output', 'best_model')
|
|
|
params_conf_file = osp.join(path, 'params.pkl')
|
|
|
with open(params_conf_file, 'rb') as f:
|
|
|
params = pickle.load(f)
|
|
|
params['train'].sensitivities_path = sensitivities_path
|
|
|
- params['train'].eval_metric_loss = eval_metric_loss
|
|
|
+ params['train'].pruned_flops = pruned_flops
|
|
|
params['train'].pretrain_weights = parent_best_model_path
|
|
|
with open(params_conf_file, 'wb') as f:
|
|
|
pickle.dump(params, f)
|