|
|
@@ -444,11 +444,11 @@ class BaseModel:
|
|
|
criterion({'l1_norm', 'fpgm'}, optional): Pruning criterion. Defaults to 'l1_norm'.
|
|
|
save_dir(str, optional): The directory to save sensitivity file of the model. Defaults to 'output'.
|
|
|
"""
|
|
|
- if self.__class__.__name__ in ['FasterRCNN', 'MaskRCNN']:
|
|
|
+ if self.__class__.__name__ in {'FasterRCNN', 'MaskRCNN', 'PicoDet'}:
|
|
|
raise Exception("{} does not support pruning currently!".format(
|
|
|
self.__class__.__name__))
|
|
|
|
|
|
- assert criterion in ['l1_norm', 'fpgm'], \
|
|
|
+ assert criterion in {'l1_norm', 'fpgm'}, \
|
|
|
"Pruning criterion {} is not supported. Please choose from ['l1_norm', 'fpgm']"
|
|
|
arrange_transforms(
|
|
|
model_type=self.model_type,
|