post_quantization.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from paddle.fluid.contrib.slim.quantization.quantization_pass import QuantizationTransformPass
  15. from paddle.fluid.contrib.slim.quantization.quantization_pass import AddQuantDequantPass
  16. from paddle.fluid.contrib.slim.quantization.quantization_pass import _out_scale_op_list
  17. from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
  18. import paddlex.utils.logging as logging
  19. import paddle.fluid as fluid
  20. import os
  21. import re
  22. import numpy as np
  23. import time
  24. def _load_variable_data(scope, var_name):
  25. '''
  26. Load variable value from scope
  27. '''
  28. var_node = scope.find_var(var_name)
  29. assert var_node is not None, \
  30. "Cannot find " + var_name + " in scope."
  31. return np.array(var_node.get_tensor())
  32. class PaddleXPostTrainingQuantization(PostTrainingQuantization):
  33. def __init__(self,
  34. executor,
  35. dataset,
  36. program,
  37. inputs,
  38. outputs,
  39. batch_size=10,
  40. batch_nums=None,
  41. scope=None,
  42. algo="KL",
  43. quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
  44. is_full_quantize=False,
  45. is_use_cache_file=False,
  46. cache_dir="./temp_post_training"):
  47. '''
  48. The class utilizes post training quantization methon to quantize the
  49. fp32 model. It uses calibrate data to calculate the scale factor of
  50. quantized variables, and inserts fake quant/dequant op to obtain the
  51. quantized model.
  52. Args:
  53. executor(fluid.Executor): The executor to load, run and save the
  54. quantized model.
  55. dataset(Python Iterator): The data Reader.
  56. program(fluid.Program): The paddle program, save the parameters for model.
  57. inputs(dict): The input of prigram.
  58. outputs(dict): The output of program.
  59. batch_size(int, optional): The batch size of DataLoader. Default is 10.
  60. batch_nums(int, optional): If batch_nums is not None, the number of
  61. calibrate data is batch_size*batch_nums. If batch_nums is None, use
  62. all data provided by sample_generator as calibrate data.
  63. scope(fluid.Scope, optional): The scope of the program, use it to load
  64. and save variables. If scope=None, get scope by global_scope().
  65. algo(str, optional): If algo=KL, use KL-divergenc method to
  66. get the more precise scale factor. If algo='direct', use
  67. abs_max methon to get the scale factor. Default is KL.
  68. quantizable_op_type(list[str], optional): List the type of ops
  69. that will be quantized. Default is ["conv2d", "depthwise_conv2d",
  70. "mul"].
  71. is_full_quantized(bool, optional): If set is_full_quantized as True,
  72. apply quantization to all supported quantizable op type. If set
  73. is_full_quantized as False, only apply quantization to the op type
  74. according to the input quantizable_op_type.
  75. is_use_cache_file(bool, optional): If set is_use_cache_file as False,
  76. all temp data will be saved in memory. If set is_use_cache_file as True,
  77. it will save temp data to disk. When the fp32 model is complex or
  78. the number of calibrate data is large, we should set is_use_cache_file
  79. as True. Defalut is False.
  80. cache_dir(str, optional): When is_use_cache_file is True, set cache_dir as
  81. the directory for saving temp data. Default is ./temp_post_training.
  82. Returns:
  83. None
  84. '''
  85. self._support_activation_quantize_type = [
  86. 'range_abs_max', 'moving_average_abs_max', 'abs_max'
  87. ]
  88. self._support_weight_quantize_type = [
  89. 'abs_max', 'channel_wise_abs_max'
  90. ]
  91. self._support_algo_type = ['KL', 'abs_max', 'min_max']
  92. self._support_quantize_op_type = \
  93. list(set(QuantizationTransformPass._supported_quantizable_op_type +
  94. AddQuantDequantPass._supported_quantizable_op_type))
  95. # Check inputs
  96. assert executor is not None, "The executor cannot be None."
  97. assert batch_size > 0, "The batch_size should be greater than 0."
  98. assert algo in self._support_algo_type, \
  99. "The algo should be KL, abs_max or min_max."
  100. self._executor = executor
  101. self._dataset = dataset
  102. self._batch_size = batch_size
  103. self._batch_nums = batch_nums
  104. self._scope = fluid.global_scope() if scope == None else scope
  105. self._algo = algo
  106. self._is_use_cache_file = is_use_cache_file
  107. self._cache_dir = cache_dir
  108. self._activation_bits = 8
  109. self._weight_bits = 8
  110. self._activation_quantize_type = 'range_abs_max'
  111. self._weight_quantize_type = 'channel_wise_abs_max'
  112. if self._is_use_cache_file and not os.path.exists(self._cache_dir):
  113. os.mkdir(self._cache_dir)
  114. if is_full_quantize:
  115. self._quantizable_op_type = self._support_quantize_op_type
  116. else:
  117. self._quantizable_op_type = quantizable_op_type
  118. for op_type in self._quantizable_op_type:
  119. assert op_type in self._support_quantize_op_type + \
  120. AddQuantDequantPass._activation_type, \
  121. op_type + " is not supported for quantization."
  122. self._place = self._executor.place
  123. self._program = program
  124. self._feed_list = list(inputs.values())
  125. self._fetch_list = list(outputs.values())
  126. self._data_loader = None
  127. self._out_scale_op_list = _out_scale_op_list
  128. self._bit_length = 8
  129. self._quantized_weight_var_name = set()
  130. self._quantized_act_var_name = set()
  131. self._sampling_data = {}
  132. self._quantized_var_kl_threshold = {}
  133. self._quantized_var_min = {}
  134. self._quantized_var_max = {}
  135. self._quantized_var_abs_max = {}
  136. def quantize(self):
  137. '''
  138. Quantize the fp32 model. Use calibrate data to calculate the scale factor of
  139. quantized variables, and inserts fake quant/dequant op to obtain the
  140. quantized model.
  141. Args:
  142. None
  143. Returns:
  144. the program of quantized model.
  145. '''
  146. self._load_model_data()
  147. self._collect_target_varnames()
  148. self._set_activation_persistable()
  149. batch_ct = 0
  150. for data in self._data_loader():
  151. batch_ct += 1
  152. if self._batch_nums and batch_ct >= self._batch_nums:
  153. break
  154. batch_id = 0
  155. logging.info("Start to run batch!")
  156. for data in self._data_loader():
  157. start = time.time()
  158. with fluid.scope_guard(self._scope):
  159. self._executor.run(program=self._program,
  160. feed=data,
  161. fetch_list=self._fetch_list,
  162. return_numpy=False)
  163. if self._algo == "KL":
  164. self._sample_data(batch_id)
  165. else:
  166. self._sample_threshold()
  167. end = time.time()
  168. logging.debug(
  169. '[Run batch data] Batch={}/{}, time_each_batch={} s.'.format(
  170. str(batch_id + 1), str(batch_ct), str(end - start)))
  171. batch_id += 1
  172. if self._batch_nums and batch_id >= self._batch_nums:
  173. break
  174. logging.info("All run batch: ".format(batch_id))
  175. self._reset_activation_persistable()
  176. logging.info("Calculate scale factor ...")
  177. if self._algo == "KL":
  178. self._calculate_kl_threshold()
  179. logging.info("Update the program ...")
  180. if self._algo in ["KL", "abs_max"]:
  181. self._update_program()
  182. else:
  183. self._save_input_threhold()
  184. logging.info("Save ...")
  185. self._save_output_threshold()
  186. logging.info("Finish quant!")
  187. return self._program
  188. def save_quantized_model(self, save_model_path):
  189. '''
  190. Save the quantized model to the disk.
  191. Args:
  192. save_model_path(str): The path to save the quantized model
  193. Returns:
  194. None
  195. '''
  196. with fluid.scope_guard(self._scope):
  197. feed_vars_names = [var.name for var in self._feed_list]
  198. fluid.io.save_inference_model(
  199. dirname=save_model_path,
  200. feeded_var_names=feed_vars_names,
  201. target_vars=self._fetch_list,
  202. executor=self._executor,
  203. params_filename='__params__',
  204. main_program=self._program)
  205. def _load_model_data(self):
  206. '''
  207. Set data loader.
  208. '''
  209. feed_vars = [fluid.framework._get_var(var.name, self._program) \
  210. for var in self._feed_list]
  211. self._data_loader = fluid.io.DataLoader.from_generator(
  212. feed_list=feed_vars, capacity=3 * self._batch_size, iterable=True)
  213. self._data_loader.set_sample_list_generator(
  214. self._dataset.generator(
  215. self._batch_size, drop_last=True),
  216. places=self._place)
  217. def _calculate_kl_threshold(self):
  218. '''
  219. Calculate the KL threshold of quantized variables.
  220. '''
  221. assert self._algo == "KL", "The algo should be KL to calculate kl threshold."
  222. ct = 1
  223. # Abs_max threshold for weights
  224. for var_name in self._quantized_weight_var_name:
  225. start = time.time()
  226. weight_data = self._sampling_data[var_name]
  227. weight_threshold = None
  228. if self._weight_quantize_type == "abs_max":
  229. weight_threshold = np.max(np.abs(weight_data))
  230. elif self._weight_quantize_type == "channel_wise_abs_max":
  231. weight_threshold = []
  232. for i in range(weight_data.shape[0]):
  233. abs_max_value = np.max(np.abs(weight_data[i]))
  234. weight_threshold.append(abs_max_value)
  235. self._quantized_var_kl_threshold[var_name] = weight_threshold
  236. end = time.time()
  237. logging.debug(
  238. '[Calculate weight] Weight_id={}/{}, time_each_weight={} s.'.
  239. format(
  240. str(ct),
  241. str(len(self._quantized_weight_var_name)),
  242. str(end - start)))
  243. ct += 1
  244. ct = 1
  245. # KL threshold for activations
  246. if self._is_use_cache_file:
  247. for var_name in self._quantized_act_var_name:
  248. start = time.time()
  249. sampling_data = []
  250. filenames = [f for f in os.listdir(self._cache_dir) \
  251. if re.match(var_name + '_[0-9]+.npy', f)]
  252. for filename in filenames:
  253. file_path = os.path.join(self._cache_dir, filename)
  254. sampling_data.append(np.load(file_path))
  255. os.remove(file_path)
  256. sampling_data = np.concatenate(sampling_data)
  257. self._quantized_var_kl_threshold[var_name] = \
  258. self._get_kl_scaling_factor(np.abs(sampling_data))
  259. end = time.time()
  260. logging.debug(
  261. '[Calculate activation] Activation_id={}/{}, time_each_activation={} s.'.
  262. format(
  263. str(ct),
  264. str(len(self._quantized_act_var_name)),
  265. str(end - start)))
  266. ct += 1
  267. else:
  268. for var_name in self._quantized_act_var_name:
  269. start = time.time()
  270. self._sampling_data[var_name] = np.concatenate(
  271. self._sampling_data[var_name])
  272. self._quantized_var_kl_threshold[var_name] = \
  273. self._get_kl_scaling_factor(np.abs(self._sampling_data[var_name]))
  274. end = time.time()
  275. logging.debug(
  276. '[Calculate activation] Activation_id={}/{}, time_each_activation={} s.'.
  277. format(
  278. str(ct),
  279. str(len(self._quantized_act_var_name)),
  280. str(end - start)))
  281. ct += 1
  282. def _sample_data(self, iter):
  283. '''
  284. Sample the tensor data of quantized variables,
  285. applied in every iteration.
  286. '''
  287. assert self._algo == "KL", "The algo should be KL to sample data."
  288. for var_name in self._quantized_weight_var_name:
  289. if var_name not in self._sampling_data:
  290. var_tensor = _load_variable_data(self._scope, var_name)
  291. self._sampling_data[var_name] = var_tensor
  292. if self._is_use_cache_file:
  293. for var_name in self._quantized_act_var_name:
  294. var_tensor = _load_variable_data(self._scope, var_name)
  295. var_tensor = var_tensor.ravel()
  296. save_path = os.path.join(self._cache_dir,
  297. var_name + "_" + str(iter) + ".npy")
  298. save_dir, file_name = os.path.split(save_path)
  299. if not os.path.exists(save_dir):
  300. os.mkdirs(save_dir)
  301. np.save(save_path, var_tensor)
  302. else:
  303. for var_name in self._quantized_act_var_name:
  304. if var_name not in self._sampling_data:
  305. self._sampling_data[var_name] = []
  306. var_tensor = _load_variable_data(self._scope, var_name)
  307. var_tensor = var_tensor.ravel()
  308. self._sampling_data[var_name].append(var_tensor)