| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301 |
- # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from paddle.fluid.contrib.slim.quantization.quantization_pass import QuantizationTransformPass
- from paddle.fluid.contrib.slim.quantization.quantization_pass import AddQuantDequantPass
- from paddle.fluid.contrib.slim.quantization.quantization_pass import _op_real_in_out_name
- from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
- import paddlex.utils.logging as logging
- import paddle.fluid as fluid
- import os
- import re
- import numpy as np
- import time
- class PaddleXPostTrainingQuantization(PostTrainingQuantization):
- def __init__(self,
- executor,
- dataset,
- program,
- inputs,
- outputs,
- batch_size=10,
- batch_nums=None,
- scope=None,
- algo="KL",
- quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
- is_full_quantize=False,
- is_use_cache_file=False,
- cache_dir="./temp_post_training"):
- '''
- The class utilizes post training quantization methon to quantize the
- fp32 model. It uses calibrate data to calculate the scale factor of
- quantized variables, and inserts fake quant/dequant op to obtain the
- quantized model.
- Args:
- executor(fluid.Executor): The executor to load, run and save the
- quantized model.
- dataset(Python Iterator): The data Reader.
- program(fluid.Program): The paddle program, save the parameters for model.
- inputs(dict): The input of prigram.
- outputs(dict): The output of program.
- batch_size(int, optional): The batch size of DataLoader. Default is 10.
- batch_nums(int, optional): If batch_nums is not None, the number of
- calibrate data is batch_size*batch_nums. If batch_nums is None, use
- all data provided by sample_generator as calibrate data.
- scope(fluid.Scope, optional): The scope of the program, use it to load
- and save variables. If scope=None, get scope by global_scope().
- algo(str, optional): If algo=KL, use KL-divergenc method to
- get the more precise scale factor. If algo='direct', use
- abs_max methon to get the scale factor. Default is KL.
- quantizable_op_type(list[str], optional): List the type of ops
- that will be quantized. Default is ["conv2d", "depthwise_conv2d",
- "mul"].
- is_full_quantized(bool, optional): If set is_full_quantized as True,
- apply quantization to all supported quantizable op type. If set
- is_full_quantized as False, only apply quantization to the op type
- according to the input quantizable_op_type.
- is_use_cache_file(bool, optional): If set is_use_cache_file as False,
- all temp data will be saved in memory. If set is_use_cache_file as True,
- it will save temp data to disk. When the fp32 model is complex or
- the number of calibrate data is large, we should set is_use_cache_file
- as True. Defalut is False.
- cache_dir(str, optional): When is_use_cache_file is True, set cache_dir as
- the directory for saving temp data. Default is ./temp_post_training.
- Returns:
- None
- '''
- self._executor = executor
- self._dataset = dataset
- self._batch_size = batch_size
- self._batch_nums = batch_nums
- self._scope = fluid.global_scope() if scope == None else scope
- self._algo = algo
- self._is_use_cache_file = is_use_cache_file
- self._cache_dir = cache_dir
- if self._is_use_cache_file and not os.path.exists(self._cache_dir):
- os.mkdir(self._cache_dir)
- supported_quantizable_op_type = \
- QuantizationTransformPass._supported_quantizable_op_type + \
- AddQuantDequantPass._supported_quantizable_op_type
- if is_full_quantize:
- self._quantizable_op_type = supported_quantizable_op_type
- else:
- self._quantizable_op_type = quantizable_op_type
- for op_type in self._quantizable_op_type:
- assert op_type in supported_quantizable_op_type + \
- AddQuantDequantPass._activation_type, \
- op_type + " is not supported for quantization."
- self._place = self._executor.place
- self._program = program
- self._feed_list = list(inputs.values())
- self._fetch_list = list(outputs.values())
- self._data_loader = None
- self._op_real_in_out_name = _op_real_in_out_name
- self._bit_length = 8
- self._quantized_weight_var_name = set()
- self._quantized_act_var_name = set()
- self._sampling_data = {}
- self._quantized_var_scale_factor = {}
- def quantize(self):
- '''
- Quantize the fp32 model. Use calibrate data to calculate the scale factor of
- quantized variables, and inserts fake quant/dequant op to obtain the
- quantized model.
- Args:
- None
- Returns:
- the program of quantized model.
- '''
- self._preprocess()
- batch_ct = 0
- for data in self._data_loader():
- batch_ct += 1
- if self._batch_nums and batch_ct >= self._batch_nums:
- break
- batch_id = 0
- logging.info("Start to run batch!")
- for data in self._data_loader():
- start = time.time()
- self._executor.run(
- program=self._program,
- feed=data,
- fetch_list=self._fetch_list,
- return_numpy=False)
- self._sample_data(batch_id)
- end = time.time()
- logging.debug('[Run batch data] Batch={}/{}, time_each_batch={} s.'.format(
- str(batch_id + 1),
- str(batch_ct),
- str(end-start)))
- batch_id += 1
- if self._batch_nums and batch_id >= self._batch_nums:
- break
- logging.info("All run batch: ".format(batch_id))
- logging.info("Calculate scale factor ...")
- self._calculate_scale_factor()
- logging.info("Update the program ...")
- self._update_program()
- logging.info("Save ...")
- self._save_output_scale()
- logging.info("Finish quant!")
- return self._program
- def save_quantized_model(self, save_model_path):
- '''
- Save the quantized model to the disk.
- Args:
- save_model_path(str): The path to save the quantized model
- Returns:
- None
- '''
- feed_vars_names = [var.name for var in self._feed_list]
- fluid.io.save_inference_model(
- dirname=save_model_path,
- feeded_var_names=feed_vars_names,
- target_vars=self._fetch_list,
- executor=self._executor,
- params_filename='__params__',
- main_program=self._program)
- def _preprocess(self):
- '''
- Load model and set data loader, collect the variable names for sampling,
- and set activation variables to be persistable.
- '''
- feed_vars = [fluid.framework._get_var(var.name, self._program) \
- for var in self._feed_list]
- self._data_loader = fluid.io.DataLoader.from_generator(
- feed_list=feed_vars, capacity=3 * self._batch_size, iterable=True)
- self._data_loader.set_sample_list_generator(
- self._dataset.generator(self._batch_size, drop_last=True),
- places=self._place)
- # collect the variable names for sampling
- persistable_var_names = []
- for var in self._program.list_vars():
- if var.persistable:
- persistable_var_names.append(var.name)
- for op in self._program.global_block().ops:
- op_type = op.type
- if op_type in self._quantizable_op_type:
- if op_type in ("conv2d", "depthwise_conv2d"):
- self._quantized_act_var_name.add(op.input("Input")[0])
- self._quantized_weight_var_name.add(op.input("Filter")[0])
- self._quantized_act_var_name.add(op.output("Output")[0])
- elif op_type == "mul":
- if self._is_input_all_not_persistable(
- op, persistable_var_names):
- op._set_attr("skip_quant", True)
- logging.warning(
- "Skip quant a mul op for two input variables are not persistable"
- )
- else:
- self._quantized_act_var_name.add(op.input("X")[0])
- self._quantized_weight_var_name.add(op.input("Y")[0])
- self._quantized_act_var_name.add(op.output("Out")[0])
- else:
- # process other quantizable op type, the input must all not persistable
- if self._is_input_all_not_persistable(
- op, persistable_var_names):
- input_output_name_list = self._op_real_in_out_name[
- op_type]
- for input_name in input_output_name_list[0]:
- for var_name in op.input(input_name):
- self._quantized_act_var_name.add(var_name)
- for output_name in input_output_name_list[1]:
- for var_name in op.output(output_name):
- self._quantized_act_var_name.add(var_name)
- # set activation variables to be persistable, so can obtain
- # the tensor data in sample_data
- for var in self._program.list_vars():
- if var.name in self._quantized_act_var_name:
- var.persistable = True
-
- def _calculate_scale_factor(self):
- '''
- Calculate the scale factor of quantized variables.
- '''
- # apply channel_wise_abs_max quantization for weights
- ct = 1
- for var_name in self._quantized_weight_var_name:
- start = time.time()
- data = self._sampling_data[var_name]
- scale_factor_per_channel = []
- for i in range(data.shape[0]):
- abs_max_value = np.max(np.abs(data[i]))
- scale_factor_per_channel.append(abs_max_value)
- self._quantized_var_scale_factor[
- var_name] = scale_factor_per_channel
- end = time.time()
- logging.debug('[Calculate weight] Weight_id={}/{}, time_each_weight={} s.'.format(
- str(ct),
- str(len(self._quantized_weight_var_name)),
- str(end-start)))
- ct += 1
-
- ct = 1
- # apply kl quantization for activation
- if self._is_use_cache_file:
- for var_name in self._quantized_act_var_name:
- start = time.time()
- sampling_data = []
- filenames = [f for f in os.listdir(self._cache_dir) \
- if re.match(var_name + '_[0-9]+.npy', f)]
- for filename in filenames:
- file_path = os.path.join(self._cache_dir, filename)
- sampling_data.append(np.load(file_path))
- os.remove(file_path)
- sampling_data = np.concatenate(sampling_data)
- if self._algo == "KL":
- self._quantized_var_scale_factor[var_name] = \
- self._get_kl_scaling_factor(np.abs(sampling_data))
- else:
- self._quantized_var_scale_factor[var_name] = \
- np.max(np.abs(sampling_data))
- end = time.time()
- logging.debug('[Calculate activation] Activation_id={}/{}, time_each_activation={} s.'.format(
- str(ct),
- str(len(self._quantized_act_var_name)),
- str(end-start)))
- ct += 1
- else:
- for var_name in self._quantized_act_var_name:
- start = time.time()
- self._sampling_data[var_name] = np.concatenate(
- self._sampling_data[var_name])
- if self._algo == "KL":
- self._quantized_var_scale_factor[var_name] = \
- self._get_kl_scaling_factor(np.abs(self._sampling_data[var_name]))
- else:
- self._quantized_var_scale_factor[var_name] = \
- np.max(np.abs(self._sampling_data[var_name]))
- end = time.time()
- logging.debug('[Calculate activation] Activation_id={}/{}, time_each_activation={} s.'.format(
- str(ct),
- str(len(self._quantized_act_var_name)),
- str(end-start)))
- ct += 1
|