浏览代码

deal with image shape while doing pruning in legacy_train

will-jl944 4 年之前
父节点
当前提交
5b66dee5c1
共有 2 个文件被更改,包括 2 次插入2 次删除
  1. 1 1
      dygraph/paddlex/cls.py
  2. 1 1
      dygraph/paddlex/seg.py

+ 1 - 1
dygraph/paddlex/cls.py

@@ -1204,7 +1204,7 @@ def _legacy_train(model, num_epochs, train_dataset, train_batch_size,
 
 
     if sensitivities_file is not None:
     if sensitivities_file is not None:
         dataset = eval_dataset or train_dataset
         dataset = eval_dataset or train_dataset
-        inputs = [1] + list(dataset[0]['image'].shape)
+        inputs = [1, 3] + list(dataset[0]['image'].shape[:2])
         model.pruner = L1NormFilterPruner(
         model.pruner = L1NormFilterPruner(
             model.net, inputs=inputs, sen_file=sensitivities_file)
             model.net, inputs=inputs, sen_file=sensitivities_file)
         model.pruner.sensitive_prune(pruned_flops=pruned_flops)
         model.pruner.sensitive_prune(pruned_flops=pruned_flops)

+ 1 - 1
dygraph/paddlex/seg.py

@@ -389,7 +389,7 @@ def _legacy_train(model, num_epochs, train_dataset, train_batch_size,
 
 
     if sensitivities_file is not None:
     if sensitivities_file is not None:
         dataset = eval_dataset or train_dataset
         dataset = eval_dataset or train_dataset
-        inputs = [1] + list(dataset[0]['image'].shape)
+        inputs = [1, 3] + list(dataset[0]['image'].shape[:2])
         model.pruner = L1NormFilterPruner(
         model.pruner = L1NormFilterPruner(
             model.net, inputs=inputs, sen_file=sensitivities_file)
             model.net, inputs=inputs, sen_file=sensitivities_file)
         model.pruner.sensitive_prune(pruned_flops=pruned_flops)
         model.pruner.sensitive_prune(pruned_flops=pruned_flops)