post_quantization.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from paddle.fluid.contrib.slim.quantization.quantization_pass import QuantizationTransformPass
  15. from paddle.fluid.contrib.slim.quantization.quantization_pass import AddQuantDequantPass
  16. from paddle.fluid.contrib.slim.quantization.quantization_pass import _op_real_in_out_name
  17. from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
  18. import paddlex.utils.logging as logging
  19. import paddle.fluid as fluid
  20. import os
  21. import re
  22. import numpy as np
  23. import time
  24. class PaddleXPostTrainingQuantization(PostTrainingQuantization):
  25. def __init__(self,
  26. executor,
  27. dataset,
  28. program,
  29. inputs,
  30. outputs,
  31. batch_size=10,
  32. batch_nums=None,
  33. scope=None,
  34. algo="KL",
  35. quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
  36. is_full_quantize=False,
  37. is_use_cache_file=False,
  38. cache_dir="./temp_post_training"):
  39. '''
  40. The class utilizes post training quantization methon to quantize the
  41. fp32 model. It uses calibrate data to calculate the scale factor of
  42. quantized variables, and inserts fake quant/dequant op to obtain the
  43. quantized model.
  44. Args:
  45. executor(fluid.Executor): The executor to load, run and save the
  46. quantized model.
  47. dataset(Python Iterator): The data Reader.
  48. program(fluid.Program): The paddle program, save the parameters for model.
  49. inputs(dict): The input of prigram.
  50. outputs(dict): The output of program.
  51. batch_size(int, optional): The batch size of DataLoader. Default is 10.
  52. batch_nums(int, optional): If batch_nums is not None, the number of
  53. calibrate data is batch_size*batch_nums. If batch_nums is None, use
  54. all data provided by sample_generator as calibrate data.
  55. scope(fluid.Scope, optional): The scope of the program, use it to load
  56. and save variables. If scope=None, get scope by global_scope().
  57. algo(str, optional): If algo=KL, use KL-divergenc method to
  58. get the more precise scale factor. If algo='direct', use
  59. abs_max methon to get the scale factor. Default is KL.
  60. quantizable_op_type(list[str], optional): List the type of ops
  61. that will be quantized. Default is ["conv2d", "depthwise_conv2d",
  62. "mul"].
  63. is_full_quantized(bool, optional): If set is_full_quantized as True,
  64. apply quantization to all supported quantizable op type. If set
  65. is_full_quantized as False, only apply quantization to the op type
  66. according to the input quantizable_op_type.
  67. is_use_cache_file(bool, optional): If set is_use_cache_file as False,
  68. all temp data will be saved in memory. If set is_use_cache_file as True,
  69. it will save temp data to disk. When the fp32 model is complex or
  70. the number of calibrate data is large, we should set is_use_cache_file
  71. as True. Defalut is False.
  72. cache_dir(str, optional): When is_use_cache_file is True, set cache_dir as
  73. the directory for saving temp data. Default is ./temp_post_training.
  74. Returns:
  75. None
  76. '''
  77. self._executor = executor
  78. self._dataset = dataset
  79. self._batch_size = batch_size
  80. self._batch_nums = batch_nums
  81. self._scope = fluid.global_scope() if scope == None else scope
  82. self._algo = algo
  83. self._is_use_cache_file = is_use_cache_file
  84. self._cache_dir = cache_dir
  85. if self._is_use_cache_file and not os.path.exists(self._cache_dir):
  86. os.mkdir(self._cache_dir)
  87. supported_quantizable_op_type = \
  88. QuantizationTransformPass._supported_quantizable_op_type + \
  89. AddQuantDequantPass._supported_quantizable_op_type
  90. if is_full_quantize:
  91. self._quantizable_op_type = supported_quantizable_op_type
  92. else:
  93. self._quantizable_op_type = quantizable_op_type
  94. for op_type in self._quantizable_op_type:
  95. assert op_type in supported_quantizable_op_type + \
  96. AddQuantDequantPass._activation_type, \
  97. op_type + " is not supported for quantization."
  98. self._place = self._executor.place
  99. self._program = program
  100. self._feed_list = list(inputs.values())
  101. self._fetch_list = list(outputs.values())
  102. self._data_loader = None
  103. self._op_real_in_out_name = _op_real_in_out_name
  104. self._bit_length = 8
  105. self._quantized_weight_var_name = set()
  106. self._quantized_act_var_name = set()
  107. self._sampling_data = {}
  108. self._quantized_var_scale_factor = {}
  109. def quantize(self):
  110. '''
  111. Quantize the fp32 model. Use calibrate data to calculate the scale factor of
  112. quantized variables, and inserts fake quant/dequant op to obtain the
  113. quantized model.
  114. Args:
  115. None
  116. Returns:
  117. the program of quantized model.
  118. '''
  119. self._preprocess()
  120. batch_ct = 0
  121. for data in self._data_loader():
  122. batch_ct += 1
  123. if self._batch_nums and batch_ct >= self._batch_nums:
  124. break
  125. batch_id = 0
  126. logging.info("Start to run batch!")
  127. for data in self._data_loader():
  128. start = time.time()
  129. self._executor.run(
  130. program=self._program,
  131. feed=data,
  132. fetch_list=self._fetch_list,
  133. return_numpy=False)
  134. self._sample_data(batch_id)
  135. end = time.time()
  136. logging.debug('[Run batch data] Batch={}/{}, time_each_batch={} s.'.format(
  137. str(batch_id + 1),
  138. str(batch_ct),
  139. str(end-start)))
  140. batch_id += 1
  141. if self._batch_nums and batch_id >= self._batch_nums:
  142. break
  143. logging.info("All run batch: ".format(batch_id))
  144. logging.info("Calculate scale factor ...")
  145. self._calculate_scale_factor()
  146. logging.info("Update the program ...")
  147. self._update_program()
  148. logging.info("Save ...")
  149. self._save_output_scale()
  150. logging.info("Finish quant!")
  151. return self._program
  152. def save_quantized_model(self, save_model_path):
  153. '''
  154. Save the quantized model to the disk.
  155. Args:
  156. save_model_path(str): The path to save the quantized model
  157. Returns:
  158. None
  159. '''
  160. feed_vars_names = [var.name for var in self._feed_list]
  161. fluid.io.save_inference_model(
  162. dirname=save_model_path,
  163. feeded_var_names=feed_vars_names,
  164. target_vars=self._fetch_list,
  165. executor=self._executor,
  166. params_filename='__params__',
  167. main_program=self._program)
  168. def _preprocess(self):
  169. '''
  170. Load model and set data loader, collect the variable names for sampling,
  171. and set activation variables to be persistable.
  172. '''
  173. feed_vars = [fluid.framework._get_var(var.name, self._program) \
  174. for var in self._feed_list]
  175. self._data_loader = fluid.io.DataLoader.from_generator(
  176. feed_list=feed_vars, capacity=3 * self._batch_size, iterable=True)
  177. self._data_loader.set_sample_list_generator(
  178. self._dataset.generator(self._batch_size, drop_last=True),
  179. places=self._place)
  180. # collect the variable names for sampling
  181. persistable_var_names = []
  182. for var in self._program.list_vars():
  183. if var.persistable:
  184. persistable_var_names.append(var.name)
  185. for op in self._program.global_block().ops:
  186. op_type = op.type
  187. if op_type in self._quantizable_op_type:
  188. if op_type in ("conv2d", "depthwise_conv2d"):
  189. self._quantized_act_var_name.add(op.input("Input")[0])
  190. self._quantized_weight_var_name.add(op.input("Filter")[0])
  191. self._quantized_act_var_name.add(op.output("Output")[0])
  192. elif op_type == "mul":
  193. if self._is_input_all_not_persistable(
  194. op, persistable_var_names):
  195. op._set_attr("skip_quant", True)
  196. logging.warning(
  197. "Skip quant a mul op for two input variables are not persistable"
  198. )
  199. else:
  200. self._quantized_act_var_name.add(op.input("X")[0])
  201. self._quantized_weight_var_name.add(op.input("Y")[0])
  202. self._quantized_act_var_name.add(op.output("Out")[0])
  203. else:
  204. # process other quantizable op type, the input must all not persistable
  205. if self._is_input_all_not_persistable(
  206. op, persistable_var_names):
  207. input_output_name_list = self._op_real_in_out_name[
  208. op_type]
  209. for input_name in input_output_name_list[0]:
  210. for var_name in op.input(input_name):
  211. self._quantized_act_var_name.add(var_name)
  212. for output_name in input_output_name_list[1]:
  213. for var_name in op.output(output_name):
  214. self._quantized_act_var_name.add(var_name)
  215. # set activation variables to be persistable, so can obtain
  216. # the tensor data in sample_data
  217. for var in self._program.list_vars():
  218. if var.name in self._quantized_act_var_name:
  219. var.persistable = True
  220. def _calculate_scale_factor(self):
  221. '''
  222. Calculate the scale factor of quantized variables.
  223. '''
  224. # apply channel_wise_abs_max quantization for weights
  225. ct = 1
  226. for var_name in self._quantized_weight_var_name:
  227. start = time.time()
  228. data = self._sampling_data[var_name]
  229. scale_factor_per_channel = []
  230. for i in range(data.shape[0]):
  231. abs_max_value = np.max(np.abs(data[i]))
  232. scale_factor_per_channel.append(abs_max_value)
  233. self._quantized_var_scale_factor[
  234. var_name] = scale_factor_per_channel
  235. end = time.time()
  236. logging.debug('[Calculate weight] Weight_id={}/{}, time_each_weight={} s.'.format(
  237. str(ct),
  238. str(len(self._quantized_weight_var_name)),
  239. str(end-start)))
  240. ct += 1
  241. ct = 1
  242. # apply kl quantization for activation
  243. if self._is_use_cache_file:
  244. for var_name in self._quantized_act_var_name:
  245. start = time.time()
  246. sampling_data = []
  247. filenames = [f for f in os.listdir(self._cache_dir) \
  248. if re.match(var_name + '_[0-9]+.npy', f)]
  249. for filename in filenames:
  250. file_path = os.path.join(self._cache_dir, filename)
  251. sampling_data.append(np.load(file_path))
  252. os.remove(file_path)
  253. sampling_data = np.concatenate(sampling_data)
  254. if self._algo == "KL":
  255. self._quantized_var_scale_factor[var_name] = \
  256. self._get_kl_scaling_factor(np.abs(sampling_data))
  257. else:
  258. self._quantized_var_scale_factor[var_name] = \
  259. np.max(np.abs(sampling_data))
  260. end = time.time()
  261. logging.debug('[Calculate activation] Activation_id={}/{}, time_each_activation={} s.'.format(
  262. str(ct),
  263. str(len(self._quantized_act_var_name)),
  264. str(end-start)))
  265. ct += 1
  266. else:
  267. for var_name in self._quantized_act_var_name:
  268. start = time.time()
  269. self._sampling_data[var_name] = np.concatenate(
  270. self._sampling_data[var_name])
  271. if self._algo == "KL":
  272. self._quantized_var_scale_factor[var_name] = \
  273. self._get_kl_scaling_factor(np.abs(self._sampling_data[var_name]))
  274. else:
  275. self._quantized_var_scale_factor[var_name] = \
  276. np.max(np.abs(self._sampling_data[var_name]))
  277. end = time.time()
  278. logging.debug('[Calculate activation] Activation_id={}/{}, time_each_activation={} s.'.format(
  279. str(ct),
  280. str(len(self._quantized_act_var_name)),
  281. str(end-start)))
  282. ct += 1