prune_config.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import os.path as osp
  16. import paddle.fluid as fluid
  17. #import paddlehub as hub
  18. import paddlex
  19. sensitivities_data = {
  20. 'ResNet18':
  21. 'https://bj.bcebos.com/paddlex/slim_prune/resnet18.sensitivities',
  22. 'ResNet34':
  23. 'https://bj.bcebos.com/paddlex/slim_prune/resnet34.sensitivities',
  24. 'ResNet50':
  25. 'https://bj.bcebos.com/paddlex/slim_prune/resnet50.sensitivities',
  26. 'ResNet101':
  27. 'https://bj.bcebos.com/paddlex/slim_prune/resnet101.sensitivities',
  28. 'ResNet50_vd':
  29. 'https://bj.bcebos.com/paddlex/slim_prune/resnet50vd.sensitivities',
  30. 'ResNet101_vd':
  31. 'https://bj.bcebos.com/paddlex/slim_prune/resnet101vd.sensitivities',
  32. 'DarkNet53':
  33. 'https://bj.bcebos.com/paddlex/slim_prune/darknet53.sensitivities',
  34. 'MobileNetV1':
  35. 'https://bj.bcebos.com/paddlex/slim_prune/mobilenetv1.sensitivities',
  36. 'MobileNetV2':
  37. 'https://bj.bcebos.com/paddlex/slim_prune/mobilenetv2.sensitivities',
  38. 'MobileNetV3_large':
  39. 'https://bj.bcebos.com/paddlex/slim_prune/mobilenetv3_large.sensitivities',
  40. 'MobileNetV3_small':
  41. 'https://bj.bcebos.com/paddlex/slim_prune/mobilenetv3_small.sensitivities',
  42. 'DenseNet121':
  43. 'https://bj.bcebos.com/paddlex/slim_prune/densenet121.sensitivities',
  44. 'DenseNet161':
  45. 'https://bj.bcebos.com/paddlex/slim_prune/densenet161.sensitivities',
  46. 'DenseNet201':
  47. 'https://bj.bcebos.com/paddlex/slim_prune/densenet201.sensitivities',
  48. 'Xception41':
  49. 'https://bj.bcebos.com/paddlex/slim_prune/xception41.sensitivities',
  50. 'Xception65':
  51. 'https://bj.bcebos.com/paddlex/slim_prune/xception65.sensitivities',
  52. 'YOLOv3_MobileNetV1':
  53. 'https://bj.bcebos.com/paddlex/slim_prune/yolov3_mobilenetv1.sensitivities',
  54. 'YOLOv3_MobileNetV3_large':
  55. 'https://bj.bcebos.com/paddlex/slim_prune/yolov3_mobilenetv3.sensitivities',
  56. 'YOLOv3_DarkNet53':
  57. 'https://bj.bcebos.com/paddlex/slim_prune/yolov3_darknet53.sensitivities',
  58. 'YOLOv3_ResNet34':
  59. 'https://bj.bcebos.com/paddlex/slim_prune/yolov3_resnet34.sensitivities',
  60. 'UNet':
  61. 'https://bj.bcebos.com/paddlex/slim_prune/unet.sensitivities',
  62. 'DeepLabv3p_MobileNetV2_x0.25':
  63. 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_mobilenetv2_x0.25_no_aspp_decoder.sensitivities',
  64. 'DeepLabv3p_MobileNetV2_x0.5':
  65. 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_mobilenetv2_x0.5_no_aspp_decoder.sensitivities',
  66. 'DeepLabv3p_MobileNetV2_x1.0':
  67. 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_mobilenetv2_x1.0_no_aspp_decoder.sensitivities',
  68. 'DeepLabv3p_MobileNetV2_x1.5':
  69. 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_mobilenetv2_x1.5_no_aspp_decoder.sensitivities',
  70. 'DeepLabv3p_MobileNetV2_x2.0':
  71. 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_mobilenetv2_x2.0_no_aspp_decoder.sensitivities',
  72. 'DeepLabv3p_MobileNetV2_x0.25_aspp_decoder':
  73. 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_mobilenetv2_x0.25_with_aspp_decoder.sensitivities',
  74. 'DeepLabv3p_MobileNetV2_x0.5_aspp_decoder':
  75. 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_mobilenetv2_x0.5_with_aspp_decoder.sensitivities',
  76. 'DeepLabv3p_MobileNetV2_x1.0_aspp_decoder':
  77. 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_mobilenetv2_x1.0_with_aspp_decoder.sensitivities',
  78. 'DeepLabv3p_MobileNetV2_x1.5_aspp_decoder':
  79. 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_mobilenetv2_x1.5_with_aspp_decoder.sensitivities',
  80. 'DeepLabv3p_MobileNetV2_x2.0_aspp_decoder':
  81. 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_mobilenetv2_x2.0_with_aspp_decoder.sensitivities',
  82. 'DeepLabv3p_Xception65_aspp_decoder':
  83. 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_xception65_with_aspp_decoder.sensitivities',
  84. 'DeepLabv3p_Xception41_aspp_decoder':
  85. 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_xception41_with_aspp_decoder.sensitivities'
  86. }
  87. def get_sensitivities(flag, model, save_dir):
  88. model_name = model.__class__.__name__
  89. model_type = model_name
  90. if hasattr(model, 'backbone'):
  91. model_type = model_name + '_' + model.backbone
  92. if model_type.startswith('DeepLabv3p_Xception'):
  93. model_type = model_type + '_' + 'aspp' + '_' + 'decoder'
  94. elif hasattr(model, 'encoder_with_aspp') or hasattr(
  95. model, 'enable_decoder'):
  96. model_type = model_type + '_' + 'aspp' + '_' + 'decoder'
  97. if osp.isfile(flag):
  98. return flag
  99. elif flag == 'DEFAULT':
  100. assert model_type in sensitivities_data, "There is not sensitivities data file for {}, you may need to calculate it by your self.".format(
  101. model_type)
  102. url = sensitivities_data[model_type]
  103. fname = osp.split(url)[-1]
  104. paddlex.utils.download(url, path=save_dir)
  105. return osp.join(save_dir, fname)
  106. # try:
  107. # hub.download(fname, save_path=save_dir)
  108. # except Exception as e:
  109. # if isinstance(e, hub.ResourceNotFoundError):
  110. # raise Exception(
  111. # "Resource for model {}(key='{}') not found".format(
  112. # model_type, fname))
  113. # elif isinstance(e, hub.ServerConnectionError):
  114. # raise Exception(
  115. # "Cannot get reource for model {}(key='{}'), please check your internet connecgtion"
  116. # .format(model_type, fname))
  117. # else:
  118. # raise Exception(
  119. # "Unexpected error, please make sure paddlehub >= 1.6.2 {}".
  120. # format(str(e)))
  121. # return osp.join(save_dir, fname)
  122. else:
  123. raise Exception(
  124. "sensitivities need to be defined as directory path or `DEFAULT`(download sensitivities automatically)."
  125. )
  126. def get_prune_params(model):
  127. prune_names = []
  128. model_type = model.__class__.__name__
  129. if model_type == 'BaseClassifier':
  130. model_type = model.model_name
  131. if hasattr(model, 'backbone'):
  132. backbone = model.backbone
  133. model_type += ('_' + backbone)
  134. program = model.test_prog
  135. if model_type.startswith('ResNet') or \
  136. model_type.startswith('DenseNet') or \
  137. model_type.startswith('DarkNet') or \
  138. model_type.startswith('AlexNet'):
  139. for block in program.blocks:
  140. for param in block.all_parameters():
  141. pd_var = fluid.global_scope().find_var(param.name)
  142. pd_param = pd_var.get_tensor()
  143. if len(np.array(pd_param).shape) == 4:
  144. prune_names.append(param.name)
  145. if model_type == 'AlexNet':
  146. prune_names.remove('conv5_weights')
  147. elif model_type == "MobileNetV1":
  148. prune_names.append("conv1_weights")
  149. for param in program.global_block().all_parameters():
  150. if "_sep_weights" in param.name:
  151. prune_names.append(param.name)
  152. elif model_type == "MobileNetV2":
  153. for param in program.global_block().all_parameters():
  154. if 'weight' not in param.name \
  155. or 'dwise' in param.name \
  156. or 'fc' in param.name :
  157. continue
  158. prune_names.append(param.name)
  159. elif model_type.startswith("MobileNetV3"):
  160. if model_type.startswith('MobileNetV3_small'):
  161. expand_prune_id = [3, 4]
  162. else:
  163. expand_prune_id = [2, 3, 4, 8, 9, 11]
  164. for param in program.global_block().all_parameters():
  165. if ('expand_weights' in param.name and \
  166. int(param.name.split('_')[0][4:]) in expand_prune_id)\
  167. or 'linear_weights' in param.name \
  168. or 'se_1_weights' in param.name:
  169. prune_names.append(param.name)
  170. elif model_type.startswith('Xception') or \
  171. model_type.startswith('DeepLabv3p_Xception'):
  172. params_not_prune = [
  173. 'weights',
  174. 'xception_{}/exit_flow/block2/separable_conv3/pointwise/weights'.
  175. format(model_type[-2:]), 'encoder/concat/weights',
  176. 'decoder/concat/weights'
  177. ]
  178. for param in program.global_block().all_parameters():
  179. if 'weight' not in param.name \
  180. or 'dwise' in param.name \
  181. or 'depthwise' in param.name \
  182. or 'logit' in param.name:
  183. continue
  184. if param.name in params_not_prune:
  185. continue
  186. prune_names.append(param.name)
  187. elif model_type.startswith('YOLOv3'):
  188. for block in program.blocks:
  189. for param in block.all_parameters():
  190. if 'weights' in param.name and 'yolo_block' in param.name:
  191. prune_names.append(param.name)
  192. elif model_type.startswith('UNet'):
  193. for param in program.global_block().all_parameters():
  194. if 'weight' not in param.name:
  195. continue
  196. if 'logit' in param.name:
  197. continue
  198. prune_names.append(param.name)
  199. params_not_prune = [
  200. 'encode/block4/down/conv1/weights',
  201. 'encode/block3/down/conv1/weights',
  202. 'encode/block2/down/conv1/weights', 'encode/block1/conv1/weights'
  203. ]
  204. for i in params_not_prune:
  205. if i in prune_names:
  206. prune_names.remove(i)
  207. elif model_type.startswith('DeepLabv3p'):
  208. for param in program.global_block().all_parameters():
  209. if 'weight' not in param.name:
  210. continue
  211. if 'dwise' in param.name or 'depthwise' in param.name or 'logit' in param.name:
  212. continue
  213. prune_names.append(param.name)
  214. params_not_prune = [
  215. 'xception_{}/exit_flow/block2/separable_conv3/pointwise/weights'.
  216. format(model_type[-2:]), 'encoder/concat/weights',
  217. 'decoder/concat/weights'
  218. ]
  219. if model.encoder_with_aspp == True:
  220. params_not_prune.append(
  221. 'xception_{}/exit_flow/block2/separable_conv3/pointwise/weights'
  222. .format(model_type[-2:]))
  223. params_not_prune.append('conv8_1_linear_weights')
  224. for i in params_not_prune:
  225. if i in prune_names:
  226. prune_names.remove(i)
  227. else:
  228. raise Exception('The {} is not implement yet!'.format(model_type))
  229. return prune_names