|
@@ -29,7 +29,7 @@ import paddlex
|
|
|
from paddlex.cv.transforms import arrange_transforms
|
|
from paddlex.cv.transforms import arrange_transforms
|
|
|
from paddlex.utils import (seconds_to_hms, get_single_card_bs, dict2str,
|
|
from paddlex.utils import (seconds_to_hms, get_single_card_bs, dict2str,
|
|
|
get_pretrain_weights, load_pretrain_weights,
|
|
get_pretrain_weights, load_pretrain_weights,
|
|
|
- SmoothedValue, TrainingStats,
|
|
|
|
|
|
|
+ load_checkpoint, SmoothedValue, TrainingStats,
|
|
|
_get_shared_memory_size_in_M, EarlyStop)
|
|
_get_shared_memory_size_in_M, EarlyStop)
|
|
|
import paddlex.utils.logging as logging
|
|
import paddlex.utils.logging as logging
|
|
|
from .slim.prune import _pruner_eval_fn, _pruner_template_input, sensitive_prune
|
|
from .slim.prune import _pruner_eval_fn, _pruner_template_input, sensitive_prune
|
|
@@ -57,11 +57,14 @@ class BaseModel:
|
|
|
self.quantizer = None
|
|
self.quantizer = None
|
|
|
self.quant_config = None
|
|
self.quant_config = None
|
|
|
|
|
|
|
|
- def net_initialize(self, pretrain_weights=None, save_dir='.'):
|
|
|
|
|
|
|
+ def net_initialize(self,
|
|
|
|
|
+ pretrain_weights=None,
|
|
|
|
|
+ save_dir='.',
|
|
|
|
|
+ resume_checkpoint=None):
|
|
|
if pretrain_weights is not None and \
|
|
if pretrain_weights is not None and \
|
|
|
- not os.path.exists(pretrain_weights):
|
|
|
|
|
- if not os.path.isdir(save_dir):
|
|
|
|
|
- if os.path.exists(save_dir):
|
|
|
|
|
|
|
+ not osp.exists(pretrain_weights):
|
|
|
|
|
+ if not osp.isdir(save_dir):
|
|
|
|
|
+ if osp.exists(save_dir):
|
|
|
os.remove(save_dir)
|
|
os.remove(save_dir)
|
|
|
os.makedirs(save_dir)
|
|
os.makedirs(save_dir)
|
|
|
if self.model_type == 'classifier':
|
|
if self.model_type == 'classifier':
|
|
@@ -77,6 +80,37 @@ class BaseModel:
|
|
|
if pretrain_weights is not None:
|
|
if pretrain_weights is not None:
|
|
|
load_pretrain_weights(
|
|
load_pretrain_weights(
|
|
|
self.net, pretrain_weights, model_name=self.model_name)
|
|
self.net, pretrain_weights, model_name=self.model_name)
|
|
|
|
|
+ if resume_checkpoint is not None:
|
|
|
|
|
+ if not osp.exists(resume_checkpoint):
|
|
|
|
|
+ logging.error(
|
|
|
|
|
+ "The checkpoint path {} to resume training from does not exist."
|
|
|
|
|
+ .format(resume_checkpoint),
|
|
|
|
|
+ exit=True)
|
|
|
|
|
+ if not osp.exists(osp.join(resume_checkpoint, 'model.pdparams')):
|
|
|
|
|
+ logging.error(
|
|
|
|
|
+ "Model parameter state dictionary file 'model.pdparams' "
|
|
|
|
|
+ "not found under given checkpoint path {}".format(
|
|
|
|
|
+ resume_checkpoint),
|
|
|
|
|
+ exit=True)
|
|
|
|
|
+ if not osp.exists(osp.join(resume_checkpoint, 'model.pdopt')):
|
|
|
|
|
+ logging.error(
|
|
|
|
|
+ "Optimizer state dictionary file 'model.pdparams' "
|
|
|
|
|
+ "not found under given checkpoint path {}".format(
|
|
|
|
|
+ resume_checkpoint),
|
|
|
|
|
+ exit=True)
|
|
|
|
|
+ if not osp.exists(osp.join(resume_checkpoint, 'model.yml')):
|
|
|
|
|
+ logging.error(
|
|
|
|
|
+ "'model.yml' not found under given checkpoint path {}".
|
|
|
|
|
+ format(resume_checkpoint),
|
|
|
|
|
+ exit=True)
|
|
|
|
|
+ with open(osp.join(resume_checkpoint, "model.yml")) as f:
|
|
|
|
|
+ info = yaml.load(f.read(), Loader=yaml.Loader)
|
|
|
|
|
+ self.completed_epochs = info['completed_epochs']
|
|
|
|
|
+ load_checkpoint(
|
|
|
|
|
+ self.net,
|
|
|
|
|
+ self.optimizer,
|
|
|
|
|
+ model_name=self.model_name,
|
|
|
|
|
+ checkpoint=resume_checkpoint)
|
|
|
|
|
|
|
|
def get_model_info(self):
|
|
def get_model_info(self):
|
|
|
info = dict()
|
|
info = dict()
|
|
@@ -339,7 +373,7 @@ class BaseModel:
|
|
|
# 每间隔save_interval_epochs, 在验证集上评估和对模型进行保存
|
|
# 每间隔save_interval_epochs, 在验证集上评估和对模型进行保存
|
|
|
if ema is not None:
|
|
if ema is not None:
|
|
|
weight = self.net.state_dict()
|
|
weight = self.net.state_dict()
|
|
|
- self.net.set_dict(ema.apply())
|
|
|
|
|
|
|
+ self.net.set_state_dict(ema.apply())
|
|
|
eval_epoch_tic = time.time()
|
|
eval_epoch_tic = time.time()
|
|
|
if (i + 1) % save_interval_epochs == 0 or i == num_epochs - 1:
|
|
if (i + 1) % save_interval_epochs == 0 or i == num_epochs - 1:
|
|
|
if eval_dataset is not None and eval_dataset.num_samples > 0:
|
|
if eval_dataset is not None and eval_dataset.num_samples > 0:
|
|
@@ -374,7 +408,7 @@ class BaseModel:
|
|
|
if earlystop(current_accuracy):
|
|
if earlystop(current_accuracy):
|
|
|
break
|
|
break
|
|
|
if ema is not None:
|
|
if ema is not None:
|
|
|
- self.net.set_dict(weight)
|
|
|
|
|
|
|
+ self.net.set_state_dict(weight)
|
|
|
|
|
|
|
|
def analyze_sensitivity(self,
|
|
def analyze_sensitivity(self,
|
|
|
dataset,
|
|
dataset,
|
|
@@ -475,12 +509,21 @@ class BaseModel:
|
|
|
# Types of layers that will be quantized.
|
|
# Types of layers that will be quantized.
|
|
|
'quantizable_layer_type': ['Conv2D', 'Linear']
|
|
'quantizable_layer_type': ['Conv2D', 'Linear']
|
|
|
}
|
|
}
|
|
|
- self.quant_config = quant_config
|
|
|
|
|
- self.quantizer = QAT(config=self.quant_config)
|
|
|
|
|
- logging.info("Preparing the model for quantization-aware training...")
|
|
|
|
|
- self.quantizer.quantize(self.net)
|
|
|
|
|
- logging.info("Model is ready for quantization-aware training.")
|
|
|
|
|
- self.status = 'Quantized'
|
|
|
|
|
|
|
+ if self.status != 'Quantized':
|
|
|
|
|
+ self.quant_config = quant_config
|
|
|
|
|
+ self.quantizer = QAT(config=self.quant_config)
|
|
|
|
|
+ logging.info(
|
|
|
|
|
+ "Preparing the model for quantization-aware training...")
|
|
|
|
|
+ self.quantizer.quantize(self.net)
|
|
|
|
|
+ logging.info("Model is ready for quantization-aware training.")
|
|
|
|
|
+ self.status = 'Quantized'
|
|
|
|
|
+ elif quant_config != self.quant_config:
|
|
|
|
|
+ logging.error(
|
|
|
|
|
+ "The model has been quantized with the following quant_config: {}."
|
|
|
|
|
+ "Doing quantization-aware training with a quantized model "
|
|
|
|
|
+ "using a different configuration is not supported."
|
|
|
|
|
+ .format(self.quant_config),
|
|
|
|
|
+ exit=True)
|
|
|
|
|
|
|
|
def _export_inference_model(self, save_dir, image_shape=None):
|
|
def _export_inference_model(self, save_dir, image_shape=None):
|
|
|
save_dir = osp.join(save_dir, 'inference_model')
|
|
save_dir = osp.join(save_dir, 'inference_model')
|