prune.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import yaml
  16. import time
  17. import pickle
  18. import os
  19. import os.path as osp
  20. from functools import reduce
  21. import paddle.fluid as fluid
  22. from multiprocessing import Process, Queue
  23. from .prune_config import get_prune_params
  24. import paddlex.utils.logging as logging
  25. from paddlex.utils import seconds_to_hms
  26. def sensitivity(program,
  27. place,
  28. param_names,
  29. eval_func,
  30. sensitivities_file=None,
  31. pruned_ratios=None,
  32. scope=None):
  33. import paddleslim
  34. from paddleslim.prune import Pruner, load_sensitivities
  35. from paddleslim.core import GraphWrapper
  36. if scope is None:
  37. scope = fluid.global_scope()
  38. else:
  39. scope = scope
  40. graph = GraphWrapper(program)
  41. sensitivities = load_sensitivities(sensitivities_file)
  42. if pruned_ratios is None:
  43. pruned_ratios = np.arange(0.1, 1, step=0.1)
  44. total_evaluate_iters = 0
  45. for name in param_names:
  46. if name not in sensitivities:
  47. sensitivities[name] = {}
  48. total_evaluate_iters += len(list(pruned_ratios))
  49. else:
  50. total_evaluate_iters += (
  51. len(list(pruned_ratios)) - len(sensitivities[name]))
  52. eta = '-'
  53. start_time = time.time()
  54. baseline = eval_func(graph.program)
  55. cost = time.time() - start_time
  56. eta = cost * (total_evaluate_iters - 1)
  57. current_iter = 1
  58. for name in sensitivities:
  59. for ratio in pruned_ratios:
  60. if ratio in sensitivities[name]:
  61. logging.debug('{}, {} has computed.'.format(name, ratio))
  62. continue
  63. progress = float(current_iter) / total_evaluate_iters
  64. progress = "%.2f%%" % (progress * 100)
  65. logging.info(
  66. "Total evaluate iters={}, current={}, progress={}, eta={}".
  67. format(total_evaluate_iters, current_iter, progress,
  68. seconds_to_hms(
  69. int(cost * (total_evaluate_iters - current_iter)))),
  70. use_color=True)
  71. current_iter += 1
  72. pruner = Pruner()
  73. logging.info("sensitive - param: {}; ratios: {}".format(name,
  74. ratio))
  75. pruned_program, param_backup, _ = pruner.prune(
  76. program=graph.program,
  77. scope=scope,
  78. params=[name],
  79. ratios=[ratio],
  80. place=place,
  81. lazy=True,
  82. only_graph=False,
  83. param_backup=True)
  84. pruned_metric = eval_func(pruned_program)
  85. loss = (baseline - pruned_metric) / baseline
  86. logging.info("pruned param: {}; {}; loss={}".format(name, ratio,
  87. loss))
  88. sensitivities[name][ratio] = loss
  89. with open(sensitivities_file, 'wb') as f:
  90. pickle.dump(sensitivities, f)
  91. for param_name in param_backup.keys():
  92. param_t = scope.find_var(param_name).get_tensor()
  93. param_t.set(param_backup[param_name], place)
  94. return sensitivities
  95. def channel_prune(program,
  96. prune_names,
  97. prune_ratios,
  98. place,
  99. only_graph=False,
  100. scope=None):
  101. """通道裁剪。
  102. Args:
  103. program (paddle.fluid.Program): 需要裁剪的Program,Program的具体介绍可参见
  104. https://paddlepaddle.org.cn/documentation/docs/zh/beginners_guide/basic_concept/program.html#program。
  105. prune_names (list): 由裁剪参数名组成的参数列表。
  106. prune_ratios (list): 由裁剪率组成的参数列表,与prune_names中的参数列表意义对应。
  107. place (paddle.fluid.CUDAPlace/paddle.fluid.CPUPlace): 运行设备。
  108. only_graph (bool): 是否只修改网络图,当为False时代表同时修改网络图和
  109. scope(全局作用域)中的参数。默认为False。
  110. Returns:
  111. paddle.fluid.Program: 裁剪后的Program。
  112. """
  113. import paddleslim
  114. from paddleslim.prune import Pruner, load_sensitivities
  115. from paddleslim.core import GraphWrapper
  116. prog_var_shape_dict = {}
  117. for var in program.list_vars():
  118. try:
  119. prog_var_shape_dict[var.name] = var.shape
  120. except Exception:
  121. pass
  122. index = 0
  123. for param, ratio in zip(prune_names, prune_ratios):
  124. origin_num = prog_var_shape_dict[param][0]
  125. pruned_num = int(round(origin_num * ratio))
  126. while origin_num == pruned_num:
  127. ratio -= 0.1
  128. pruned_num = int(round(origin_num * (ratio)))
  129. prune_ratios[index] = ratio
  130. index += 1
  131. if scope is None:
  132. scope = fluid.global_scope()
  133. pruner = Pruner()
  134. program, _, _ = pruner.prune(
  135. program,
  136. scope,
  137. params=prune_names,
  138. ratios=prune_ratios,
  139. place=place,
  140. lazy=False,
  141. only_graph=only_graph,
  142. param_backup=False,
  143. param_shape_backup=False)
  144. return program
  145. def prune_program(model, prune_params_ratios=None):
  146. """根据裁剪参数和裁剪率裁剪Program。
  147. 1. 裁剪训练Program和测试Program。
  148. 2. 使用裁剪后的Program更新模型中的train_prog和test_prog。
  149. 【注意】Program的具体介绍可参见
  150. https://paddlepaddle.org.cn/documentation/docs/zh/beginners_guide/basic_concept/program.html#program。
  151. Args:
  152. model (paddlex.cv.models): paddlex中的模型。
  153. prune_params_ratios (dict): 由裁剪参数名和裁剪率组成的字典,当为None时
  154. 使用默认裁剪参数名和裁剪率。默认为None。
  155. """
  156. import paddleslim
  157. from paddleslim.prune import Pruner, load_sensitivities
  158. from paddleslim.core import GraphWrapper
  159. assert model.status == 'Normal', 'Only the models saved while training are supported!'
  160. place = model.places[0]
  161. train_prog = model.train_prog
  162. eval_prog = model.test_prog
  163. valid_prune_names = get_prune_params(model)
  164. assert set(list(prune_params_ratios.keys())) & set(valid_prune_names), \
  165. "All params in 'prune_params_ratios' can't be pruned!"
  166. prune_names = list(
  167. set(list(prune_params_ratios.keys())) & set(valid_prune_names))
  168. prune_ratios = [
  169. prune_params_ratios[prune_name] for prune_name in prune_names
  170. ]
  171. model.train_prog = channel_prune(
  172. train_prog, prune_names, prune_ratios, place, scope=model.scope)
  173. model.test_prog = channel_prune(
  174. eval_prog,
  175. prune_names,
  176. prune_ratios,
  177. place,
  178. only_graph=True,
  179. scope=model.scope)
  180. def update_program(program, model_dir, place, scope=None):
  181. """根据裁剪信息更新Program和参数。
  182. Args:
  183. program (paddle.fluid.Program): 需要更新的Program,Program的具体介绍可参见
  184. https://paddlepaddle.org.cn/documentation/docs/zh/beginners_guide/basic_concept/program.html#program。
  185. model_dir (str): 模型存储路径。
  186. place (paddle.fluid.CUDAPlace/paddle.fluid.CPUPlace): 运行设备。
  187. Returns:
  188. paddle.fluid.Program: 更新后的Program。
  189. """
  190. import paddleslim
  191. from paddleslim.prune import Pruner, load_sensitivities
  192. from paddleslim.core import GraphWrapper
  193. graph = GraphWrapper(program)
  194. with open(osp.join(model_dir, "prune.yml")) as f:
  195. shapes = yaml.load(f.read(), Loader=yaml.Loader)
  196. for param, shape in shapes.items():
  197. graph.var(param).set_shape(shape)
  198. if scope is None:
  199. scope = fluid.global_scope()
  200. for block in program.blocks:
  201. for param in block.all_parameters():
  202. if param.name in shapes:
  203. param_tensor = scope.find_var(param.name).get_tensor()
  204. param_tensor.set(
  205. np.zeros(list(shapes[param.name])).astype('float32'), place)
  206. graph.update_groups_of_conv()
  207. graph.infer_shape()
  208. return program
  209. def cal_params_sensitivities(model, save_file, eval_dataset, batch_size=8):
  210. """计算模型中可裁剪卷积Kernel的敏感度。
  211. 1. 获取模型中可裁剪卷积Kernel的名称。
  212. 2. 计算每个可裁剪卷积Kernel不同裁剪率下的敏感度。
  213. 【注意】卷积的敏感度是指在不同裁剪率下评估数据集预测精度的损失,
  214. 通过得到的敏感度,可以决定最终模型需要裁剪的参数列表和各裁剪参数对应的裁剪率。
  215. Args:
  216. model (paddlex.cv.models): paddlex中的模型。
  217. save_file (str): 计算的得到的sensetives文件存储路径。
  218. eval_dataset (paddlex.datasets): 验证数据读取器。
  219. batch_size (int): 验证数据批大小。默认为8。
  220. Returns:
  221. dict: 由参数名和不同裁剪率下敏感度组成的字典。存储的信息如下:
  222. .. code-block:: python
  223. {"weight_0":
  224. {0.1: 0.22,
  225. 0.2: 0.33
  226. },
  227. "weight_1":
  228. {0.1: 0.21,
  229. 0.2: 0.4
  230. }
  231. }
  232. 其中``weight_0``是卷积Kernel名;``sensitivities['weight_0']``是一个字典,key是裁剪率,value是敏感度。
  233. """
  234. import paddleslim
  235. from paddleslim.prune import Pruner, load_sensitivities
  236. from paddleslim.core import GraphWrapper
  237. assert model.status == 'Normal', 'Only the models saved while training are supported!'
  238. if os.path.exists(save_file):
  239. os.remove(save_file)
  240. prune_names = get_prune_params(model)
  241. def eval_for_prune(program):
  242. eval_metrics = model.evaluate(
  243. eval_dataset=eval_dataset,
  244. batch_size=batch_size,
  245. return_details=False)
  246. primary_key = list(eval_metrics.keys())[0]
  247. return eval_metrics[primary_key]
  248. sensitivitives = sensitivity(
  249. model.test_prog,
  250. model.places[0],
  251. prune_names,
  252. eval_for_prune,
  253. sensitivities_file=save_file,
  254. pruned_ratios=list(np.arange(0.1, 1, 0.1)),
  255. scope=model.scope)
  256. return sensitivitives
  257. def analysis(model, dataset, batch_size=8, save_file='./model.sensi.data'):
  258. return cal_params_sensitivities(
  259. model, eval_dataset=dataset, batch_size=batch_size, save_file=save_file)
  260. def get_params_ratios(sensitivities_file, eval_metric_loss=0.05):
  261. """根据设定的精度损失容忍度metric_loss_thresh和计算保存的模型参数敏感度信息文件sensetive_file,
  262. 获取裁剪的参数配置。
  263. 【注意】metric_loss_thresh并不确保最终裁剪后的模型在fine-tune后的模型效果,仅为预估值。
  264. Args:
  265. sensitivities_file (str): 敏感度文件存储路径。
  266. eval_metric_loss (float): 可容忍的精度损失。默认为0.05。
  267. Returns:
  268. dict: 由参数名和裁剪率组成的字典。存储的信息如下:
  269. .. code-block:: python
  270. {"weight_0": 0.1,
  271. "weight_1": 0.2
  272. }
  273. 其中key是卷积Kernel名;value是裁剪率。
  274. """
  275. import paddleslim
  276. from paddleslim.prune import Pruner, load_sensitivities
  277. from paddleslim.core import GraphWrapper
  278. if not osp.exists(sensitivities_file):
  279. raise Exception('The sensitivities file is not exists!')
  280. sensitivitives = paddleslim.prune.load_sensitivities(sensitivities_file)
  281. params_ratios = paddleslim.prune.get_ratios_by_loss(sensitivitives,
  282. eval_metric_loss)
  283. return params_ratios
  284. def cal_model_size(program,
  285. place,
  286. sensitivities_file,
  287. eval_metric_loss=0.05,
  288. scope=None):
  289. """在可容忍的精度损失下,计算裁剪后模型大小相对于当前模型大小的比例。
  290. Args:
  291. program (paddle.fluid.Program): 需要裁剪的Program,Program的具体介绍可参见
  292. https://paddlepaddle.org.cn/documentation/docs/zh/beginners_guide/basic_concept/program.html#program。
  293. place (paddle.fluid.CUDAPlace/paddle.fluid.CPUPlace): 运行设备。
  294. sensitivities_file (str): 敏感度文件存储路径。
  295. eval_metric_loss (float): 可容忍的精度损失。默认为0.05。
  296. Returns:
  297. float: 裁剪后模型大小相对于当前模型大小的比例。
  298. """
  299. import paddleslim
  300. from paddleslim.prune import Pruner, load_sensitivities
  301. from paddleslim.core import GraphWrapper
  302. prune_params_ratios = get_params_ratios(sensitivities_file,
  303. eval_metric_loss)
  304. prog_var_shape_dict = {}
  305. for var in program.list_vars():
  306. try:
  307. prog_var_shape_dict[var.name] = var.shape
  308. except Exception:
  309. pass
  310. for param, ratio in prune_params_ratios.items():
  311. origin_num = prog_var_shape_dict[param][0]
  312. pruned_num = int(round(origin_num * ratio))
  313. while origin_num == pruned_num:
  314. ratio -= 0.1
  315. pruned_num = int(round(origin_num * (ratio)))
  316. prune_params_ratios[param] = ratio
  317. prune_program = channel_prune(
  318. program,
  319. list(prune_params_ratios.keys()),
  320. list(prune_params_ratios.values()),
  321. place,
  322. only_graph=True,
  323. scope=scope)
  324. origin_size = 0
  325. new_size = 0
  326. for var in program.list_vars():
  327. name = var.name
  328. shape = var.shape
  329. for prune_block in prune_program.blocks:
  330. if prune_block.has_var(name):
  331. prune_var = prune_block.var(name)
  332. prune_shape = prune_var.shape
  333. break
  334. origin_size += reduce(lambda x, y: x * y, shape)
  335. new_size += reduce(lambda x, y: x * y, prune_shape)
  336. return (new_size * 1.0) / origin_size