prune_config.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import os.path as osp
  16. import paddle.fluid as fluid
  17. #import paddlehub as hub
  18. import paddlex
  19. sensitivities_data = {
  20. 'AlexNet':
  21. 'https://bj.bcebos.com/paddlex/slim_prune/alexnet_sensitivities.data',
  22. 'ResNet18':
  23. 'https://bj.bcebos.com/paddlex/slim_prune/resnet18.sensitivities',
  24. 'ResNet34':
  25. 'https://bj.bcebos.com/paddlex/slim_prune/resnet34.sensitivities',
  26. 'ResNet50':
  27. 'https://bj.bcebos.com/paddlex/slim_prune/resnet50.sensitivities',
  28. 'ResNet101':
  29. 'https://bj.bcebos.com/paddlex/slim_prune/resnet101.sensitivities',
  30. 'ResNet50_vd':
  31. 'https://bj.bcebos.com/paddlex/slim_prune/resnet50vd.sensitivities',
  32. 'ResNet101_vd':
  33. 'https://bj.bcebos.com/paddlex/slim_prune/resnet101vd.sensitivities',
  34. 'DarkNet53':
  35. 'https://bj.bcebos.com/paddlex/slim_prune/darknet53.sensitivities',
  36. 'MobileNetV1':
  37. 'https://bj.bcebos.com/paddlex/slim_prune/mobilenetv1.sensitivities',
  38. 'MobileNetV2':
  39. 'https://bj.bcebos.com/paddlex/slim_prune/mobilenetv2.sensitivities',
  40. 'MobileNetV3_large':
  41. 'https://bj.bcebos.com/paddlex/slim_prune/mobilenetv3_large.sensitivities',
  42. 'MobileNetV3_small':
  43. 'https://bj.bcebos.com/paddlex/slim_prune/mobilenetv3_small.sensitivities',
  44. 'MobileNetV3_large_ssld':
  45. 'https://bj.bcebos.com/paddlex/slim_prune/mobilenetv3_large_ssld_sensitivities.data',
  46. 'MobileNetV3_small_ssld':
  47. 'https://bj.bcebos.com/paddlex/slim_prune/mobilenetv3_small_ssld_sensitivities.data',
  48. 'DenseNet121':
  49. 'https://bj.bcebos.com/paddlex/slim_prune/densenet121.sensitivities',
  50. 'DenseNet161':
  51. 'https://bj.bcebos.com/paddlex/slim_prune/densenet161.sensitivities',
  52. 'DenseNet201':
  53. 'https://bj.bcebos.com/paddlex/slim_prune/densenet201.sensitivities',
  54. 'Xception41':
  55. 'https://bj.bcebos.com/paddlex/slim_prune/xception41.sensitivities',
  56. 'Xception65':
  57. 'https://bj.bcebos.com/paddlex/slim_prune/xception65.sensitivities',
  58. 'ShuffleNetV2':
  59. 'https://bj.bcebos.com/paddlex/slim_prune/shufflenetv2_sensitivities.data',
  60. 'YOLOv3_MobileNetV1':
  61. 'https://bj.bcebos.com/paddlex/slim_prune/yolov3_mobilenetv1.sensitivities',
  62. 'YOLOv3_MobileNetV3_large':
  63. 'https://bj.bcebos.com/paddlex/slim_prune/yolov3_mobilenetv3.sensitivities',
  64. 'YOLOv3_DarkNet53':
  65. 'https://bj.bcebos.com/paddlex/slim_prune/yolov3_darknet53.sensitivities',
  66. 'YOLOv3_ResNet34':
  67. 'https://bj.bcebos.com/paddlex/slim_prune/yolov3_resnet34.sensitivities',
  68. 'UNet': 'https://bj.bcebos.com/paddlex/slim_prune/unet.sensitivities',
  69. 'DeepLabv3p_MobileNetV2_x0.25':
  70. 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_mobilenetv2_x0.25_no_aspp_decoder.sensitivities',
  71. 'DeepLabv3p_MobileNetV2_x0.5':
  72. 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_mobilenetv2_x0.5_no_aspp_decoder.sensitivities',
  73. 'DeepLabv3p_MobileNetV2_x1.0':
  74. 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_mobilenetv2_x1.0_no_aspp_decoder.sensitivities',
  75. 'DeepLabv3p_MobileNetV2_x1.5':
  76. 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_mobilenetv2_x1.5_no_aspp_decoder.sensitivities',
  77. 'DeepLabv3p_MobileNetV2_x2.0':
  78. 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_mobilenetv2_x2.0_no_aspp_decoder.sensitivities',
  79. 'DeepLabv3p_MobileNetV2_x0.25_aspp_decoder':
  80. 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_mobilenetv2_x0.25_with_aspp_decoder.sensitivities',
  81. 'DeepLabv3p_MobileNetV2_x0.5_aspp_decoder':
  82. 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_mobilenetv2_x0.5_with_aspp_decoder.sensitivities',
  83. 'DeepLabv3p_MobileNetV2_x1.0_aspp_decoder':
  84. 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_mobilenetv2_x1.0_with_aspp_decoder.sensitivities',
  85. 'DeepLabv3p_MobileNetV2_x1.5_aspp_decoder':
  86. 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_mobilenetv2_x1.5_with_aspp_decoder.sensitivities',
  87. 'DeepLabv3p_MobileNetV2_x2.0_aspp_decoder':
  88. 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_mobilenetv2_x2.0_with_aspp_decoder.sensitivities',
  89. 'DeepLabv3p_Xception65_aspp_decoder':
  90. 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_xception65_with_aspp_decoder.sensitivities',
  91. 'DeepLabv3p_Xception41_aspp_decoder':
  92. 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_xception41_with_aspp_decoder.sensitivities'
  93. }
  94. def get_sensitivities(flag, model, save_dir):
  95. model_name = model.__class__.__name__
  96. model_type = model_name
  97. if hasattr(model, 'backbone'):
  98. model_type = model_name + '_' + model.backbone
  99. if model_type.startswith('DeepLabv3p_Xception'):
  100. model_type = model_type + '_' + 'aspp' + '_' + 'decoder'
  101. elif hasattr(model, 'encoder_with_aspp') or hasattr(model,
  102. 'enable_decoder'):
  103. model_type = model_type + '_' + 'aspp' + '_' + 'decoder'
  104. if osp.isfile(flag):
  105. return flag
  106. elif flag == 'DEFAULT':
  107. assert model_type in sensitivities_data, "There is not sensitivities data file for {}, you may need to calculate it by your self.".format(
  108. model_type)
  109. url = sensitivities_data[model_type]
  110. fname = osp.split(url)[-1]
  111. paddlex.utils.download(url, path=save_dir)
  112. return osp.join(save_dir, fname)
  113. # try:
  114. # hub.download(fname, save_path=save_dir)
  115. # except Exception as e:
  116. # if isinstance(e, hub.ResourceNotFoundError):
  117. # raise Exception(
  118. # "Resource for model {}(key='{}') not found".format(
  119. # model_type, fname))
  120. # elif isinstance(e, hub.ServerConnectionError):
  121. # raise Exception(
  122. # "Cannot get reource for model {}(key='{}'), please check your internet connection"
  123. # .format(model_type, fname))
  124. # else:
  125. # raise Exception(
  126. # "Unexpected error, please make sure paddlehub >= 1.6.2 {}".
  127. # format(str(e)))
  128. # return osp.join(save_dir, fname)
  129. else:
  130. raise Exception(
  131. "sensitivities need to be defined as directory path or `DEFAULT`(download sensitivities automatically)."
  132. )
  133. def get_prune_params(model):
  134. prune_names = []
  135. model_type = model.__class__.__name__
  136. if model_type == 'BaseClassifier':
  137. model_type = model.model_name
  138. if hasattr(model, 'backbone'):
  139. backbone = model.backbone
  140. model_type += ('_' + backbone)
  141. program = model.test_prog
  142. if model_type.startswith('ResNet') or \
  143. model_type.startswith('DenseNet') or \
  144. model_type.startswith('DarkNet') or \
  145. model_type.startswith('AlexNet') or \
  146. model_type.startswith('ShuffleNetV2'):
  147. for block in program.blocks:
  148. for param in block.all_parameters():
  149. pd_var = fluid.global_scope().find_var(param.name)
  150. pd_param = pd_var.get_tensor()
  151. if len(np.array(pd_param).shape) == 4:
  152. prune_names.append(param.name)
  153. if model_type == 'AlexNet':
  154. prune_names.remove('conv5_weights')
  155. if model_type == 'ShuffleNetV2':
  156. not_prune_names = [
  157. 'stage_2_1_conv5_weights',
  158. 'stage_2_1_conv3_weights',
  159. 'stage_2_2_conv3_weights',
  160. 'stage_2_3_conv3_weights',
  161. 'stage_2_4_conv3_weights',
  162. 'stage_3_1_conv5_weights',
  163. 'stage_3_1_conv3_weights',
  164. 'stage_3_2_conv3_weights',
  165. 'stage_3_3_conv3_weights',
  166. 'stage_3_4_conv3_weights',
  167. 'stage_3_5_conv3_weights',
  168. 'stage_3_6_conv3_weights',
  169. 'stage_3_7_conv3_weights',
  170. 'stage_3_8_conv3_weights',
  171. 'stage_4_1_conv5_weights',
  172. 'stage_4_1_conv3_weights',
  173. 'stage_4_2_conv3_weights',
  174. 'stage_4_3_conv3_weights',
  175. 'stage_4_4_conv3_weights',
  176. ]
  177. for name in not_prune_names:
  178. prune_names.remove(name)
  179. elif model_type == "MobileNetV1":
  180. prune_names.append("conv1_weights")
  181. for param in program.global_block().all_parameters():
  182. if "_sep_weights" in param.name:
  183. prune_names.append(param.name)
  184. elif model_type == "MobileNetV2":
  185. for param in program.global_block().all_parameters():
  186. if 'weight' not in param.name \
  187. or 'dwise' in param.name \
  188. or 'fc' in param.name :
  189. continue
  190. prune_names.append(param.name)
  191. elif model_type.startswith("MobileNetV3"):
  192. if model_type.startswith('MobileNetV3_small'):
  193. expand_prune_id = [3, 4]
  194. else:
  195. expand_prune_id = [2, 3, 4, 8, 9, 11]
  196. for param in program.global_block().all_parameters():
  197. if ('expand_weights' in param.name and \
  198. int(param.name.split('_')[0][4:]) in expand_prune_id)\
  199. or 'linear_weights' in param.name \
  200. or 'se_1_weights' in param.name:
  201. prune_names.append(param.name)
  202. elif model_type.startswith('Xception') or \
  203. model_type.startswith('DeepLabv3p_Xception'):
  204. params_not_prune = [
  205. 'weights',
  206. 'xception_{}/exit_flow/block2/separable_conv3/pointwise/weights'.
  207. format(model_type[-2:]), 'encoder/concat/weights',
  208. 'decoder/concat/weights'
  209. ]
  210. for param in program.global_block().all_parameters():
  211. if 'weight' not in param.name \
  212. or 'dwise' in param.name \
  213. or 'depthwise' in param.name \
  214. or 'logit' in param.name:
  215. continue
  216. if param.name in params_not_prune:
  217. continue
  218. prune_names.append(param.name)
  219. elif model_type.startswith('YOLOv3'):
  220. for block in program.blocks:
  221. for param in block.all_parameters():
  222. if 'weights' in param.name and 'yolo_block' in param.name:
  223. prune_names.append(param.name)
  224. elif model_type.startswith('UNet'):
  225. for param in program.global_block().all_parameters():
  226. if 'weight' not in param.name:
  227. continue
  228. if 'logit' in param.name:
  229. continue
  230. prune_names.append(param.name)
  231. params_not_prune = [
  232. 'encode/block4/down/conv1/weights',
  233. 'encode/block3/down/conv1/weights',
  234. 'encode/block2/down/conv1/weights', 'encode/block1/conv1/weights'
  235. ]
  236. for i in params_not_prune:
  237. if i in prune_names:
  238. prune_names.remove(i)
  239. elif model_type.startswith('DeepLabv3p'):
  240. for param in program.global_block().all_parameters():
  241. if 'weight' not in param.name:
  242. continue
  243. if 'dwise' in param.name or 'depthwise' in param.name or 'logit' in param.name:
  244. continue
  245. prune_names.append(param.name)
  246. params_not_prune = [
  247. 'xception_{}/exit_flow/block2/separable_conv3/pointwise/weights'.
  248. format(model_type[-2:]), 'encoder/concat/weights',
  249. 'decoder/concat/weights'
  250. ]
  251. if model.encoder_with_aspp == True:
  252. params_not_prune.append(
  253. 'xception_{}/exit_flow/block2/separable_conv3/pointwise/weights'
  254. .format(model_type[-2:]))
  255. params_not_prune.append('conv8_1_linear_weights')
  256. for i in params_not_prune:
  257. if i in prune_names:
  258. prune_names.remove(i)
  259. else:
  260. raise Exception('The {} is not implement yet!'.format(model_type))
  261. return prune_names