Parcourir la source

when loading a classifier or segmenter exporting by PaddleX 2.0.0rc, add 'net.' to each parameter name

FlyingQianMM il y a 4 ans
Parent
commit
f6900b61e1

+ 1 - 1
examples/meter_reader/reader_infer.py

@@ -263,7 +263,7 @@ class MeterReader:
         eroded_results = seg_results
         for i in range(len(seg_results)):
             eroded_results[i]['label_map'] = cv2.erode(
-                seg_results[i]['label_map'], kernel)
+                seg_results[i]['label_map'].astype(np.uint8), kernel)
         return eroded_results
 
     def circle_to_rectangle(self, seg_results):

+ 10 - 0
paddlex/cv/models/load_model.py

@@ -121,6 +121,16 @@ def load_model(model_dir, **params):
                     net_state_dict = load_rcnn_inference_model(model_dir)
                 else:
                     net_state_dict = paddle.load(osp.join(model_dir, 'model'))
+                    if model.model_type in ['classifier', 'segmenter'
+                                            ] and 'rc' in version:
+                        # For PaddleX>=2.0.0, when exporting a classifier and segmenter,
+                        # InferNet is defined to append softmax and argmax operators to the model,
+                        # so parameter name starts with 'net.'
+                        new_net_state_dict = {}
+                        for k, v in net_state_dict.items():
+                            new_net_state_dict['net.' + k] = v
+                        net_state_dict = new_net_state_dict
+
             else:
                 net_state_dict = paddle.load(
                     osp.join(model_dir, 'model.pdparams'))

+ 1 - 1
paddlex_restful/restful/project/task.py

@@ -388,7 +388,7 @@ def start_train_task(data, workspace, monitored_processes):
         parent_id = workspace.tasks[tid].parent_id
         assert parent_id != "", "任务{}不是裁剪训练任务".format(tid)
         parent_path = workspace.tasks[parent_id].path
-        sensitivities_path = osp.join(parent_path, 'prune')
+        sensitivities_path = osp.join(parent_path, 'prune', 'model.sensi.data')
         pruned_flops = data['pruned_flops']
         parent_best_model_path = osp.join(parent_path, 'output', 'best_model')
         params_conf_file = osp.join(path, 'params.pkl')