瀏覽代碼

Merge pull request #593 from FlyingQianMM/develop_qh

support model quantization when using paddle2.0
Jason 4 年之前
父節點
當前提交
e4cba1c396

+ 16 - 4
paddlex/cv/datasets/shared_queue/sharedmemory.py

@@ -221,7 +221,10 @@ class PageAllocator(object):
 
         start = self.s_allocator_header
         end = start + self._page_size * hpages
-        alloc_flags = self._base[start:end].tostring()
+        try:
+            alloc_flags = self._base[start:end].tobytes()
+        except:
+            alloc_flags = self._base[start:end].tostring()
         info = {
             'magic_num': self._magic_num,
             'header_pages': hpages,
@@ -250,7 +253,10 @@ class PageAllocator(object):
     def header(self):
         """ get header info of this allocator
         """
-        header_str = self._base[0:self.s_allocator_header].tostring()
+        try:
+            header_str = self._base[0:self.s_allocator_header].tobytes()
+        except:
+            header_str = self._base[0:self.s_allocator_header].tostring()
         magic, pos, used = struct.unpack(str('III'), header_str)
 
         assert magic == self._magic_num, \
@@ -297,7 +303,10 @@ class PageAllocator(object):
         end = start + page_num
         assert start >= 0 and end <= self._header_size, 'invalid end[%d] of pages '\
             'in allocator[%s]' % (end, str(self))
-        status = self._base[start:end].tostring().decode()
+        try:
+            status = self._base[start:end].tobytes().decode()
+        except:
+            status = self._base[start:end].tostring().decode()
         if ret_flag:
             return status
 
@@ -515,7 +524,10 @@ class SharedMemoryMgr(object):
         if no_copy:
             return self._base[start:start + size]
         else:
-            return self._base[start:start + size].tostring()
+            try:
+                return self._base[start:start + size].tobytes()
+            except:
+                return self._base[start:start + size].tostring()
 
     def __str__(self):
         return 'SharedMemoryMgr:{id:%d, %s}' % (self._id, str(self._allocator))

+ 6 - 5
paddlex/cv/models/base.py

@@ -143,13 +143,14 @@ class BaseAPI:
             mode='quant',
             input_channel=input_channel)
         dataset.num_samples = batch_size * batch_num
-        try:
+        import paddle
+        version = paddle.__version__.strip().split('.')
+        if version[0] == '2' or (version[0] == '0' and
+                                 hasattr(paddle, 'enable_static')):
+            from .slim.post_quantization import PaddleXPostTrainingQuantizationV2 as PaddleXPostTrainingQuantization
+        else:
             from .slim.post_quantization import PaddleXPostTrainingQuantization
             PaddleXPostTrainingQuantization._collect_target_varnames
-        except:
-            raise Exception(
-                "Model Quantization is not available, try to upgrade your paddlepaddle>=1.8.0"
-            )
         is_use_cache_file = True
         if cache_dir is None:
             is_use_cache_file = False

+ 189 - 0
paddlex/cv/models/slim/post_quantization.py

@@ -325,3 +325,192 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
                 var_tensor = _load_variable_data(self._scope, var_name)
                 var_tensor = var_tensor.ravel()
                 self._sampling_data[var_name].append(var_tensor)
+
+
+class PaddleXPostTrainingQuantizationV2(PostTrainingQuantization):
+    def __init__(self,
+                 executor,
+                 dataset,
+                 program,
+                 inputs,
+                 outputs,
+                 batch_size=10,
+                 batch_nums=None,
+                 scope=None,
+                 algo="KL",
+                 quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
+                 is_full_quantize=False,
+                 activation_bits=8,
+                 weight_bits=8,
+                 activation_quantize_type='range_abs_max',
+                 weight_quantize_type='channel_wise_abs_max',
+                 optimize_model=False,
+                 is_use_cache_file=False,
+                 cache_dir="./temp_post_training"):
+        '''
+        Constructor.
+
+        Args:
+            executor(fluid.Executor): The executor to load, run and save the
+                quantized model.
+            dataset(Python Iterator): The data Reader.
+            program(fluid.Program): The paddle program, save the parameters for model.
+            inputs(dict): The input of prigram.
+            outputs(dict): The output of program.
+
+            scope(fluid.Scope, optional): The scope of the program, use it to load
+                and save variables. If scope=None, get scope by global_scope().
+            batch_size(int, optional): The batch size of DataLoader. Default is 10.
+            batch_nums(int, optional): If batch_nums is not None, the number of
+                calibrate data is batch_size*batch_nums. If batch_nums is None, use
+                all data provided by sample_generator as calibrate data.
+            algo(str, optional): If algo='KL', use KL-divergenc method to
+                get the KL threshold for quantized activations and get the abs_max
+                value for quantized weights. If algo='abs_max', get the abs max
+                value for activations and weights. If algo= 'min_max', get the min
+                and max value for quantized activations and weights. Default is KL.
+            quantizable_op_type(list[str], optional): List the type of ops
+                that will be quantized. Default is ["conv2d", "depthwise_conv2d",
+                "mul"].
+            is_full_quantized(bool, optional): If set is_full_quantized as True,
+                apply quantization to all supported quantizable op type. If set
+                is_full_quantized as False, only apply quantization to the op type
+                according to the input quantizable_op_type.
+            activation_bits(int): quantization bit number for activation.
+            weight_bits(int, optional): quantization bit number for weights.
+            activation_quantize_type(str): quantization type for activation,
+                now support 'range_abs_max', 'moving_average_abs_max' and 'abs_max'.
+                This param only specifies the fake ops in saving quantized model.
+                If it is 'range_abs_max' or 'moving_average_abs_max', we save the scale
+                obtained by post training quantization in fake ops. Note that, if it
+                is 'abs_max', the scale will not be saved in fake ops.
+            weight_quantize_type(str): quantization type for weights,
+                support 'abs_max' and 'channel_wise_abs_max'. This param only specifies
+                the fake ops in saving quantized model, and we save the scale obtained
+                by post training quantization in fake ops. Compared to 'abs_max',
+                the model accuracy is usually higher when it is 'channel_wise_abs_max'.
+            optimize_model(bool, optional): If set optimize_model as True, it applies
+                some passes to the model before quantization, and it supports
+                `conv2d/depthwise_conv2d + bn` pass so far. Some targets require the
+                weights are quantized by tensor-wise method, which means the weights
+                scale for all channel are the same. However, if fuse
+                `conv2d/depthwise_conv2d + bn`, the weights scale for all channel will
+                be different. In address this problem, fuse the pattern before
+                quantization. Default False.
+            is_use_cache_file(bool, optional): This param is deprecated.
+            cache_dir(str, optional): This param is deprecated.
+        Returns:
+            None
+
+        '''
+
+        self._support_activation_quantize_type = [
+            'range_abs_max', 'moving_average_abs_max', 'abs_max'
+        ]
+        self._support_weight_quantize_type = [
+            'abs_max', 'channel_wise_abs_max'
+        ]
+        self._support_algo_type = ['KL', 'abs_max', 'min_max']
+        self._dynamic_quantize_op_type = ['lstm']
+        self._support_quantize_op_type = \
+            list(set(QuantizationTransformPass._supported_quantizable_op_type +
+                AddQuantDequantPass._supported_quantizable_op_type +
+                self._dynamic_quantize_op_type))
+
+        # Check inputs
+        assert executor is not None, "The executor cannot be None."
+
+        assert batch_size > 0, "The batch_size should be greater than 0."
+        assert algo in self._support_algo_type, \
+            "The algo should be KL, abs_max or min_max."
+        assert activation_quantize_type in self._support_activation_quantize_type, \
+            "The activation_quantize_type ({}) should in ({}).".format(
+            activation_quantize_type, self._support_activation_quantize_type)
+        assert weight_quantize_type in self._support_weight_quantize_type, \
+            "The weight_quantize_type ({}) shoud in ({}).".format(
+            weight_quantize_type, self._support_weight_quantize_type)
+
+        # Save input params
+        self._executor = executor
+        self._dataset = dataset
+        self._scope = fluid.global_scope() if scope == None else scope
+        self._batch_size = batch_size
+        self._batch_nums = batch_nums
+        self._algo = algo
+        self._activation_bits = activation_bits
+        self._weight_bits = weight_bits
+        self._activation_quantize_type = activation_quantize_type
+        self._weight_quantize_type = weight_quantize_type
+        self._is_full_quantize = is_full_quantize
+        if is_full_quantize:
+            self._quantizable_op_type = self._support_quantize_op_type
+        else:
+            self._quantizable_op_type = quantizable_op_type
+            for op_type in self._quantizable_op_type:
+                assert op_type in self._support_quantize_op_type, \
+                    op_type + " is not supported for quantization."
+        self._optimize_model = optimize_model
+
+        # Define variables
+        self._place = self._executor.place
+        self._program = program
+        self._feed_list = [var.name for var in inputs.values()]
+        self._fetch_list = list(outputs.values())
+        self._data_loader = None
+
+        self._out_scale_op_list = _out_scale_op_list
+        self._quantized_weight_var_name = set()
+        self._quantized_act_var_name = set()
+        self._weight_op_pairs = {}
+        # The vars for alog = KL
+        self._sampling_act_abs_min_max = {}
+        self._sampling_act_histogram = {}
+        self._sampling_data = {}
+        self._quantized_var_kl_threshold = {}
+        self._histogram_bins = 2048
+        # The vars for algo = min_max
+        self._quantized_var_min = {}
+        self._quantized_var_max = {}
+        # The vars for algo = abs_max
+        self._quantized_var_abs_max = {}
+
+    def _load_model_data(self):
+        '''
+        Set data loader.
+        '''
+        logging.info("Set data loader ...")
+        if self._program.num_blocks > 1:
+            _logger.error("The post training quantization requires that the "
+                          "program only has one block.")
+
+        if self._optimize_model:
+            self._optimize_fp32_model()
+
+        feed_vars = [fluid.framework._get_var(var_name, self._program) \
+            for var_name in self._feed_list]
+        self._data_loader = fluid.io.DataLoader.from_generator(
+            feed_list=feed_vars, capacity=3 * self._batch_size, iterable=True)
+        self._data_loader.set_sample_list_generator(
+            self._dataset.generator(
+                self._batch_size, drop_last=True),
+            places=self._place)
+
+    def save_quantized_model(self, save_model_path):
+        '''
+        Save the quantized model to the disk.
+
+        Args:
+            save_model_path(str): The path to save the quantized model.
+        Returns:
+            None
+        '''
+        with fluid.scope_guard(self._scope):
+            fluid.io.save_inference_model(
+                dirname=save_model_path,
+                model_filename='__model__',
+                params_filename='__params__',
+                feeded_var_names=self._feed_list,
+                target_vars=self._fetch_list,
+                executor=self._executor,
+                main_program=self._program)
+        logging.info("The quantized model is saved in " + save_model_path)