Răsfoiți Sursa

udpate restful

will-jl944 4 ani în urmă
părinte
comite
18584d9aca

+ 5 - 6
paddlex_restful/restful/project/task.py

@@ -382,21 +382,20 @@ def start_train_task(data, workspace, monitored_processes):
     tid = data['tid']
     assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
     path = workspace.tasks[tid].path
-    if 'eval_metric_loss' in data and \
-        data['eval_metric_loss'] is not None:
+    if 'pruned_flops' in data and \
+        data['pruned_flops'] is not None:
         # 裁剪任务
         parent_id = workspace.tasks[tid].parent_id
         assert parent_id != "", "任务{}不是裁剪训练任务".format(tid)
         parent_path = workspace.tasks[parent_id].path
-        sensitivities_path = osp.join(parent_path, 'prune',
-                                      'sensitivities.data')
-        eval_metric_loss = data['eval_metric_loss']
+        sensitivities_path = osp.join(parent_path, 'prune')
+        pruned_flops = data['pruned_flops']
         parent_best_model_path = osp.join(parent_path, 'output', 'best_model')
         params_conf_file = osp.join(path, 'params.pkl')
         with open(params_conf_file, 'rb') as f:
             params = pickle.load(f)
         params['train'].sensitivities_path = sensitivities_path
-        params['train'].eval_metric_loss = eval_metric_loss
+        params['train'].pruned_flops = pruned_flops
         params['train'].pretrain_weights = parent_best_model_path
         with open(params_conf_file, 'wb') as f:
             pickle.dump(params, f)

+ 2 - 2
paddlex_restful/restful/project/train/params.py

@@ -89,8 +89,8 @@ class Params(object):
     def set_sensitivities_path(self, sensitivities_path):
         self.sensitivities_path = sensitivities_path
 
-    def set_eval_metric_loss(self, eval_metric_loss):
-        self.eval_metric_loss = eval_metric_loss
+    def set_pruned_flops(self, pruned_flops):
+        self.pruned_flops = pruned_flops
 
     def set_image_shape(self, image_shape):
         self.image_shape = image_shape