Эх сурвалжийг харах

Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleX into param_tuning

will-jl944 4 жил өмнө
parent
commit
3a519a241d

+ 9 - 7
dygraph/paddlex/cv/models/base.py

@@ -122,7 +122,14 @@ class BaseModel:
         info = dict()
         info['pruner'] = self.pruner.__class__.__name__
         info['pruning_ratios'] = self.pruning_ratios
-        info['pruner_inputs'] = self.pruner.inputs
+        pruner_inputs = self.pruner.inputs
+        if self.model_type == 'detector':
+            pruner_inputs = {
+                k: v.tolist()
+                for k, v in pruner_inputs[0].items()
+            }
+        info['pruner_inputs'] = pruner_inputs
+
         return info
 
     def get_quant_info(self):
@@ -427,12 +434,7 @@ class BaseModel:
         pre_pruning_flops = flops(self.net, self.pruner.inputs)
         logging.info("Pre-pruning FLOPs: {}. Pruning starts...".format(
             pre_pruning_flops))
-        skip_vars = []
-        for param in self.net.parameters():
-            if param.shape[0] <= 8:
-                skip_vars.append(param.name)
-        _, self.pruning_ratios = sensitive_prune(self.pruner, pruned_flops,
-                                                 skip_vars)
+        _, self.pruning_ratios = sensitive_prune(self.pruner, pruned_flops)
         post_pruning_flops = flops(self.net, self.pruner.inputs)
         logging.info("Pruning is complete. Post-pruning FLOPs: {}".format(
             post_pruning_flops))

+ 5 - 0
dygraph/paddlex/cv/models/load_model.py

@@ -69,6 +69,11 @@ def load_model(model_dir):
             with open(osp.join(model_dir, "prune.yml")) as f:
                 pruning_info = yaml.load(f.read(), Loader=yaml.Loader)
                 inputs = pruning_info['pruner_inputs']
+                if model.model_type == 'detector':
+                    inputs = [{
+                        k: paddle.to_tensor(v)
+                        for k, v in inputs.items()
+                    }]
                 model.pruner = getattr(paddleslim, pruning_info['pruner'])(
                     model.net, inputs=inputs)
                 model.pruning_ratios = pruning_info['pruning_ratios']

+ 1 - 1
dygraph/tutorials/slim/prune/image_classification/mobilenetv2_train.py

@@ -32,7 +32,7 @@ eval_dataset = pdx.datasets.ImageNet(
 # 初始化模型,并进行训练
 # 可使用VisualDL查看训练指标,参考https://github.com/PaddlePaddle/PaddleX/tree/release/2.0-rc/tutorials/train#visualdl可视化训练指标
 num_classes = len(train_dataset.labels)
-model = pdx.models.MobileNetV3_large(num_classes=num_classes)
+model = pdx.models.MobileNetV2(num_classes=num_classes)
 
 # API说明:https://github.com/PaddlePaddle/PaddleX/blob/95c53dec89ab0f3769330fa445c6d9213986ca5f/paddlex/cv/models/classifier.py#L153
 # 各参数介绍与调整说明:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html