prune_config.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import os.path as osp
  16. import paddle.fluid as fluid
  17. #import paddlehub as hub
  18. import paddlex
  19. sensitivities_data = {
  20. 'AlexNet':
  21. 'https://bj.bcebos.com/paddlex/slim_prune/alexnet_sensitivities.data',
  22. 'ResNet18':
  23. 'https://bj.bcebos.com/paddlex/slim_prune/resnet18.sensitivities',
  24. 'ResNet34':
  25. 'https://bj.bcebos.com/paddlex/slim_prune/resnet34.sensitivities',
  26. 'ResNet50':
  27. 'https://bj.bcebos.com/paddlex/slim_prune/resnet50.sensitivities',
  28. 'ResNet101':
  29. 'https://bj.bcebos.com/paddlex/slim_prune/resnet101.sensitivities',
  30. 'ResNet50_vd':
  31. 'https://bj.bcebos.com/paddlex/slim_prune/resnet50vd.sensitivities',
  32. 'ResNet101_vd':
  33. 'https://bj.bcebos.com/paddlex/slim_prune/resnet101vd.sensitivities',
  34. 'DarkNet53':
  35. 'https://bj.bcebos.com/paddlex/slim_prune/darknet53.sensitivities',
  36. 'MobileNetV1':
  37. 'https://bj.bcebos.com/paddlex/slim_prune/mobilenetv1.sensitivities',
  38. 'MobileNetV2':
  39. 'https://bj.bcebos.com/paddlex/slim_prune/mobilenetv2.sensitivities',
  40. 'MobileNetV3_large':
  41. 'https://bj.bcebos.com/paddlex/slim_prune/mobilenetv3_large.sensitivities',
  42. 'MobileNetV3_small':
  43. 'https://bj.bcebos.com/paddlex/slim_prune/mobilenetv3_small.sensitivities',
  44. 'MobileNetV3_large_ssld':
  45. 'https://bj.bcebos.com/paddlex/slim_prune/mobilenetv3_large_ssld_sensitivities.data',
  46. 'MobileNetV3_small_ssld':
  47. 'https://bj.bcebos.com/paddlex/slim_prune/mobilenetv3_small_ssld_sensitivities.data',
  48. 'DenseNet121':
  49. 'https://bj.bcebos.com/paddlex/slim_prune/densenet121.sensitivities',
  50. 'DenseNet161':
  51. 'https://bj.bcebos.com/paddlex/slim_prune/densenet161.sensitivities',
  52. 'DenseNet201':
  53. 'https://bj.bcebos.com/paddlex/slim_prune/densenet201.sensitivities',
  54. 'Xception41':
  55. 'https://bj.bcebos.com/paddlex/slim_prune/xception41.sensitivities',
  56. 'Xception65':
  57. 'https://bj.bcebos.com/paddlex/slim_prune/xception65.sensitivities',
  58. 'ShuffleNetV2':
  59. 'https://bj.bcebos.com/paddlex/slim_prune/shufflenetv2_sensitivities.data',
  60. 'YOLOv3_MobileNetV1':
  61. 'https://bj.bcebos.com/paddlex/slim_prune/yolov3_mobilenetv1.sensitivities',
  62. 'YOLOv3_MobileNetV3_large':
  63. 'https://bj.bcebos.com/paddlex/slim_prune/yolov3_mobilenetv3.sensitivities',
  64. 'YOLOv3_DarkNet53':
  65. 'https://bj.bcebos.com/paddlex/slim_prune/yolov3_darknet53.sensitivities',
  66. 'YOLOv3_ResNet34':
  67. 'https://bj.bcebos.com/paddlex/slim_prune/yolov3_resnet34.sensitivities',
  68. 'UNet': 'https://bj.bcebos.com/paddlex/slim_prune/unet.sensitivities',
  69. 'DeepLabv3p_MobileNetV2_x0.25':
  70. 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_mobilenetv2_x0.25_no_aspp_decoder.sensitivities',
  71. 'DeepLabv3p_MobileNetV2_x0.5':
  72. 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_mobilenetv2_x0.5_no_aspp_decoder.sensitivities',
  73. 'DeepLabv3p_MobileNetV2_x1.0':
  74. 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_mobilenetv2_x1.0_no_aspp_decoder.sensitivities',
  75. 'DeepLabv3p_MobileNetV2_x1.5':
  76. 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_mobilenetv2_x1.5_no_aspp_decoder.sensitivities',
  77. 'DeepLabv3p_MobileNetV2_x2.0':
  78. 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_mobilenetv2_x2.0_no_aspp_decoder.sensitivities',
  79. 'DeepLabv3p_MobileNetV2_x0.25_aspp_decoder':
  80. 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_mobilenetv2_x0.25_with_aspp_decoder.sensitivities',
  81. 'DeepLabv3p_MobileNetV2_x0.5_aspp_decoder':
  82. 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_mobilenetv2_x0.5_with_aspp_decoder.sensitivities',
  83. 'DeepLabv3p_MobileNetV2_x1.0_aspp_decoder':
  84. 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_mobilenetv2_x1.0_with_aspp_decoder.sensitivities',
  85. 'DeepLabv3p_MobileNetV2_x1.5_aspp_decoder':
  86. 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_mobilenetv2_x1.5_with_aspp_decoder.sensitivities',
  87. 'DeepLabv3p_MobileNetV2_x2.0_aspp_decoder':
  88. 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_mobilenetv2_x2.0_with_aspp_decoder.sensitivities',
  89. 'DeepLabv3p_Xception65_aspp_decoder':
  90. 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_xception65_with_aspp_decoder.sensitivities',
  91. 'DeepLabv3p_Xception41_aspp_decoder':
  92. 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_xception41_with_aspp_decoder.sensitivities',
  93. 'HRNet_W18_Seg':
  94. 'https://bj.bcebos.com/paddlex/slim_prune/hrnet_w18.sensitivities',
  95. 'HRNet_W30_Seg':
  96. 'https://bj.bcebos.com/paddlex/slim_prune/hrnet_w30.sensitivities',
  97. 'HRNet_W32_Seg':
  98. 'https://bj.bcebos.com/paddlex/slim_prune/hrnet_w32.sensitivities',
  99. 'HRNet_W40_Seg':
  100. 'https://bj.bcebos.com/paddlex/slim_prune/hrnet_w40.sensitivities',
  101. 'HRNet_W44_Seg':
  102. 'https://bj.bcebos.com/paddlex/slim_prune/hrnet_w44.sensitivities',
  103. 'HRNet_W48_Seg':
  104. 'https://bj.bcebos.com/paddlex/slim_prune/hrnet_w48.sensitivities',
  105. 'HRNet_W64_Seg':
  106. 'https://bj.bcebos.com/paddlex/slim_prune/hrnet_w64.sensitivities',
  107. 'FastSCNN':
  108. 'https://bj.bcebos.com/paddlex/slim_prune/fast_scnn.sensitivities'
  109. }
  110. def get_sensitivities(flag, model, save_dir):
  111. model_name = model.__class__.__name__
  112. model_type = model_name
  113. if hasattr(model, 'backbone'):
  114. model_type = model_name + '_' + model.backbone
  115. if model_type.startswith('DeepLabv3p_Xception'):
  116. model_type = model_type + '_' + 'aspp' + '_' + 'decoder'
  117. elif hasattr(model, 'encoder_with_aspp') or hasattr(model,
  118. 'enable_decoder'):
  119. model_type = model_type + '_' + 'aspp' + '_' + 'decoder'
  120. if model_type.startswith('HRNet') and model.model_type == 'segmenter':
  121. model_type = '{}_W{}_Seg'.format(model_type, model.width)
  122. if osp.isfile(flag):
  123. return flag
  124. elif flag == 'DEFAULT':
  125. assert model_type in sensitivities_data, "There is not sensitivities data file for {}, you may need to calculate it by your self.".format(
  126. model_type)
  127. url = sensitivities_data[model_type]
  128. fname = osp.split(url)[-1]
  129. paddlex.utils.download(url, path=save_dir)
  130. return osp.join(save_dir, fname)
  131. # try:
  132. # hub.download(fname, save_path=save_dir)
  133. # except Exception as e:
  134. # if isinstance(e, hub.ResourceNotFoundError):
  135. # raise Exception(
  136. # "Resource for model {}(key='{}') not found".format(
  137. # model_type, fname))
  138. # elif isinstance(e, hub.ServerConnectionError):
  139. # raise Exception(
  140. # "Cannot get reource for model {}(key='{}'), please check your internet connection"
  141. # .format(model_type, fname))
  142. # else:
  143. # raise Exception(
  144. # "Unexpected error, please make sure paddlehub >= 1.6.2 {}".
  145. # format(str(e)))
  146. # return osp.join(save_dir, fname)
  147. else:
  148. raise Exception(
  149. "sensitivities need to be defined as directory path or `DEFAULT`(download sensitivities automatically)."
  150. )
  151. def get_prune_params(model):
  152. prune_names = []
  153. model_type = model.__class__.__name__
  154. if model_type == 'BaseClassifier':
  155. model_type = model.model_name
  156. if hasattr(model, 'backbone'):
  157. backbone = model.backbone
  158. model_type += ('_' + backbone)
  159. program = model.test_prog
  160. if model_type.startswith('ResNet') or \
  161. model_type.startswith('DenseNet') or \
  162. model_type.startswith('DarkNet') or \
  163. model_type.startswith('AlexNet') or \
  164. model_type.startswith('ShuffleNetV2'):
  165. for block in program.blocks:
  166. for param in block.all_parameters():
  167. pd_var = model.scope.find_var(param.name)
  168. try:
  169. pd_param = pd_var.get_tensor()
  170. if len(np.array(pd_param).shape) == 4:
  171. prune_names.append(param.name)
  172. except Exception as e:
  173. print("None Tensor Name: ", param.name)
  174. print("Error message: {}".format(e))
  175. if model_type == 'AlexNet':
  176. prune_names.remove('conv5_weights')
  177. if model_type == 'ShuffleNetV2':
  178. not_prune_names = [
  179. 'stage_2_1_conv5_weights',
  180. 'stage_2_1_conv3_weights',
  181. 'stage_2_2_conv3_weights',
  182. 'stage_2_3_conv3_weights',
  183. 'stage_2_4_conv3_weights',
  184. 'stage_3_1_conv5_weights',
  185. 'stage_3_1_conv3_weights',
  186. 'stage_3_2_conv3_weights',
  187. 'stage_3_3_conv3_weights',
  188. 'stage_3_4_conv3_weights',
  189. 'stage_3_5_conv3_weights',
  190. 'stage_3_6_conv3_weights',
  191. 'stage_3_7_conv3_weights',
  192. 'stage_3_8_conv3_weights',
  193. 'stage_4_1_conv5_weights',
  194. 'stage_4_1_conv3_weights',
  195. 'stage_4_2_conv3_weights',
  196. 'stage_4_3_conv3_weights',
  197. 'stage_4_4_conv3_weights',
  198. ]
  199. for name in not_prune_names:
  200. prune_names.remove(name)
  201. elif model_type == "MobileNetV1":
  202. prune_names.append("conv1_weights")
  203. for param in program.global_block().all_parameters():
  204. if "_sep_weights" in param.name:
  205. prune_names.append(param.name)
  206. elif model_type == "MobileNetV2":
  207. for param in program.global_block().all_parameters():
  208. if 'weight' not in param.name \
  209. or 'dwise' in param.name \
  210. or 'fc' in param.name :
  211. continue
  212. prune_names.append(param.name)
  213. elif model_type.startswith("MobileNetV3"):
  214. if model_type.startswith('MobileNetV3_small'):
  215. expand_prune_id = [3, 4]
  216. else:
  217. expand_prune_id = [2, 3, 4, 8, 9, 11]
  218. for param in program.global_block().all_parameters():
  219. if ('expand_weights' in param.name and \
  220. int(param.name.split('_')[0][4:]) in expand_prune_id)\
  221. or 'linear_weights' in param.name \
  222. or 'se_1_weights' in param.name:
  223. prune_names.append(param.name)
  224. elif model_type.startswith('Xception') or \
  225. model_type.startswith('DeepLabv3p_Xception'):
  226. params_not_prune = [
  227. 'weights',
  228. 'xception_{}/exit_flow/block2/separable_conv3/pointwise/weights'.
  229. format(model_type[-2:]), 'encoder/concat/weights',
  230. 'decoder/concat/weights'
  231. ]
  232. for param in program.global_block().all_parameters():
  233. if 'weight' not in param.name \
  234. or 'dwise' in param.name \
  235. or 'depthwise' in param.name \
  236. or 'logit' in param.name:
  237. continue
  238. if param.name in params_not_prune:
  239. continue
  240. prune_names.append(param.name)
  241. elif model_type.startswith('YOLOv3'):
  242. for block in program.blocks:
  243. for param in block.all_parameters():
  244. if 'weights' in param.name and 'yolo_block' in param.name:
  245. prune_names.append(param.name)
  246. elif model_type.startswith('UNet'):
  247. for param in program.global_block().all_parameters():
  248. if 'weight' not in param.name:
  249. continue
  250. if 'logit' in param.name:
  251. continue
  252. prune_names.append(param.name)
  253. params_not_prune = [
  254. 'encode/block4/down/conv1/weights',
  255. 'encode/block3/down/conv1/weights',
  256. 'encode/block2/down/conv1/weights', 'encode/block1/conv1/weights'
  257. ]
  258. for i in params_not_prune:
  259. if i in prune_names:
  260. prune_names.remove(i)
  261. elif model_type.startswith('HRNet') and model.model_type == 'segmenter':
  262. for param in program.global_block().all_parameters():
  263. if 'weight' not in param.name:
  264. continue
  265. prune_names.append(param.name)
  266. params_not_prune = ['conv-1_weights']
  267. for i in params_not_prune:
  268. if i in prune_names:
  269. prune_names.remove(i)
  270. elif model_type.startswith('FastSCNN'):
  271. for param in program.global_block().all_parameters():
  272. if 'weight' not in param.name:
  273. continue
  274. if 'dwise' in param.name or 'depthwise' in param.name or 'logit' in param.name:
  275. continue
  276. prune_names.append(param.name)
  277. params_not_prune = ['classifier/weights']
  278. for i in params_not_prune:
  279. if i in prune_names:
  280. prune_names.remove(i)
  281. elif model_type.startswith('DeepLabv3p'):
  282. if model_type.lower() == "deeplabv3p_mobilenetv3_large_x1_0_ssld":
  283. params_not_prune = [
  284. 'last_1x1_conv_weights', 'conv14_se_2_weights',
  285. 'conv16_depthwise_weights', 'conv13_depthwise_weights',
  286. 'conv15_se_2_weights', 'conv2_depthwise_weights',
  287. 'conv6_depthwise_weights', 'conv8_depthwise_weights',
  288. 'fc_weights', 'conv3_depthwise_weights', 'conv7_se_2_weights',
  289. 'conv16_expand_weights', 'conv16_se_2_weights',
  290. 'conv10_depthwise_weights', 'conv11_depthwise_weights',
  291. 'conv15_expand_weights', 'conv5_expand_weights',
  292. 'conv15_depthwise_weights', 'conv14_depthwise_weights',
  293. 'conv12_se_2_weights', 'conv1_weights',
  294. 'conv13_expand_weights', 'conv_last_weights',
  295. 'conv12_depthwise_weights', 'conv13_se_2_weights',
  296. 'conv12_expand_weights', 'conv5_depthwise_weights',
  297. 'conv6_se_2_weights', 'conv10_expand_weights',
  298. 'conv9_depthwise_weights', 'conv6_expand_weights',
  299. 'conv5_se_2_weights', 'conv14_expand_weights',
  300. 'conv4_depthwise_weights', 'conv7_expand_weights',
  301. 'conv7_depthwise_weights', 'encoder/aspp0/weights',
  302. 'decoder/merge/weights', 'encoder/image_pool/weights',
  303. 'decoder/weights'
  304. ]
  305. for param in program.global_block().all_parameters():
  306. if 'weight' not in param.name:
  307. continue
  308. if 'dwise' in param.name or 'depthwise' in param.name or 'logit' in param.name:
  309. continue
  310. if model_type.lower() == "deeplabv3p_mobilenetv3_large_x1_0_ssld":
  311. if param.name in params_not_prune:
  312. continue
  313. prune_names.append(param.name)
  314. params_not_prune = [
  315. 'xception_{}/exit_flow/block2/separable_conv3/pointwise/weights'.
  316. format(model_type[-2:]), 'encoder/concat/weights',
  317. 'decoder/concat/weights'
  318. ]
  319. if model.encoder_with_aspp == True:
  320. params_not_prune.append(
  321. 'xception_{}/exit_flow/block2/separable_conv3/pointwise/weights'
  322. .format(model_type[-2:]))
  323. params_not_prune.append('conv8_1_linear_weights')
  324. for i in params_not_prune:
  325. if i in prune_names:
  326. prune_names.remove(i)
  327. elif 'RCNN' in model_type:
  328. for block in program.blocks:
  329. for param in block.all_parameters():
  330. pd_var = model.scope.find_var(param.name)
  331. pd_param = pd_var.get_tensor()
  332. if len(np.array(pd_param).shape) == 4:
  333. if 'fpn' in param.name or 'rpn' in param.name or 'fc' in param.name or 'cls' in param.name or 'bbox' in param.name:
  334. continue
  335. prune_names.append(param.name)
  336. else:
  337. raise Exception('The {} is not implement yet!'.format(model_type))
  338. return prune_names