|
|
@@ -54,7 +54,8 @@ class BaseModel:
|
|
|
self.completed_epochs = 0
|
|
|
self.pruner = None
|
|
|
self.pruning_ratios = None
|
|
|
- self.quanter = None
|
|
|
+ self.quantizer = None
|
|
|
+ self.quant_config = None
|
|
|
|
|
|
def net_initialize(self, pretrain_weights=None, save_dir='.'):
|
|
|
if pretrain_weights is not None and \
|
|
|
@@ -124,6 +125,11 @@ class BaseModel:
|
|
|
info['pruner_inputs'] = self.pruner.inputs
|
|
|
return info
|
|
|
|
|
|
+ def get_quant_info(self):
|
|
|
+ info = dict()
|
|
|
+ info['quant_config'] = self.quant_config
|
|
|
+ return self.info
|
|
|
+
|
|
|
def save_model(self, save_dir):
|
|
|
if not osp.isdir(save_dir):
|
|
|
if osp.exists(save_dir):
|
|
|
@@ -132,16 +138,10 @@ class BaseModel:
|
|
|
model_info = self.get_model_info()
|
|
|
model_info['status'] = self.status
|
|
|
|
|
|
- if self.status == 'Quantized':
|
|
|
- self.quanter.save_quantized_model(
|
|
|
- self.net,
|
|
|
- osp.join(save_dir, 'model'),
|
|
|
- input_spec=self.test_inputs)
|
|
|
- else:
|
|
|
- paddle.save(self.net.state_dict(),
|
|
|
- osp.join(save_dir, 'model.pdparams'))
|
|
|
- paddle.save(self.optimizer.state_dict(),
|
|
|
- osp.join(save_dir, 'model.pdopt'))
|
|
|
+ paddle.save(self.net.state_dict(),
|
|
|
+ osp.join(save_dir, 'model.pdparams'))
|
|
|
+ paddle.save(self.optimizer.state_dict(),
|
|
|
+ osp.join(save_dir, 'model.pdopt'))
|
|
|
|
|
|
with open(
|
|
|
osp.join(save_dir, 'model.yml'), encoding='utf-8',
|
|
|
@@ -160,6 +160,13 @@ class BaseModel:
|
|
|
mode='w') as f:
|
|
|
yaml.dump(pruning_info, f)
|
|
|
|
|
|
+ if self.status == 'Quantized' and self.quanter is not None:
|
|
|
+ quant_info = self.get_quant_info()
|
|
|
+ with open(
|
|
|
+ osp.join(save_dir, 'quant.yml'), encoding='utf-8',
|
|
|
+ mode='w') as f:
|
|
|
+ yaml.dump(quant_info, f)
|
|
|
+
|
|
|
# 模型保存成功的标志
|
|
|
open(osp.join(save_dir, '.success'), 'w').close()
|
|
|
logging.info("Model saved in {}.".format(save_dir))
|
|
|
@@ -413,7 +420,7 @@ class BaseModel:
|
|
|
Otherwise, the pruned model will be saved at save_dir. Defaults to None.
|
|
|
|
|
|
"""
|
|
|
- if self.status == "Pruned":
|
|
|
+ if self.ne == "Pruned":
|
|
|
raise Exception(
|
|
|
"A pruned model cannot be done model pruning again!")
|
|
|
pre_pruning_flops = flops(self.net, self.pruner.inputs)
|
|
|
@@ -462,9 +469,10 @@ class BaseModel:
|
|
|
# Types of layers that will be quantized.
|
|
|
'quantizable_layer_type': ['Conv2D', 'Linear']
|
|
|
}
|
|
|
- self.quanter = QAT(config=quant_config)
|
|
|
+ self.quant_config = quant_config
|
|
|
+ self.quantizer = QAT(config=self.quant_config)
|
|
|
logging.info("Preparing the model for quantization-aware training...")
|
|
|
- self.quanter.quantize(self.net)
|
|
|
+ self.quantizer.quantize(self.net)
|
|
|
logging.info("Model is ready for quantization-aware training.")
|
|
|
self.status = 'Quantized'
|
|
|
self.test_inputs = self.get_test_inputs(image_shape)
|
|
|
@@ -473,9 +481,20 @@ class BaseModel:
|
|
|
save_dir = osp.join(save_dir, 'inference_model')
|
|
|
self.net.eval()
|
|
|
self.test_inputs = self.get_test_inputs(image_shape)
|
|
|
- static_net = paddle.jit.to_static(
|
|
|
- self.net, input_spec=self.test_inputs)
|
|
|
- paddle.jit.save(static_net, osp.join(save_dir, 'model'))
|
|
|
+
|
|
|
+ if self.status == 'Quantized':
|
|
|
+ self.quantizer.save_quantized_model(self.net,
|
|
|
+ osp.join(save_dir, 'model'),
|
|
|
+ self.test_inputs)
|
|
|
+ quant_info = self.get_quant_info()
|
|
|
+ with open(
|
|
|
+ osp.join(save_dir, 'quant.yml'), encoding='utf-8',
|
|
|
+ mode='w') as f:
|
|
|
+ yaml.dump(quant_info, f)
|
|
|
+ else:
|
|
|
+ static_net = paddle.jit.to_static(
|
|
|
+ self.net, input_spec=self.test_inputs)
|
|
|
+ paddle.jit.save(static_net, osp.join(save_dir, 'model'))
|
|
|
|
|
|
if self.status == 'Pruned':
|
|
|
pruning_info = self.get_pruning_info()
|