Просмотр исходного кода

Merge pull request #76 from SunAhong1993/syf0519

fix the post quant
Jason 5 лет назад
Родитель
Сommit
fe50020f88
2 измененных файлов с 71 добавлено и 92 удалено
  1. 3 2
      paddlex/cv/models/base.py
  2. 68 90
      paddlex/cv/models/slim/post_quantization.py

+ 3 - 2
paddlex/cv/models/base.py

@@ -139,9 +139,10 @@ class BaseAPI:
         dataset.num_samples = batch_size * batch_num
         dataset.num_samples = batch_size * batch_num
         try:
         try:
             from .slim.post_quantization import PaddleXPostTrainingQuantization
             from .slim.post_quantization import PaddleXPostTrainingQuantization
+            PaddleXPostTrainingQuantization._collect_target_varnames
         except:
         except:
             raise Exception(
             raise Exception(
-                "Model Quantization is not available, try to upgrade your paddlepaddle>=1.7.0"
+                "Model Quantization is not available, try to upgrade your paddlepaddle>=1.8.0"
             )
             )
         is_use_cache_file = True
         is_use_cache_file = True
         if cache_dir is None:
         if cache_dir is None:
@@ -544,4 +545,4 @@ class BaseAPI:
                                 best_accuracy))
                                 best_accuracy))
                 if eval_dataset is not None and early_stop:
                 if eval_dataset is not None and early_stop:
                     if earlystop(current_accuracy):
                     if earlystop(current_accuracy):
-                        break
+                        break

+ 68 - 90
paddlex/cv/models/slim/post_quantization.py

@@ -14,7 +14,7 @@
 
 
 from paddle.fluid.contrib.slim.quantization.quantization_pass import QuantizationTransformPass
 from paddle.fluid.contrib.slim.quantization.quantization_pass import QuantizationTransformPass
 from paddle.fluid.contrib.slim.quantization.quantization_pass import AddQuantDequantPass
 from paddle.fluid.contrib.slim.quantization.quantization_pass import AddQuantDequantPass
-from paddle.fluid.contrib.slim.quantization.quantization_pass import _op_real_in_out_name
+from paddle.fluid.contrib.slim.quantization.quantization_pass import _out_scale_op_list
 from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
 from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
 import paddlex.utils.logging as logging
 import paddlex.utils.logging as logging
 import paddle.fluid as fluid
 import paddle.fluid as fluid
@@ -44,7 +44,6 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
         fp32 model. It uses calibrate data to calculate the scale factor of
         fp32 model. It uses calibrate data to calculate the scale factor of
         quantized variables, and inserts fake quant/dequant op to obtain the
         quantized variables, and inserts fake quant/dequant op to obtain the
         quantized model.
         quantized model.
-
         Args:
         Args:
             executor(fluid.Executor): The executor to load, run and save the
             executor(fluid.Executor): The executor to load, run and save the
                 quantized model.
                 quantized model.
@@ -78,6 +77,21 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
         Returns:
         Returns:
             None
             None
         '''
         '''
+        self._support_activation_quantize_type = [
+            'range_abs_max', 'moving_average_abs_max', 'abs_max'
+        ]
+        self._support_weight_quantize_type = ['abs_max', 'channel_wise_abs_max']
+        self._support_algo_type = ['KL', 'abs_max', 'min_max']
+        self._support_quantize_op_type = \
+            list(set(QuantizationTransformPass._supported_quantizable_op_type +
+                AddQuantDequantPass._supported_quantizable_op_type))
+        
+        # Check inputs
+        assert executor is not None, "The executor cannot be None."
+        assert batch_size > 0, "The batch_size should be greater than 0."
+        assert algo in self._support_algo_type, \
+            "The algo should be KL, abs_max or min_max."
+        
         self._executor = executor
         self._executor = executor
         self._dataset = dataset
         self._dataset = dataset
         self._batch_size = batch_size
         self._batch_size = batch_size
@@ -86,18 +100,19 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
         self._algo = algo
         self._algo = algo
         self._is_use_cache_file = is_use_cache_file
         self._is_use_cache_file = is_use_cache_file
         self._cache_dir = cache_dir
         self._cache_dir = cache_dir
+        self._activation_bits = 8
+        self._weight_bits = 8
+        self._activation_quantize_type = 'range_abs_max'
+        self._weight_quantize_type = 'channel_wise_abs_max'
         if self._is_use_cache_file and not os.path.exists(self._cache_dir):
         if self._is_use_cache_file and not os.path.exists(self._cache_dir):
             os.mkdir(self._cache_dir)
             os.mkdir(self._cache_dir)
 
 
-        supported_quantizable_op_type = \
-            QuantizationTransformPass._supported_quantizable_op_type + \
-            AddQuantDequantPass._supported_quantizable_op_type
         if is_full_quantize:
         if is_full_quantize:
-            self._quantizable_op_type = supported_quantizable_op_type
+            self._quantizable_op_type = self._support_quantize_op_type
         else:
         else:
             self._quantizable_op_type = quantizable_op_type
             self._quantizable_op_type = quantizable_op_type
             for op_type in self._quantizable_op_type:
             for op_type in self._quantizable_op_type:
-                assert op_type in supported_quantizable_op_type + \
+                assert op_type in self._support_quantize_op_type + \
                     AddQuantDequantPass._activation_type, \
                     AddQuantDequantPass._activation_type, \
                     op_type + " is not supported for quantization."
                     op_type + " is not supported for quantization."
 
 
@@ -107,25 +122,29 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
         self._fetch_list = list(outputs.values())
         self._fetch_list = list(outputs.values())
         self._data_loader = None
         self._data_loader = None
 
 
-        self._op_real_in_out_name = _op_real_in_out_name
+        self._out_scale_op_list = _out_scale_op_list
         self._bit_length = 8
         self._bit_length = 8
         self._quantized_weight_var_name = set()
         self._quantized_weight_var_name = set()
         self._quantized_act_var_name = set()
         self._quantized_act_var_name = set()
         self._sampling_data = {}
         self._sampling_data = {}
-        self._quantized_var_scale_factor = {}
+        self._quantized_var_kl_threshold = {}
+        self._quantized_var_min = {}
+        self._quantized_var_max = {}
+        self._quantized_var_abs_max = {}
 
 
     def quantize(self):
     def quantize(self):
         '''
         '''
         Quantize the fp32 model. Use calibrate data to calculate the scale factor of
         Quantize the fp32 model. Use calibrate data to calculate the scale factor of
         quantized variables, and inserts fake quant/dequant op to obtain the
         quantized variables, and inserts fake quant/dequant op to obtain the
         quantized model.
         quantized model.
-
         Args:
         Args:
             None
             None
         Returns:
         Returns:
             the program of quantized model.
             the program of quantized model.
         '''
         '''
-        self._preprocess()
+        self._load_model_data()
+        self._collect_target_varnames()
+        self._set_activation_persistable()
         batch_ct = 0
         batch_ct = 0
         for data in self._data_loader():
         for data in self._data_loader():
             batch_ct += 1
             batch_ct += 1
@@ -140,7 +159,10 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
                 feed=data,
                 feed=data,
                 fetch_list=self._fetch_list,
                 fetch_list=self._fetch_list,
                 return_numpy=False)
                 return_numpy=False)
-            self._sample_data(batch_id)
+            if self._algo == "KL":
+                self._sample_data(batch_id)
+            else:
+                self._sample_threshold()
             end = time.time()
             end = time.time()
             logging.debug('[Run batch data] Batch={}/{}, time_each_batch={} s.'.format(
             logging.debug('[Run batch data] Batch={}/{}, time_each_batch={} s.'.format(
                 str(batch_id + 1),
                 str(batch_id + 1),
@@ -150,19 +172,23 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
             if self._batch_nums and batch_id >= self._batch_nums:
             if self._batch_nums and batch_id >= self._batch_nums:
                 break
                 break
         logging.info("All run batch: ".format(batch_id))
         logging.info("All run batch: ".format(batch_id))
+        self._reset_activation_persistable()
         logging.info("Calculate scale factor ...")
         logging.info("Calculate scale factor ...")
-        self._calculate_scale_factor()
+        if self._algo == "KL":
+            self._calculate_kl_threshold()
         logging.info("Update the program ...")
         logging.info("Update the program ...")
-        self._update_program()
+        if self._algo in ["KL", "abs_max"]:
+            self._update_program()
+        else:
+            self._save_input_threhold()
         logging.info("Save ...")
         logging.info("Save ...")
-        self._save_output_scale()
+        self._save_output_threshold()
         logging.info("Finish quant!")
         logging.info("Finish quant!")
         return self._program
         return self._program
 
 
     def save_quantized_model(self, save_model_path):
     def save_quantized_model(self, save_model_path):
         '''
         '''
         Save the quantized model to the disk.
         Save the quantized model to the disk.
-
         Args:
         Args:
             save_model_path(str): The path to save the quantized model
             save_model_path(str): The path to save the quantized model
         Returns:
         Returns:
@@ -176,88 +202,47 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
             executor=self._executor,
             executor=self._executor,
             params_filename='__params__',
             params_filename='__params__',
             main_program=self._program)
             main_program=self._program)
-
-    def _preprocess(self):
+        
+    def _load_model_data(self):
         '''
         '''
-        Load model and set data loader, collect the variable names for sampling,
-        and set activation variables to be persistable.
+        Set data loader.
         '''
         '''
         feed_vars = [fluid.framework._get_var(var.name, self._program) \
         feed_vars = [fluid.framework._get_var(var.name, self._program) \
             for var in self._feed_list]
             for var in self._feed_list]
-
         self._data_loader = fluid.io.DataLoader.from_generator(
         self._data_loader = fluid.io.DataLoader.from_generator(
             feed_list=feed_vars, capacity=3 * self._batch_size, iterable=True)
             feed_list=feed_vars, capacity=3 * self._batch_size, iterable=True)
         self._data_loader.set_sample_list_generator(
         self._data_loader.set_sample_list_generator(
             self._dataset.generator(self._batch_size, drop_last=True),
             self._dataset.generator(self._batch_size, drop_last=True),
             places=self._place)
             places=self._place)
 
 
-        # collect the variable names for sampling
-        persistable_var_names = []
-        for var in self._program.list_vars():
-            if var.persistable:
-                persistable_var_names.append(var.name)
-
-        for op in self._program.global_block().ops:
-            op_type = op.type
-            if op_type in self._quantizable_op_type:
-                if op_type in ("conv2d", "depthwise_conv2d"):
-                    self._quantized_act_var_name.add(op.input("Input")[0])
-                    self._quantized_weight_var_name.add(op.input("Filter")[0])
-                    self._quantized_act_var_name.add(op.output("Output")[0])
-                elif op_type == "mul":
-                    if self._is_input_all_not_persistable(
-                            op, persistable_var_names):
-                        op._set_attr("skip_quant", True)
-                        logging.warning(
-                            "Skip quant a mul op for two input variables are not persistable"
-                        )
-                    else:
-                        self._quantized_act_var_name.add(op.input("X")[0])
-                        self._quantized_weight_var_name.add(op.input("Y")[0])
-                        self._quantized_act_var_name.add(op.output("Out")[0])
-                else:
-                    # process other quantizable op type, the input must all not persistable
-                    if self._is_input_all_not_persistable(
-                            op, persistable_var_names):
-                        input_output_name_list = self._op_real_in_out_name[
-                            op_type]
-                        for input_name in input_output_name_list[0]:
-                            for var_name in op.input(input_name):
-                                self._quantized_act_var_name.add(var_name)
-                        for output_name in input_output_name_list[1]:
-                            for var_name in op.output(output_name):
-                                self._quantized_act_var_name.add(var_name)
-
-        # set activation variables to be persistable, so can obtain
-        # the tensor data in sample_data
-        for var in self._program.list_vars():
-            if var.name in self._quantized_act_var_name:
-                var.persistable = True
-                
-    def _calculate_scale_factor(self):
+    def _calculate_kl_threshold(self):
         '''
         '''
-        Calculate the scale factor of quantized variables.
+        Calculate the KL threshold of quantized variables.
         '''
         '''
-        # apply channel_wise_abs_max quantization for weights
+        assert self._algo == "KL", "The algo should be KL to calculate kl threshold."
         ct = 1
         ct = 1
+        # Abs_max threshold for weights
         for var_name in self._quantized_weight_var_name:
         for var_name in self._quantized_weight_var_name:
             start = time.time()
             start = time.time()
-            data = self._sampling_data[var_name]
-            scale_factor_per_channel = []
-            for i in range(data.shape[0]):
-                abs_max_value = np.max(np.abs(data[i]))
-                scale_factor_per_channel.append(abs_max_value)
-            self._quantized_var_scale_factor[
-                var_name] = scale_factor_per_channel
+            weight_data = self._sampling_data[var_name]
+            weight_threshold = None
+            if self._weight_quantize_type == "abs_max":
+                weight_threshold = np.max(np.abs(weight_data))
+            elif self._weight_quantize_type == "channel_wise_abs_max":
+                weight_threshold = []
+                for i in range(weight_data.shape[0]):
+                    abs_max_value = np.max(np.abs(weight_data[i]))
+                    weight_threshold.append(abs_max_value)
+            self._quantized_var_kl_threshold[var_name] = weight_threshold
             end = time.time()
             end = time.time()
             logging.debug('[Calculate weight] Weight_id={}/{}, time_each_weight={} s.'.format(
             logging.debug('[Calculate weight] Weight_id={}/{}, time_each_weight={} s.'.format(
                 str(ct),
                 str(ct),
                 str(len(self._quantized_weight_var_name)),
                 str(len(self._quantized_weight_var_name)),
                 str(end-start)))
                 str(end-start)))
             ct += 1
             ct += 1
-            
+
         ct = 1
         ct = 1
-        # apply kl quantization for activation
+        # KL threshold for activations
         if self._is_use_cache_file:
         if self._is_use_cache_file:
             for var_name in self._quantized_act_var_name:
             for var_name in self._quantized_act_var_name:
                 start = time.time()
                 start = time.time()
@@ -269,13 +254,8 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
                     sampling_data.append(np.load(file_path))
                     sampling_data.append(np.load(file_path))
                     os.remove(file_path)
                     os.remove(file_path)
                 sampling_data = np.concatenate(sampling_data)
                 sampling_data = np.concatenate(sampling_data)
-
-                if self._algo == "KL":
-                    self._quantized_var_scale_factor[var_name] = \
-                        self._get_kl_scaling_factor(np.abs(sampling_data))
-                else:
-                    self._quantized_var_scale_factor[var_name] = \
-                        np.max(np.abs(sampling_data))
+                self._quantized_var_kl_threshold[var_name] = \
+                    self._get_kl_scaling_factor(np.abs(sampling_data))
                 end = time.time()
                 end = time.time()
                 logging.debug('[Calculate activation] Activation_id={}/{}, time_each_activation={} s.'.format(
                 logging.debug('[Calculate activation] Activation_id={}/{}, time_each_activation={} s.'.format(
                     str(ct),
                     str(ct),
@@ -287,15 +267,13 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
                 start = time.time()
                 start = time.time()
                 self._sampling_data[var_name] = np.concatenate(
                 self._sampling_data[var_name] = np.concatenate(
                     self._sampling_data[var_name])
                     self._sampling_data[var_name])
-                if self._algo == "KL":
-                    self._quantized_var_scale_factor[var_name] = \
-                        self._get_kl_scaling_factor(np.abs(self._sampling_data[var_name]))
-                else:
-                    self._quantized_var_scale_factor[var_name] = \
-                        np.max(np.abs(self._sampling_data[var_name]))
+                self._quantized_var_kl_threshold[var_name] = \
+                    self._get_kl_scaling_factor(np.abs(self._sampling_data[var_name]))
                 end = time.time()
                 end = time.time()
                 logging.debug('[Calculate activation] Activation_id={}/{}, time_each_activation={} s.'.format(
                 logging.debug('[Calculate activation] Activation_id={}/{}, time_each_activation={} s.'.format(
                     str(ct),
                     str(ct),
                     str(len(self._quantized_act_var_name)),
                     str(len(self._quantized_act_var_name)),
                     str(end-start)))
                     str(end-start)))
-                ct += 1
+                ct += 1
+
+