|
|
@@ -22,6 +22,7 @@ import yaml
|
|
|
import json
|
|
|
import paddle
|
|
|
from paddle.io import DataLoader, DistributedBatchSampler
|
|
|
+from paddleslim import QAT
|
|
|
from paddleslim.analysis import flops
|
|
|
from paddleslim import L1NormFilterPruner, FPGMFilterPruner
|
|
|
import paddlex
|
|
|
@@ -53,6 +54,8 @@ class BaseModel:
|
|
|
self.completed_epochs = 0
|
|
|
self.pruner = None
|
|
|
self.pruning_ratios = None
|
|
|
+ self.quantizer = None
|
|
|
+ self.quant_config = None
|
|
|
|
|
|
def net_initialize(self, pretrain_weights=None, save_dir='.'):
|
|
|
if pretrain_weights is not None and \
|
|
|
@@ -122,6 +125,11 @@ class BaseModel:
|
|
|
info['pruner_inputs'] = self.pruner.inputs
|
|
|
return info
|
|
|
|
|
|
+ def get_quant_info(self):
|
|
|
+ info = dict()
|
|
|
+ info['quant_config'] = self.quant_config
|
|
|
+ return info
|
|
|
+
|
|
|
def save_model(self, save_dir):
|
|
|
if not osp.isdir(save_dir):
|
|
|
if osp.exists(save_dir):
|
|
|
@@ -129,10 +137,11 @@ class BaseModel:
|
|
|
os.makedirs(save_dir)
|
|
|
model_info = self.get_model_info()
|
|
|
model_info['status'] = self.status
|
|
|
+
|
|
|
paddle.save(self.net.state_dict(),
|
|
|
- os.path.join(save_dir, 'model.pdparams'))
|
|
|
+ osp.join(save_dir, 'model.pdparams'))
|
|
|
paddle.save(self.optimizer.state_dict(),
|
|
|
- os.path.join(save_dir, 'model.pdopt'))
|
|
|
+ osp.join(save_dir, 'model.pdopt'))
|
|
|
|
|
|
with open(
|
|
|
osp.join(save_dir, 'model.yml'), encoding='utf-8',
|
|
|
@@ -151,6 +160,13 @@ class BaseModel:
|
|
|
mode='w') as f:
|
|
|
yaml.dump(pruning_info, f)
|
|
|
|
|
|
+ if self.status == 'Quantized' and self.quantizer is not None:
|
|
|
+ quant_info = self.get_quant_info()
|
|
|
+ with open(
|
|
|
+ osp.join(save_dir, 'quant.yml'), encoding='utf-8',
|
|
|
+ mode='w') as f:
|
|
|
+ yaml.dump(quant_info, f)
|
|
|
+
|
|
|
# 模型保存成功的标志
|
|
|
open(osp.join(save_dir, '.success'), 'w').close()
|
|
|
logging.info("Model saved in {}.".format(save_dir))
|
|
|
@@ -400,8 +416,8 @@ class BaseModel:
|
|
|
|
|
|
Args:
|
|
|
pruned_flops(float): Ratio of FLOPs to be pruned.
|
|
|
- save_dir(None or str, optional): If None, the pruned model will not be saved
|
|
|
- Otherwise, the pruned model will be saved at save_dir. Defaults to None.
|
|
|
+ save_dir(None or str, optional): If None, the pruned model will not be saved.
|
|
|
+ Otherwise, the pruned model will be saved at save_dir. Defaults to None.
|
|
|
|
|
|
"""
|
|
|
if self.status == "Pruned":
|
|
|
@@ -427,13 +443,57 @@ class BaseModel:
|
|
|
self.save_model(save_dir)
|
|
|
logging.info("Pruned model is saved at {}".format(save_dir))
|
|
|
|
|
|
+ def _prepare_qat(self, quant_config):
|
|
|
+ if quant_config is None:
|
|
|
+ # default quantization configuration
|
|
|
+ quant_config = {
|
|
|
+ # {None, 'PACT'}. Weight preprocess type. If None, no preprocessing is performed.
|
|
|
+ 'weight_preprocess_type': None,
|
|
|
+ # {None, 'PACT'}. Activation preprocess type. If None, no preprocessing is performed.
|
|
|
+ 'activation_preprocess_type': None,
|
|
|
+ # {'abs_max', 'channel_wise_abs_max', 'range_abs_max', 'moving_average_abs_max'}.
|
|
|
+ # Weight quantization type.
|
|
|
+ 'weight_quantize_type': 'channel_wise_abs_max',
|
|
|
+ # {'abs_max', 'range_abs_max', 'moving_average_abs_max'}. Activation quantization type.
|
|
|
+ 'activation_quantize_type': 'moving_average_abs_max',
|
|
|
+ # The number of bits of weights after quantization.
|
|
|
+ 'weight_bits': 8,
|
|
|
+ # The number of bits of activation after quantization.
|
|
|
+ 'activation_bits': 8,
|
|
|
+ # Data type after quantization, such as 'uint8', 'int8', etc.
|
|
|
+ 'dtype': 'int8',
|
|
|
+ # Window size for 'range_abs_max' quantization.
|
|
|
+ 'window_size': 10000,
|
|
|
+ # Decay coefficient of moving average.
|
|
|
+ 'moving_rate': .9,
|
|
|
+ # Types of layers that will be quantized.
|
|
|
+ 'quantizable_layer_type': ['Conv2D', 'Linear']
|
|
|
+ }
|
|
|
+ self.quant_config = quant_config
|
|
|
+ self.quantizer = QAT(config=self.quant_config)
|
|
|
+ logging.info("Preparing the model for quantization-aware training...")
|
|
|
+ self.quantizer.quantize(self.net)
|
|
|
+ logging.info("Model is ready for quantization-aware training.")
|
|
|
+ self.status = 'Quantized'
|
|
|
+
|
|
|
def _export_inference_model(self, save_dir, image_shape=[-1, -1]):
|
|
|
save_dir = osp.join(save_dir, 'inference_model')
|
|
|
self.net.eval()
|
|
|
self.test_inputs = self.get_test_inputs(image_shape)
|
|
|
- static_net = paddle.jit.to_static(
|
|
|
- self.net, input_spec=self.test_inputs)
|
|
|
- paddle.jit.save(static_net, osp.join(save_dir, 'model'))
|
|
|
+
|
|
|
+ if self.status == 'Quantized':
|
|
|
+ self.quantizer.save_quantized_model(self.net,
|
|
|
+ osp.join(save_dir, 'model'),
|
|
|
+ self.test_inputs)
|
|
|
+ quant_info = self.get_quant_info()
|
|
|
+ with open(
|
|
|
+ osp.join(save_dir, 'quant.yml'), encoding='utf-8',
|
|
|
+ mode='w') as f:
|
|
|
+ yaml.dump(quant_info, f)
|
|
|
+ else:
|
|
|
+ static_net = paddle.jit.to_static(
|
|
|
+ self.net, input_spec=self.test_inputs)
|
|
|
+ paddle.jit.save(static_net, osp.join(save_dir, 'model'))
|
|
|
|
|
|
if self.status == 'Pruned':
|
|
|
pruning_info = self.get_pruning_info()
|