prune_config.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import os.path as osp
  16. import paddle.fluid as fluid
  17. #import paddlehub as hub
  18. import paddlex
  19. sensitivities_data = {
  20. 'ResNet18':
  21. 'https://bj.bcebos.com/paddlex/slim_prune/resnet18.sensitivities',
  22. 'ResNet34':
  23. 'https://bj.bcebos.com/paddlex/slim_prune/resnet34.sensitivities',
  24. 'ResNet50':
  25. 'https://bj.bcebos.com/paddlex/slim_prune/resnet50.sensitivities',
  26. 'ResNet101':
  27. 'https://bj.bcebos.com/paddlex/slim_prune/resnet101.sensitivities',
  28. 'ResNet50_vd':
  29. 'https://bj.bcebos.com/paddlex/slim_prune/resnet50vd.sensitivities',
  30. 'ResNet101_vd':
  31. 'https://bj.bcebos.com/paddlex/slim_prune/resnet101vd.sensitivities',
  32. 'DarkNet53':
  33. 'https://bj.bcebos.com/paddlex/slim_prune/darknet53.sensitivities',
  34. 'MobileNetV1':
  35. 'https://bj.bcebos.com/paddlex/slim_prune/mobilenetv1.sensitivities',
  36. 'MobileNetV2':
  37. 'https://bj.bcebos.com/paddlex/slim_prune/mobilenetv2.sensitivities',
  38. 'MobileNetV3_large':
  39. 'https://bj.bcebos.com/paddlex/slim_prune/mobilenetv3_large.sensitivities',
  40. 'MobileNetV3_small':
  41. 'https://bj.bcebos.com/paddlex/slim_prune/mobilenetv3_small.sensitivities',
  42. 'DenseNet121':
  43. 'https://bj.bcebos.com/paddlex/slim_prune/densenet121.sensitivities',
  44. 'DenseNet161':
  45. 'https://bj.bcebos.com/paddlex/slim_prune/densenet161.sensitivities',
  46. 'DenseNet201':
  47. 'https://bj.bcebos.com/paddlex/slim_prune/densenet201.sensitivities',
  48. 'Xception41':
  49. 'https://bj.bcebos.com/paddlex/slim_prune/xception41.sensitivities',
  50. 'Xception65':
  51. 'https://bj.bcebos.com/paddlex/slim_prune/xception65.sensitivities',
  52. 'YOLOv3_MobileNetV1':
  53. 'https://bj.bcebos.com/paddlex/slim_prune/yolov3_mobilenetv1.sensitivities',
  54. 'YOLOv3_MobileNetV3_large':
  55. 'https://bj.bcebos.com/paddlex/slim_prune/yolov3_mobilenetv3.sensitivities',
  56. 'YOLOv3_DarkNet53':
  57. 'https://bj.bcebos.com/paddlex/slim_prune/yolov3_darknet53.sensitivities',
  58. 'YOLOv3_ResNet34':
  59. 'https://bj.bcebos.com/paddlex/slim_prune/yolov3_resnet34.sensitivities',
  60. 'UNet':
  61. 'https://bj.bcebos.com/paddlex/slim_prune/unet.sensitivities',
  62. 'DeepLabv3p_MobileNetV2_x0.25':
  63. 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_mobilenetv2_x0.25_no_aspp_decoder.sensitivities',
  64. 'DeepLabv3p_MobileNetV2_x0.5':
  65. 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_mobilenetv2_x0.5_no_aspp_decoder.sensitivities',
  66. 'DeepLabv3p_MobileNetV2_x1.0':
  67. 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_mobilenetv2_x1.0_no_aspp_decoder.sensitivities',
  68. 'DeepLabv3p_MobileNetV2_x1.5':
  69. 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_mobilenetv2_x1.5_no_aspp_decoder.sensitivities',
  70. 'DeepLabv3p_MobileNetV2_x2.0':
  71. 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_mobilenetv2_x2.0_no_aspp_decoder.sensitivities',
  72. 'DeepLabv3p_MobileNetV2_x0.25_aspp_decoder':
  73. 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_mobilenetv2_x0.25_with_aspp_decoder.sensitivities',
  74. 'DeepLabv3p_MobileNetV2_x0.5_aspp_decoder':
  75. 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_mobilenetv2_x0.5_with_aspp_decoder.sensitivities',
  76. 'DeepLabv3p_MobileNetV2_x1.0_aspp_decoder':
  77. 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_mobilenetv2_x1.0_with_aspp_decoder.sensitivities',
  78. 'DeepLabv3p_MobileNetV2_x1.5_aspp_decoder':
  79. 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_mobilenetv2_x1.5_with_aspp_decoder.sensitivities',
  80. 'DeepLabv3p_MobileNetV2_x2.0_aspp_decoder':
  81. 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_mobilenetv2_x2.0_with_aspp_decoder.sensitivities',
  82. 'DeepLabv3p_Xception65_aspp_decoder':
  83. 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_xception65_with_aspp_decoder.sensitivities',
  84. 'DeepLabv3p_Xception41_aspp_decoder':
  85. 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_xception41_with_aspp_decoder.sensitivities'
  86. }
  87. def get_sensitivities(flag, model, save_dir):
  88. model_name = model.__class__.__name__
  89. model_type = model_name
  90. if hasattr(model, 'backbone'):
  91. model_type = model_name + '_' + model.backbone
  92. if model_type.startswith('DeepLabv3p_Xception'):
  93. model_type = model_type + '_' + 'aspp' + '_' + 'decoder'
  94. elif hasattr(model, 'encoder_with_aspp') or hasattr(
  95. model, 'enable_decoder'):
  96. model_type = model_type + '_' + 'aspp' + '_' + 'decoder'
  97. if osp.isfile(flag):
  98. return flag
  99. elif flag == 'DEFAULT':
  100. assert model_type in sensitivities_data, "There is not sensitivities data file for {}, you may need to calculate it by your self.".format(
  101. model_type)
  102. url = sensitivities_data[model_type]
  103. fname = osp.split(url)[-1]
  104. paddlex.utils.download(url, path=save_dir)
  105. return osp.join(save_dir, fname)
  106. # try:
  107. # hub.download(fname, save_path=save_dir)
  108. # except Exception as e:
  109. # if isinstance(e, hub.ResourceNotFoundError):
  110. # raise Exception(
  111. # "Resource for model {}(key='{}') not found".format(
  112. # model_type, fname))
  113. # elif isinstance(e, hub.ServerConnectionError):
  114. # raise Exception(
  115. # "Cannot get reource for model {}(key='{}'), please check your internet connecgtion"
  116. # .format(model_type, fname))
  117. # else:
  118. # raise Exception(
  119. # "Unexpected error, please make sure paddlehub >= 1.6.2 {}".
  120. # format(str(e)))
  121. # return osp.join(save_dir, fname)
  122. else:
  123. raise Exception(
  124. "sensitivities need to be defined as directory path or `DEFAULT`(download sensitivities automatically)."
  125. )
  126. def get_prune_params(model):
  127. prune_names = []
  128. model_type = model.__class__.__name__
  129. if model_type == 'BaseClassifier':
  130. model_type = model.model_name
  131. if hasattr(model, 'backbone'):
  132. backbone = model.backbone
  133. model_type += ('_' + backbone)
  134. program = model.test_prog
  135. if model_type.startswith('ResNet') or \
  136. model_type.startswith('DenseNet') or \
  137. model_type.startswith('DarkNet'):
  138. for block in program.blocks:
  139. for param in block.all_parameters():
  140. pd_var = fluid.global_scope().find_var(param.name)
  141. pd_param = pd_var.get_tensor()
  142. if len(np.array(pd_param).shape) == 4:
  143. prune_names.append(param.name)
  144. elif model_type == "MobileNetV1":
  145. prune_names.append("conv1_weights")
  146. for param in program.global_block().all_parameters():
  147. if "_sep_weights" in param.name:
  148. prune_names.append(param.name)
  149. elif model_type == "MobileNetV2":
  150. for param in program.global_block().all_parameters():
  151. if 'weight' not in param.name \
  152. or 'dwise' in param.name \
  153. or 'fc' in param.name :
  154. continue
  155. prune_names.append(param.name)
  156. elif model_type.startswith("MobileNetV3"):
  157. if model_type == 'MobileNetV3_small':
  158. expand_prune_id = [3, 4]
  159. else:
  160. expand_prune_id = [2, 3, 4, 8, 9, 11]
  161. for param in program.global_block().all_parameters():
  162. if ('expand_weights' in param.name and \
  163. int(param.name.split('_')[0][4:]) in expand_prune_id)\
  164. or 'linear_weights' in param.name \
  165. or 'se_1_weights' in param.name:
  166. prune_names.append(param.name)
  167. elif model_type.startswith('Xception') or \
  168. model_type.startswith('DeepLabv3p_Xception'):
  169. params_not_prune = [
  170. 'weights',
  171. 'xception_{}/exit_flow/block2/separable_conv3/pointwise/weights'.
  172. format(model_type[-2:]), 'encoder/concat/weights',
  173. 'decoder/concat/weights'
  174. ]
  175. for param in program.global_block().all_parameters():
  176. if 'weight' not in param.name \
  177. or 'dwise' in param.name \
  178. or 'depthwise' in param.name \
  179. or 'logit' in param.name:
  180. continue
  181. if param.name in params_not_prune:
  182. continue
  183. prune_names.append(param.name)
  184. elif model_type.startswith('YOLOv3'):
  185. for block in program.blocks:
  186. for param in block.all_parameters():
  187. if 'weights' in param.name and 'yolo_block' in param.name:
  188. prune_names.append(param.name)
  189. elif model_type.startswith('UNet'):
  190. for param in program.global_block().all_parameters():
  191. if 'weight' not in param.name:
  192. continue
  193. if 'logit' in param.name:
  194. continue
  195. prune_names.append(param.name)
  196. params_not_prune = [
  197. 'encode/block4/down/conv1/weights',
  198. 'encode/block3/down/conv1/weights',
  199. 'encode/block2/down/conv1/weights', 'encode/block1/conv1/weights'
  200. ]
  201. for i in params_not_prune:
  202. if i in prune_names:
  203. prune_names.remove(i)
  204. elif model_type.startswith('DeepLabv3p'):
  205. for param in program.global_block().all_parameters():
  206. if 'weight' not in param.name:
  207. continue
  208. if 'dwise' in param.name or 'depthwise' in param.name or 'logit' in param.name:
  209. continue
  210. prune_names.append(param.name)
  211. params_not_prune = [
  212. 'xception_{}/exit_flow/block2/separable_conv3/pointwise/weights'.
  213. format(model_type[-2:]), 'encoder/concat/weights',
  214. 'decoder/concat/weights'
  215. ]
  216. if model.encoder_with_aspp == True:
  217. params_not_prune.append(
  218. 'xception_{}/exit_flow/block2/separable_conv3/pointwise/weights'
  219. .format(model_type[-2:]))
  220. params_not_prune.append('conv8_1_linear_weights')
  221. for i in params_not_prune:
  222. if i in prune_names:
  223. prune_names.remove(i)
  224. else:
  225. raise Exception('The {} is not implement yet!'.format(model_type))
  226. return prune_names