post_quantization.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from paddle.fluid.contrib.slim.quantization.quantization_pass import QuantizationTransformPass
  15. from paddle.fluid.contrib.slim.quantization.quantization_pass import AddQuantDequantPass
  16. from paddle.fluid.contrib.slim.quantization.quantization_pass import _out_scale_op_list
  17. from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
  18. import paddlex.utils.logging as logging
  19. import paddle.fluid as fluid
  20. import os
  21. import re
  22. import numpy as np
  23. import time
  24. class PaddleXPostTrainingQuantization(PostTrainingQuantization):
  25. def __init__(self,
  26. executor,
  27. dataset,
  28. program,
  29. inputs,
  30. outputs,
  31. batch_size=10,
  32. batch_nums=None,
  33. scope=None,
  34. algo="KL",
  35. quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
  36. is_full_quantize=False,
  37. is_use_cache_file=False,
  38. cache_dir="./temp_post_training"):
  39. '''
  40. The class utilizes post training quantization methon to quantize the
  41. fp32 model. It uses calibrate data to calculate the scale factor of
  42. quantized variables, and inserts fake quant/dequant op to obtain the
  43. quantized model.
  44. Args:
  45. executor(fluid.Executor): The executor to load, run and save the
  46. quantized model.
  47. dataset(Python Iterator): The data Reader.
  48. program(fluid.Program): The paddle program, save the parameters for model.
  49. inputs(dict): The input of prigram.
  50. outputs(dict): The output of program.
  51. batch_size(int, optional): The batch size of DataLoader. Default is 10.
  52. batch_nums(int, optional): If batch_nums is not None, the number of
  53. calibrate data is batch_size*batch_nums. If batch_nums is None, use
  54. all data provided by sample_generator as calibrate data.
  55. scope(fluid.Scope, optional): The scope of the program, use it to load
  56. and save variables. If scope=None, get scope by global_scope().
  57. algo(str, optional): If algo=KL, use KL-divergenc method to
  58. get the more precise scale factor. If algo='direct', use
  59. abs_max methon to get the scale factor. Default is KL.
  60. quantizable_op_type(list[str], optional): List the type of ops
  61. that will be quantized. Default is ["conv2d", "depthwise_conv2d",
  62. "mul"].
  63. is_full_quantized(bool, optional): If set is_full_quantized as True,
  64. apply quantization to all supported quantizable op type. If set
  65. is_full_quantized as False, only apply quantization to the op type
  66. according to the input quantizable_op_type.
  67. is_use_cache_file(bool, optional): If set is_use_cache_file as False,
  68. all temp data will be saved in memory. If set is_use_cache_file as True,
  69. it will save temp data to disk. When the fp32 model is complex or
  70. the number of calibrate data is large, we should set is_use_cache_file
  71. as True. Defalut is False.
  72. cache_dir(str, optional): When is_use_cache_file is True, set cache_dir as
  73. the directory for saving temp data. Default is ./temp_post_training.
  74. Returns:
  75. None
  76. '''
  77. self._support_activation_quantize_type = [
  78. 'range_abs_max', 'moving_average_abs_max', 'abs_max'
  79. ]
  80. self._support_weight_quantize_type = ['abs_max', 'channel_wise_abs_max']
  81. self._support_algo_type = ['KL', 'abs_max', 'min_max']
  82. self._support_quantize_op_type = \
  83. list(set(QuantizationTransformPass._supported_quantizable_op_type +
  84. AddQuantDequantPass._supported_quantizable_op_type))
  85. # Check inputs
  86. assert executor is not None, "The executor cannot be None."
  87. assert batch_size > 0, "The batch_size should be greater than 0."
  88. assert algo in self._support_algo_type, \
  89. "The algo should be KL, abs_max or min_max."
  90. self._executor = executor
  91. self._dataset = dataset
  92. self._batch_size = batch_size
  93. self._batch_nums = batch_nums
  94. self._scope = fluid.global_scope() if scope == None else scope
  95. self._algo = algo
  96. self._is_use_cache_file = is_use_cache_file
  97. self._cache_dir = cache_dir
  98. self._activation_bits = 8
  99. self._weight_bits = 8
  100. self._activation_quantize_type = 'range_abs_max'
  101. self._weight_quantize_type = 'channel_wise_abs_max'
  102. if self._is_use_cache_file and not os.path.exists(self._cache_dir):
  103. os.mkdir(self._cache_dir)
  104. if is_full_quantize:
  105. self._quantizable_op_type = self._support_quantize_op_type
  106. else:
  107. self._quantizable_op_type = quantizable_op_type
  108. for op_type in self._quantizable_op_type:
  109. assert op_type in self._support_quantize_op_type + \
  110. AddQuantDequantPass._activation_type, \
  111. op_type + " is not supported for quantization."
  112. self._place = self._executor.place
  113. self._program = program
  114. self._feed_list = list(inputs.values())
  115. self._fetch_list = list(outputs.values())
  116. self._data_loader = None
  117. self._out_scale_op_list = _out_scale_op_list
  118. self._bit_length = 8
  119. self._quantized_weight_var_name = set()
  120. self._quantized_act_var_name = set()
  121. self._sampling_data = {}
  122. self._quantized_var_kl_threshold = {}
  123. self._quantized_var_min = {}
  124. self._quantized_var_max = {}
  125. self._quantized_var_abs_max = {}
  126. def quantize(self):
  127. '''
  128. Quantize the fp32 model. Use calibrate data to calculate the scale factor of
  129. quantized variables, and inserts fake quant/dequant op to obtain the
  130. quantized model.
  131. Args:
  132. None
  133. Returns:
  134. the program of quantized model.
  135. '''
  136. self._load_model_data()
  137. self._collect_target_varnames()
  138. self._set_activation_persistable()
  139. batch_ct = 0
  140. for data in self._data_loader():
  141. batch_ct += 1
  142. if self._batch_nums and batch_ct >= self._batch_nums:
  143. break
  144. batch_id = 0
  145. logging.info("Start to run batch!")
  146. for data in self._data_loader():
  147. start = time.time()
  148. with fluid.scope_guard(self._scope):
  149. self._executor.run(program=self._program,
  150. feed=data,
  151. fetch_list=self._fetch_list,
  152. return_numpy=False)
  153. if self._algo == "KL":
  154. self._sample_data(batch_id)
  155. else:
  156. self._sample_threshold()
  157. end = time.time()
  158. logging.debug(
  159. '[Run batch data] Batch={}/{}, time_each_batch={} s.'.format(
  160. str(batch_id + 1), str(batch_ct), str(end - start)))
  161. batch_id += 1
  162. if self._batch_nums and batch_id >= self._batch_nums:
  163. break
  164. logging.info("All run batch: ".format(batch_id))
  165. self._reset_activation_persistable()
  166. logging.info("Calculate scale factor ...")
  167. if self._algo == "KL":
  168. self._calculate_kl_threshold()
  169. logging.info("Update the program ...")
  170. if self._algo in ["KL", "abs_max"]:
  171. self._update_program()
  172. else:
  173. self._save_input_threhold()
  174. logging.info("Save ...")
  175. self._save_output_threshold()
  176. logging.info("Finish quant!")
  177. return self._program
  178. def save_quantized_model(self, save_model_path):
  179. '''
  180. Save the quantized model to the disk.
  181. Args:
  182. save_model_path(str): The path to save the quantized model
  183. Returns:
  184. None
  185. '''
  186. with fluid.scope_guard(self._scope):
  187. feed_vars_names = [var.name for var in self._feed_list]
  188. fluid.io.save_inference_model(
  189. dirname=save_model_path,
  190. feeded_var_names=feed_vars_names,
  191. target_vars=self._fetch_list,
  192. executor=self._executor,
  193. params_filename='__params__',
  194. main_program=self._program)
  195. def _load_model_data(self):
  196. '''
  197. Set data loader.
  198. '''
  199. feed_vars = [fluid.framework._get_var(var.name, self._program) \
  200. for var in self._feed_list]
  201. self._data_loader = fluid.io.DataLoader.from_generator(
  202. feed_list=feed_vars, capacity=3 * self._batch_size, iterable=True)
  203. self._data_loader.set_sample_list_generator(
  204. self._dataset.generator(
  205. self._batch_size, drop_last=True),
  206. places=self._place)
  207. def _calculate_kl_threshold(self):
  208. '''
  209. Calculate the KL threshold of quantized variables.
  210. '''
  211. assert self._algo == "KL", "The algo should be KL to calculate kl threshold."
  212. ct = 1
  213. # Abs_max threshold for weights
  214. for var_name in self._quantized_weight_var_name:
  215. start = time.time()
  216. weight_data = self._sampling_data[var_name]
  217. weight_threshold = None
  218. if self._weight_quantize_type == "abs_max":
  219. weight_threshold = np.max(np.abs(weight_data))
  220. elif self._weight_quantize_type == "channel_wise_abs_max":
  221. weight_threshold = []
  222. for i in range(weight_data.shape[0]):
  223. abs_max_value = np.max(np.abs(weight_data[i]))
  224. weight_threshold.append(abs_max_value)
  225. self._quantized_var_kl_threshold[var_name] = weight_threshold
  226. end = time.time()
  227. logging.debug(
  228. '[Calculate weight] Weight_id={}/{}, time_each_weight={} s.'.
  229. format(
  230. str(ct),
  231. str(len(self._quantized_weight_var_name)), str(end -
  232. start)))
  233. ct += 1
  234. ct = 1
  235. # KL threshold for activations
  236. if self._is_use_cache_file:
  237. for var_name in self._quantized_act_var_name:
  238. start = time.time()
  239. sampling_data = []
  240. filenames = [f for f in os.listdir(self._cache_dir) \
  241. if re.match(var_name + '_[0-9]+.npy', f)]
  242. for filename in filenames:
  243. file_path = os.path.join(self._cache_dir, filename)
  244. sampling_data.append(np.load(file_path))
  245. os.remove(file_path)
  246. sampling_data = np.concatenate(sampling_data)
  247. self._quantized_var_kl_threshold[var_name] = \
  248. self._get_kl_scaling_factor(np.abs(sampling_data))
  249. end = time.time()
  250. logging.debug(
  251. '[Calculate activation] Activation_id={}/{}, time_each_activation={} s.'.
  252. format(
  253. str(ct),
  254. str(len(self._quantized_act_var_name)),
  255. str(end - start)))
  256. ct += 1
  257. else:
  258. for var_name in self._quantized_act_var_name:
  259. start = time.time()
  260. self._sampling_data[var_name] = np.concatenate(
  261. self._sampling_data[var_name])
  262. self._quantized_var_kl_threshold[var_name] = \
  263. self._get_kl_scaling_factor(np.abs(self._sampling_data[var_name]))
  264. end = time.time()
  265. logging.debug(
  266. '[Calculate activation] Activation_id={}/{}, time_each_activation={} s.'.
  267. format(
  268. str(ct),
  269. str(len(self._quantized_act_var_name)),
  270. str(end - start)))
  271. ct += 1