瀏覽代碼

fix scope problem for slim

jiangjiajun 5 年之前
父節點
當前提交
6a89ed1380
共有 3 個文件被更改,包括 45 次插入13 次删除
  1. 12 8
      paddlex/cv/models/slim/prune.py
  2. 32 4
      paddlex/cv/models/slim/prune_config.py
  3. 1 1
      paddlex/cv/models/slim/visualize.py

+ 12 - 8
paddlex/cv/models/slim/prune.py

@@ -104,7 +104,7 @@ def sensitivity(program,
     return sensitivities
     return sensitivities
 
 
 
 
-def channel_prune(program, prune_names, prune_ratios, place, only_graph=False):
+def channel_prune(program, prune_names, prune_ratios, place, only_graph=False, scope=None):
     """通道裁剪。
     """通道裁剪。
 
 
     Args:
     Args:
@@ -134,7 +134,8 @@ def channel_prune(program, prune_names, prune_ratios, place, only_graph=False):
             pruned_num = int(round(origin_num * (ratio)))
             pruned_num = int(round(origin_num * (ratio)))
             prune_ratios[index] = ratio
             prune_ratios[index] = ratio
         index += 1
         index += 1
-    scope = fluid.global_scope()
+    if scope is None:
+        scope = fluid.global_scope()
     pruner = Pruner()
     pruner = Pruner()
     program, _, _ = pruner.prune(
     program, _, _ = pruner.prune(
         program,
         program,
@@ -175,12 +176,12 @@ def prune_program(model, prune_params_ratios=None):
         prune_params_ratios[prune_name] for prune_name in prune_names
         prune_params_ratios[prune_name] for prune_name in prune_names
     ]
     ]
     model.train_prog = channel_prune(train_prog, prune_names, prune_ratios,
     model.train_prog = channel_prune(train_prog, prune_names, prune_ratios,
-                                     place)
+                                     place, scope=model.scope)
     model.test_prog = channel_prune(
     model.test_prog = channel_prune(
-        eval_prog, prune_names, prune_ratios, place, only_graph=True)
+        eval_prog, prune_names, prune_ratios, place, only_graph=True, scope=model.scope)
 
 
 
 
-def update_program(program, model_dir, place):
+def update_program(program, model_dir, place, scope=None):
     """根据裁剪信息更新Program和参数。
     """根据裁剪信息更新Program和参数。
 
 
     Args:
     Args:
@@ -197,10 +198,12 @@ def update_program(program, model_dir, place):
         shapes = yaml.load(f.read(), Loader=yaml.Loader)
         shapes = yaml.load(f.read(), Loader=yaml.Loader)
     for param, shape in shapes.items():
     for param, shape in shapes.items():
         graph.var(param).set_shape(shape)
         graph.var(param).set_shape(shape)
+    if scope is None:
+        scope = fluid.global_scope()
     for block in program.blocks:
     for block in program.blocks:
         for param in block.all_parameters():
         for param in block.all_parameters():
             if param.name in shapes:
             if param.name in shapes:
-                param_tensor = fluid.global_scope().find_var(
+                param_tensor = scope.find_var(
                     param.name).get_tensor()
                     param.name).get_tensor()
                 param_tensor.set(
                 param_tensor.set(
                     np.zeros(list(shapes[param.name])).astype('float32'),
                     np.zeros(list(shapes[param.name])).astype('float32'),
@@ -293,7 +296,7 @@ def get_params_ratios(sensitivities_file, eval_metric_loss=0.05):
     return params_ratios
     return params_ratios
 
 
 
 
-def cal_model_size(program, place, sensitivities_file, eval_metric_loss=0.05):
+def cal_model_size(program, place, sensitivities_file, eval_metric_loss=0.05, scope=None):
     """在可容忍的精度损失下,计算裁剪后模型大小相对于当前模型大小的比例。
     """在可容忍的精度损失下,计算裁剪后模型大小相对于当前模型大小的比例。
 
 
     Args:
     Args:
@@ -326,7 +329,8 @@ def cal_model_size(program, place, sensitivities_file, eval_metric_loss=0.05):
         list(prune_params_ratios.keys()),
         list(prune_params_ratios.keys()),
         list(prune_params_ratios.values()),
         list(prune_params_ratios.values()),
         place,
         place,
-        only_graph=True)
+        only_graph=True,
+        scope=scope)
     origin_size = 0
     origin_size = 0
     new_size = 0
     new_size = 0
     for var in program.list_vars():
     for var in program.list_vars():

+ 32 - 4
paddlex/cv/models/slim/prune_config.py

@@ -171,10 +171,14 @@ def get_prune_params(model):
             model_type.startswith('ShuffleNetV2'):
             model_type.startswith('ShuffleNetV2'):
         for block in program.blocks:
         for block in program.blocks:
             for param in block.all_parameters():
             for param in block.all_parameters():
-                pd_var = fluid.global_scope().find_var(param.name)
-                pd_param = pd_var.get_tensor()
-                if len(np.array(pd_param).shape) == 4:
-                    prune_names.append(param.name)
+                pd_var = model.scope.find_var(param.name)
+                try:
+                    pd_param = pd_var.get_tensor()
+                    if len(np.array(pd_param).shape) == 4:
+                        prune_names.append(param.name)
+                except Exception as e:
+                    print("None Tensor Name: ", param.name)
+                    print("Error message: {}".format(e))
         if model_type == 'AlexNet':
         if model_type == 'AlexNet':
             prune_names.remove('conv5_weights')
             prune_names.remove('conv5_weights')
         if model_type == 'ShuffleNetV2':
         if model_type == 'ShuffleNetV2':
@@ -285,11 +289,35 @@ def get_prune_params(model):
                 prune_names.remove(i)
                 prune_names.remove(i)
 
 
     elif model_type.startswith('DeepLabv3p'):
     elif model_type.startswith('DeepLabv3p'):
+        if model_type.lower() == "deeplabv3p_mobilenetv3_large_x1_0_ssld":
+            params_not_prune = [
+                'last_1x1_conv_weights', 'conv14_se_2_weights',
+                'conv16_depthwise_weights', 'conv13_depthwise_weights',
+                'conv15_se_2_weights', 'conv2_depthwise_weights',
+                'conv6_depthwise_weights', 'conv8_depthwise_weights',
+                'fc_weights', 'conv3_depthwise_weights', 'conv7_se_2_weights',
+                'conv16_expand_weights', 'conv16_se_2_weights',
+                'conv10_depthwise_weights', 'conv11_depthwise_weights',
+                'conv15_expand_weights', 'conv5_expand_weights',
+                'conv15_depthwise_weights', 'conv14_depthwise_weights',
+                'conv12_se_2_weights', 'conv1_weights',
+                'conv13_expand_weights', 'conv_last_weights',
+                'conv12_depthwise_weights', 'conv13_se_2_weights',
+                'conv12_expand_weights', 'conv5_depthwise_weights',
+                'conv6_se_2_weights', 'conv10_expand_weights',
+                'conv9_depthwise_weights', 'conv6_expand_weights',
+                'conv5_se_2_weights', 'conv14_expand_weights',
+                'conv4_depthwise_weights', 'conv7_expand_weights',
+                'conv7_depthwise_weights'
+            ]
         for param in program.global_block().all_parameters():
         for param in program.global_block().all_parameters():
             if 'weight' not in param.name:
             if 'weight' not in param.name:
                 continue
                 continue
             if 'dwise' in param.name or 'depthwise' in param.name or 'logit' in param.name:
             if 'dwise' in param.name or 'depthwise' in param.name or 'logit' in param.name:
                 continue
                 continue
+            if model_type.lower() == "deeplabv3p_mobilenetv3_large_x1_0_ssld":
+                if param.name in params_not_prune:
+                    continue
             prune_names.append(param.name)
             prune_names.append(param.name)
         params_not_prune = [
         params_not_prune = [
             'xception_{}/exit_flow/block2/separable_conv3/pointwise/weights'.
             'xception_{}/exit_flow/block2/separable_conv3/pointwise/weights'.

+ 1 - 1
paddlex/cv/models/slim/visualize.py

@@ -42,7 +42,7 @@ def visualize(model, sensitivities_file, save_dir='./'):
     y = list()
     y = list()
     for loss_thresh in tqdm.tqdm(list(np.arange(0.05, 1, 0.05))):
     for loss_thresh in tqdm.tqdm(list(np.arange(0.05, 1, 0.05))):
         prune_ratio = 1 - cal_model_size(
         prune_ratio = 1 - cal_model_size(
-            program, place, sensitivities_file, eval_metric_loss=loss_thresh)
+            program, place, sensitivities_file, eval_metric_loss=loss_thresh, scope=model.scope)
         x.append(prune_ratio)
         x.append(prune_ratio)
         y.append(loss_thresh)
         y.append(loss_thresh)
     plt.plot(x, y, color='green', linewidth=0.5, marker='o', markersize=3)
     plt.plot(x, y, color='green', linewidth=0.5, marker='o', markersize=3)