|
|
@@ -19,6 +19,9 @@ from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
|
|
|
import paddlex.utils.logging as logging
|
|
|
import paddle.fluid as fluid
|
|
|
import os
|
|
|
+import re
|
|
|
+import numpy as np
|
|
|
+import datetime
|
|
|
|
|
|
|
|
|
class PaddleXPostTrainingQuantization(PostTrainingQuantization):
|
|
|
@@ -123,28 +126,37 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
|
|
|
the program of quantized model.
|
|
|
'''
|
|
|
self._preprocess()
|
|
|
-
|
|
|
+ batch_ct = 0
|
|
|
+ for data in self._data_loader():
|
|
|
+ batch_ct += 1
|
|
|
+ if self._batch_nums and batch_ct >= self._batch_nums:
|
|
|
+ break
|
|
|
batch_id = 0
|
|
|
+ logging.info("Start to run batch!")
|
|
|
for data in self._data_loader():
|
|
|
+ start = datetime.datetime.now()
|
|
|
self._executor.run(
|
|
|
program=self._program,
|
|
|
feed=data,
|
|
|
fetch_list=self._fetch_list,
|
|
|
return_numpy=False)
|
|
|
self._sample_data(batch_id)
|
|
|
-
|
|
|
- if batch_id % 5 == 0:
|
|
|
- logging.info("run batch: {}".format(batch_id))
|
|
|
+ end = datetime.datetime.now()
|
|
|
+ logging.debug('[Run batch data] Batch={}/{}, time_each_batch={} ms.'.format(
|
|
|
+ str(batch_id + 1),
|
|
|
+ str(batch_ct),
|
|
|
+ str((end-start).microseconds)))
|
|
|
batch_id += 1
|
|
|
if self._batch_nums and batch_id >= self._batch_nums:
|
|
|
break
|
|
|
- logging.info("all run batch: ".format(batch_id))
|
|
|
- logging.info("calculate scale factor ...")
|
|
|
+ logging.info("All run batch: ".format(batch_id))
|
|
|
+ logging.info("Calculate scale factor ...")
|
|
|
self._calculate_scale_factor()
|
|
|
- logging.info("update the program ...")
|
|
|
+ logging.info("Update the program ...")
|
|
|
self._update_program()
|
|
|
-
|
|
|
+ logging.info("Save ...")
|
|
|
self._save_output_scale()
|
|
|
+ logging.info("Finish quant!")
|
|
|
return self._program
|
|
|
|
|
|
def save_quantized_model(self, save_model_path):
|
|
|
@@ -221,3 +233,69 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
|
|
|
for var in self._program.list_vars():
|
|
|
if var.name in self._quantized_act_var_name:
|
|
|
var.persistable = True
|
|
|
+
|
|
|
+ def _calculate_scale_factor(self):
|
|
|
+ '''
|
|
|
+ Calculate the scale factor of quantized variables.
|
|
|
+ '''
|
|
|
+ # apply channel_wise_abs_max quantization for weights
|
|
|
+ ct = 1
|
|
|
+ for var_name in self._quantized_weight_var_name:
|
|
|
+ start = datetime.datetime.now()
|
|
|
+ data = self._sampling_data[var_name]
|
|
|
+ scale_factor_per_channel = []
|
|
|
+ for i in range(data.shape[0]):
|
|
|
+ abs_max_value = np.max(np.abs(data[i]))
|
|
|
+ scale_factor_per_channel.append(abs_max_value)
|
|
|
+ self._quantized_var_scale_factor[
|
|
|
+ var_name] = scale_factor_per_channel
|
|
|
+ end = datetime.datetime.now()
|
|
|
+ logging.debug('[Calculate weight] Weight_id={}/{}, time_each_weight={} ms.'.format(
|
|
|
+ str(ct),
|
|
|
+ str(len(self._quantized_weight_var_name)),
|
|
|
+ str((end-start).microseconds)))
|
|
|
+ ct += 1
|
|
|
+
|
|
|
+ ct = 1
|
|
|
+ # apply kl quantization for activation
|
|
|
+ if self._is_use_cache_file:
|
|
|
+ for var_name in self._quantized_act_var_name:
|
|
|
+ start = datetime.datetime.now()
|
|
|
+ sampling_data = []
|
|
|
+ filenames = [f for f in os.listdir(self._cache_dir) \
|
|
|
+ if re.match(var_name + '_[0-9]+.npy', f)]
|
|
|
+ for filename in filenames:
|
|
|
+ file_path = os.path.join(self._cache_dir, filename)
|
|
|
+ sampling_data.append(np.load(file_path))
|
|
|
+ os.remove(file_path)
|
|
|
+ sampling_data = np.concatenate(sampling_data)
|
|
|
+
|
|
|
+ if self._algo == "KL":
|
|
|
+ self._quantized_var_scale_factor[var_name] = \
|
|
|
+ self._get_kl_scaling_factor(np.abs(sampling_data))
|
|
|
+ else:
|
|
|
+ self._quantized_var_scale_factor[var_name] = \
|
|
|
+ np.max(np.abs(sampling_data))
|
|
|
+ end = datetime.datetime.now()
|
|
|
+ logging.debug('[Calculate activation] Activation_id={}/{}, time_each_activation={} ms.'.format(
|
|
|
+ str(ct),
|
|
|
+ str(len(self._quantized_act_var_name)),
|
|
|
+ str((end-start).microseconds)))
|
|
|
+ ct += 1
|
|
|
+ else:
|
|
|
+ for var_name in self._quantized_act_var_name:
|
|
|
+ start = datetime.datetime.now()
|
|
|
+ self._sampling_data[var_name] = np.concatenate(
|
|
|
+ self._sampling_data[var_name])
|
|
|
+ if self._algo == "KL":
|
|
|
+ self._quantized_var_scale_factor[var_name] = \
|
|
|
+ self._get_kl_scaling_factor(np.abs(self._sampling_data[var_name]))
|
|
|
+ else:
|
|
|
+ self._quantized_var_scale_factor[var_name] = \
|
|
|
+ np.max(np.abs(self._sampling_data[var_name]))
|
|
|
+ end = datetime.datetime.now()
|
|
|
+ logging.debug('[Calculate activation] Activation_id={}/{}, time_each_activation={} ms.'.format(
|
|
|
+ str(ct),
|
|
|
+ str(len(self._quantized_act_var_name)),
|
|
|
+ str((end-start).microseconds)))
|
|
|
+ ct += 1
|