Przeglądaj źródła

fis the post quant

sunyanfang01 5 lat temu
rodzic
commit
0b9a4c4c46
2 zmienionych plików z 86 dodań i 100 usunięć
  1. 18 10
      paddlex/cv/models/base.py
  2. 68 90
      paddlex/cv/models/slim/post_quantization.py

+ 18 - 10
paddlex/cv/models/base.py

@@ -15,7 +15,6 @@
 from __future__ import absolute_import
 import paddle.fluid as fluid
 import os
-import sys
 import numpy as np
 import time
 import math
@@ -139,9 +138,10 @@ class BaseAPI:
         dataset.num_samples = batch_size * batch_num
         try:
             from .slim.post_quantization import PaddleXPostTrainingQuantization
+            PaddleXPostTrainingQuantization._collect_target_varnames
         except:
             raise Exception(
-                "Model Quantization is not available, try to upgrade your paddlepaddle>=1.7.0"
+                "Model Quantization is not available, try to upgrade your paddlepaddle>=1.8.0"
             )
         is_use_cache_file = True
         if cache_dir is None:
@@ -252,9 +252,6 @@ class BaseAPI:
             del self.init_params['self']
         if '__class__' in self.init_params:
             del self.init_params['__class__']
-        if 'model_name' in self.init_params:
-            del self.init_params['model_name']
-
         info['_init_params'] = self.init_params
 
         info['_Attributes']['num_classes'] = self.num_classes
@@ -375,8 +372,6 @@ class BaseAPI:
                    use_vdl=False,
                    early_stop=False,
                    early_stop_patience=5):
-        if train_dataset.num_samples < train_batch_size:
-            raise Exception('The amount of training datset must be larger than batch size.')
         if not osp.isdir(save_dir):
             if osp.exists(save_dir):
                 os.remove(save_dir)
@@ -434,7 +429,9 @@ class BaseAPI:
 
         if use_vdl:
             # VisualDL component
-            log_writer = LogWriter(vdl_logdir)
+            log_writer = LogWriter(vdl_logdir, sync_cycle=20)
+            train_step_component = OrderedDict()
+            eval_component = OrderedDict()
 
         thresh = 0.0001
         if early_stop:
@@ -472,7 +469,13 @@ class BaseAPI:
 
                     if use_vdl:
                         for k, v in step_metrics.items():
-                            log_writer.add_scalar('Metrics/Training(Step): {}'.format(k), v, num_steps)
+                            if k not in train_step_component.keys():
+                                with log_writer.mode('Each_Step_while_Training'
+                                                     ) as step_logger:
+                                    train_step_component[
+                                        k] = step_logger.scalar(
+                                            'Training: {}'.format(k))
+                            train_step_component[k].add_record(num_steps, v)
 
                     # 估算剩余时间
                     avg_step_time = np.mean(time_stat)
@@ -533,7 +536,12 @@ class BaseAPI:
                             if isinstance(v, np.ndarray):
                                 if v.size > 1:
                                     continue
-                            log_writer.add_scalar("Metrics/Eval(Epoch): {}".format(k), v, i+1)
+                            if k not in eval_component:
+                                with log_writer.mode('Each_Epoch_on_Eval_Data'
+                                                     ) as eval_logger:
+                                    eval_component[k] = eval_logger.scalar(
+                                        'Evaluation: {}'.format(k))
+                            eval_component[k].add_record(i + 1, v)
                 self.save_model(save_dir=current_save_dir)
                 time_eval_one_epoch = time.time() - eval_epoch_start_time
                 eval_epoch_start_time = time.time()

+ 68 - 90
paddlex/cv/models/slim/post_quantization.py

@@ -14,7 +14,7 @@
 
 from paddle.fluid.contrib.slim.quantization.quantization_pass import QuantizationTransformPass
 from paddle.fluid.contrib.slim.quantization.quantization_pass import AddQuantDequantPass
-from paddle.fluid.contrib.slim.quantization.quantization_pass import _op_real_in_out_name
+from paddle.fluid.contrib.slim.quantization.quantization_pass import _out_scale_op_list
 from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
 import paddlex.utils.logging as logging
 import paddle.fluid as fluid
@@ -44,7 +44,6 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
         fp32 model. It uses calibrate data to calculate the scale factor of
         quantized variables, and inserts fake quant/dequant op to obtain the
         quantized model.
-
         Args:
             executor(fluid.Executor): The executor to load, run and save the
                 quantized model.
@@ -78,6 +77,21 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
         Returns:
             None
         '''
+        self._support_activation_quantize_type = [
+            'range_abs_max', 'moving_average_abs_max', 'abs_max'
+        ]
+        self._support_weight_quantize_type = ['abs_max', 'channel_wise_abs_max']
+        self._support_algo_type = ['KL', 'abs_max', 'min_max']
+        self._support_quantize_op_type = \
+            list(set(QuantizationTransformPass._supported_quantizable_op_type +
+                AddQuantDequantPass._supported_quantizable_op_type))
+        
+        # Check inputs
+        assert executor is not None, "The executor cannot be None."
+        assert batch_size > 0, "The batch_size should be greater than 0."
+        assert algo in self._support_algo_type, \
+            "The algo should be KL, abs_max or min_max."
+        
         self._executor = executor
         self._dataset = dataset
         self._batch_size = batch_size
@@ -86,18 +100,19 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
         self._algo = algo
         self._is_use_cache_file = is_use_cache_file
         self._cache_dir = cache_dir
+        self._activation_bits = 8
+        self._weight_bits = 8
+        self._activation_quantize_type = 'range_abs_max'
+        self._weight_quantize_type = 'channel_wise_abs_max'
         if self._is_use_cache_file and not os.path.exists(self._cache_dir):
             os.mkdir(self._cache_dir)
 
-        supported_quantizable_op_type = \
-            QuantizationTransformPass._supported_quantizable_op_type + \
-            AddQuantDequantPass._supported_quantizable_op_type
         if is_full_quantize:
-            self._quantizable_op_type = supported_quantizable_op_type
+            self._quantizable_op_type = self._support_quantize_op_type
         else:
             self._quantizable_op_type = quantizable_op_type
             for op_type in self._quantizable_op_type:
-                assert op_type in supported_quantizable_op_type + \
+                assert op_type in self._support_quantize_op_type + \
                     AddQuantDequantPass._activation_type, \
                     op_type + " is not supported for quantization."
 
@@ -107,25 +122,29 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
         self._fetch_list = list(outputs.values())
         self._data_loader = None
 
-        self._op_real_in_out_name = _op_real_in_out_name
+        self._out_scale_op_list = _out_scale_op_list
         self._bit_length = 8
         self._quantized_weight_var_name = set()
         self._quantized_act_var_name = set()
         self._sampling_data = {}
-        self._quantized_var_scale_factor = {}
+        self._quantized_var_kl_threshold = {}
+        self._quantized_var_min = {}
+        self._quantized_var_max = {}
+        self._quantized_var_abs_max = {}
 
     def quantize(self):
         '''
         Quantize the fp32 model. Use calibrate data to calculate the scale factor of
         quantized variables, and inserts fake quant/dequant op to obtain the
         quantized model.
-
         Args:
             None
         Returns:
             the program of quantized model.
         '''
-        self._preprocess()
+        self._load_model_data()
+        self._collect_target_varnames()
+        self._set_activation_persistable()
         batch_ct = 0
         for data in self._data_loader():
             batch_ct += 1
@@ -140,7 +159,10 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
                 feed=data,
                 fetch_list=self._fetch_list,
                 return_numpy=False)
-            self._sample_data(batch_id)
+            if self._algo == "KL":
+                self._sample_data(batch_id)
+            else:
+                self._sample_threshold()
             end = time.time()
             logging.debug('[Run batch data] Batch={}/{}, time_each_batch={} s.'.format(
                 str(batch_id + 1),
@@ -150,19 +172,23 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
             if self._batch_nums and batch_id >= self._batch_nums:
                 break
         logging.info("All run batch: ".format(batch_id))
+        self._reset_activation_persistable()
         logging.info("Calculate scale factor ...")
-        self._calculate_scale_factor()
+        if self._algo == "KL":
+            self._calculate_kl_threshold()
         logging.info("Update the program ...")
-        self._update_program()
+        if self._algo in ["KL", "abs_max"]:
+            self._update_program()
+        else:
+            self._save_input_threhold()
         logging.info("Save ...")
-        self._save_output_scale()
+        self._save_output_threshold()
         logging.info("Finish quant!")
         return self._program
 
     def save_quantized_model(self, save_model_path):
         '''
         Save the quantized model to the disk.
-
         Args:
             save_model_path(str): The path to save the quantized model
         Returns:
@@ -176,88 +202,47 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
             executor=self._executor,
             params_filename='__params__',
             main_program=self._program)
-
-    def _preprocess(self):
+        
+    def _load_model_data(self):
         '''
-        Load model and set data loader, collect the variable names for sampling,
-        and set activation variables to be persistable.
+        Set data loader.
         '''
         feed_vars = [fluid.framework._get_var(var.name, self._program) \
             for var in self._feed_list]
-
         self._data_loader = fluid.io.DataLoader.from_generator(
             feed_list=feed_vars, capacity=3 * self._batch_size, iterable=True)
         self._data_loader.set_sample_list_generator(
             self._dataset.generator(self._batch_size, drop_last=True),
             places=self._place)
 
-        # collect the variable names for sampling
-        persistable_var_names = []
-        for var in self._program.list_vars():
-            if var.persistable:
-                persistable_var_names.append(var.name)
-
-        for op in self._program.global_block().ops:
-            op_type = op.type
-            if op_type in self._quantizable_op_type:
-                if op_type in ("conv2d", "depthwise_conv2d"):
-                    self._quantized_act_var_name.add(op.input("Input")[0])
-                    self._quantized_weight_var_name.add(op.input("Filter")[0])
-                    self._quantized_act_var_name.add(op.output("Output")[0])
-                elif op_type == "mul":
-                    if self._is_input_all_not_persistable(
-                            op, persistable_var_names):
-                        op._set_attr("skip_quant", True)
-                        logging.warning(
-                            "Skip quant a mul op for two input variables are not persistable"
-                        )
-                    else:
-                        self._quantized_act_var_name.add(op.input("X")[0])
-                        self._quantized_weight_var_name.add(op.input("Y")[0])
-                        self._quantized_act_var_name.add(op.output("Out")[0])
-                else:
-                    # process other quantizable op type, the input must all not persistable
-                    if self._is_input_all_not_persistable(
-                            op, persistable_var_names):
-                        input_output_name_list = self._op_real_in_out_name[
-                            op_type]
-                        for input_name in input_output_name_list[0]:
-                            for var_name in op.input(input_name):
-                                self._quantized_act_var_name.add(var_name)
-                        for output_name in input_output_name_list[1]:
-                            for var_name in op.output(output_name):
-                                self._quantized_act_var_name.add(var_name)
-
-        # set activation variables to be persistable, so can obtain
-        # the tensor data in sample_data
-        for var in self._program.list_vars():
-            if var.name in self._quantized_act_var_name:
-                var.persistable = True
-                
-    def _calculate_scale_factor(self):
+    def _calculate_kl_threshold(self):
         '''
-        Calculate the scale factor of quantized variables.
+        Calculate the KL threshold of quantized variables.
         '''
-        # apply channel_wise_abs_max quantization for weights
+        assert self._algo == "KL", "The algo should be KL to calculate kl threshold."
         ct = 1
+        # Abs_max threshold for weights
         for var_name in self._quantized_weight_var_name:
             start = time.time()
-            data = self._sampling_data[var_name]
-            scale_factor_per_channel = []
-            for i in range(data.shape[0]):
-                abs_max_value = np.max(np.abs(data[i]))
-                scale_factor_per_channel.append(abs_max_value)
-            self._quantized_var_scale_factor[
-                var_name] = scale_factor_per_channel
+            weight_data = self._sampling_data[var_name]
+            weight_threshold = None
+            if self._weight_quantize_type == "abs_max":
+                weight_threshold = np.max(np.abs(weight_data))
+            elif self._weight_quantize_type == "channel_wise_abs_max":
+                weight_threshold = []
+                for i in range(weight_data.shape[0]):
+                    abs_max_value = np.max(np.abs(weight_data[i]))
+                    weight_threshold.append(abs_max_value)
+            self._quantized_var_kl_threshold[var_name] = weight_threshold
             end = time.time()
             logging.debug('[Calculate weight] Weight_id={}/{}, time_each_weight={} s.'.format(
                 str(ct),
                 str(len(self._quantized_weight_var_name)),
                 str(end-start)))
             ct += 1
-            
+
         ct = 1
-        # apply kl quantization for activation
+        # KL threshold for activations
         if self._is_use_cache_file:
             for var_name in self._quantized_act_var_name:
                 start = time.time()
@@ -269,13 +254,8 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
                     sampling_data.append(np.load(file_path))
                     os.remove(file_path)
                 sampling_data = np.concatenate(sampling_data)
-
-                if self._algo == "KL":
-                    self._quantized_var_scale_factor[var_name] = \
-                        self._get_kl_scaling_factor(np.abs(sampling_data))
-                else:
-                    self._quantized_var_scale_factor[var_name] = \
-                        np.max(np.abs(sampling_data))
+                self._quantized_var_kl_threshold[var_name] = \
+                    self._get_kl_scaling_factor(np.abs(sampling_data))
                 end = time.time()
                 logging.debug('[Calculate activation] Activation_id={}/{}, time_each_activation={} s.'.format(
                     str(ct),
@@ -287,15 +267,13 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
                 start = time.time()
                 self._sampling_data[var_name] = np.concatenate(
                     self._sampling_data[var_name])
-                if self._algo == "KL":
-                    self._quantized_var_scale_factor[var_name] = \
-                        self._get_kl_scaling_factor(np.abs(self._sampling_data[var_name]))
-                else:
-                    self._quantized_var_scale_factor[var_name] = \
-                        np.max(np.abs(self._sampling_data[var_name]))
+                self._quantized_var_kl_threshold[var_name] = \
+                    self._get_kl_scaling_factor(np.abs(self._sampling_data[var_name]))
                 end = time.time()
                 logging.debug('[Calculate activation] Activation_id={}/{}, time_each_activation={} s.'.format(
                     str(ct),
                     str(len(self._quantized_act_var_name)),
                     str(end-start)))
-                ct += 1
+                ct += 1
+
+