|
|
@@ -1204,7 +1204,7 @@ def _legacy_train(model, num_epochs, train_dataset, train_batch_size,
|
|
|
|
|
|
if sensitivities_file is not None:
|
|
|
dataset = eval_dataset or train_dataset
|
|
|
- inputs = [1] + list(dataset[0]['image'].shape)
|
|
|
+ inputs = [1, 3] + list(dataset[0]['image'].shape[:2])
|
|
|
model.pruner = L1NormFilterPruner(
|
|
|
model.net, inputs=inputs, sen_file=sensitivities_file)
|
|
|
model.pruner.sensitive_prune(pruned_flops=pruned_flops)
|