瀏覽代碼

train and quant_aware_train support resume_checkpoint

will-jl944 4 年之前
父節點
當前提交
fd2093e6b6

+ 56 - 13
dygraph/paddlex/cv/models/base.py

@@ -29,7 +29,7 @@ import paddlex
 from paddlex.cv.transforms import arrange_transforms
 from paddlex.utils import (seconds_to_hms, get_single_card_bs, dict2str,
                            get_pretrain_weights, load_pretrain_weights,
-                           SmoothedValue, TrainingStats,
+                           load_checkpoint, SmoothedValue, TrainingStats,
                            _get_shared_memory_size_in_M, EarlyStop)
 import paddlex.utils.logging as logging
 from .slim.prune import _pruner_eval_fn, _pruner_template_input, sensitive_prune
@@ -57,11 +57,14 @@ class BaseModel:
         self.quantizer = None
         self.quant_config = None
 
-    def net_initialize(self, pretrain_weights=None, save_dir='.'):
+    def net_initialize(self,
+                       pretrain_weights=None,
+                       save_dir='.',
+                       resume_checkpoint=None):
         if pretrain_weights is not None and \
-                not os.path.exists(pretrain_weights):
-            if not os.path.isdir(save_dir):
-                if os.path.exists(save_dir):
+                not osp.exists(pretrain_weights):
+            if not osp.isdir(save_dir):
+                if osp.exists(save_dir):
                     os.remove(save_dir)
                 os.makedirs(save_dir)
             if self.model_type == 'classifier':
@@ -77,6 +80,37 @@ class BaseModel:
         if pretrain_weights is not None:
             load_pretrain_weights(
                 self.net, pretrain_weights, model_name=self.model_name)
+        if resume_checkpoint is not None:
+            if not osp.exists(resume_checkpoint):
+                logging.error(
+                    "The checkpoint path {} to resume training from does not exist."
+                    .format(resume_checkpoint),
+                    exit=True)
+            if not osp.exists(osp.join(resume_checkpoint, 'model.pdparams')):
+                logging.error(
+                    "Model parameter state dictionary file 'model.pdparams' "
+                    "not found under given checkpoint path {}".format(
+                        resume_checkpoint),
+                    exit=True)
+            if not osp.exists(osp.join(resume_checkpoint, 'model.pdopt')):
+                logging.error(
+                    "Optimizer state dictionary file 'model.pdparams' "
+                    "not found under given checkpoint path {}".format(
+                        resume_checkpoint),
+                    exit=True)
+            if not osp.exists(osp.join(resume_checkpoint, 'model.yml')):
+                logging.error(
+                    "'model.yml' not found under given checkpoint path {}".
+                    format(resume_checkpoint),
+                    exit=True)
+            with open(osp.join(resume_checkpoint, "model.yml")) as f:
+                info = yaml.load(f.read(), Loader=yaml.Loader)
+                self.completed_epochs = info['completed_epochs']
+            load_checkpoint(
+                self.net,
+                self.optimizer,
+                model_name=self.model_name,
+                checkpoint=resume_checkpoint)
 
     def get_model_info(self):
         info = dict()
@@ -339,7 +373,7 @@ class BaseModel:
             # 每间隔save_interval_epochs, 在验证集上评估和对模型进行保存
             if ema is not None:
                 weight = self.net.state_dict()
-                self.net.set_dict(ema.apply())
+                self.net.set_state_dict(ema.apply())
             eval_epoch_tic = time.time()
             if (i + 1) % save_interval_epochs == 0 or i == num_epochs - 1:
                 if eval_dataset is not None and eval_dataset.num_samples > 0:
@@ -374,7 +408,7 @@ class BaseModel:
                         if earlystop(current_accuracy):
                             break
             if ema is not None:
-                self.net.set_dict(weight)
+                self.net.set_state_dict(weight)
 
     def analyze_sensitivity(self,
                             dataset,
@@ -475,12 +509,21 @@ class BaseModel:
                 # Types of layers that will be quantized.
                 'quantizable_layer_type': ['Conv2D', 'Linear']
             }
-        self.quant_config = quant_config
-        self.quantizer = QAT(config=self.quant_config)
-        logging.info("Preparing the model for quantization-aware training...")
-        self.quantizer.quantize(self.net)
-        logging.info("Model is ready for quantization-aware training.")
-        self.status = 'Quantized'
+        if self.status != 'Quantized':
+            self.quant_config = quant_config
+            self.quantizer = QAT(config=self.quant_config)
+            logging.info(
+                "Preparing the model for quantization-aware training...")
+            self.quantizer.quantize(self.net)
+            logging.info("Model is ready for quantization-aware training.")
+            self.status = 'Quantized'
+        elif quant_config != self.quant_config:
+            logging.error(
+                "The model has been quantized with the following quant_config: {}."
+                "Doing quantization-aware training with a quantized model "
+                "using a different configuration is not supported."
+                .format(self.quant_config),
+                exit=True)
 
     def _export_inference_model(self, save_dir, image_shape=None):
         save_dir = osp.join(save_dir, 'inference_model')

+ 20 - 4
dygraph/paddlex/cv/models/classifier.py

@@ -191,7 +191,8 @@ class BaseClassifier(BaseModel):
               lr_decay_gamma=0.1,
               early_stop=False,
               early_stop_patience=5,
-              use_vdl=True):
+              use_vdl=True,
+              resume_checkpoint=None):
         """
         Train the model.
         Args:
@@ -206,7 +207,9 @@ class BaseClassifier(BaseModel):
             log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
             save_dir(str, optional): Directory to save the model. Defaults to 'output'.
             pretrain_weights(str or None, optional):
-                None or name/path of pretrained weights. If None, no pretrained weights will be loaded. Defaults to 'IMAGENET'.
+                None or name/path of pretrained weights. If None, no pretrained weights will be loaded.
+                At most one of `resume_checkpoint` and `pretrain_weights` can be set simultaneously.
+                Defaults to 'IMAGENET'.
             learning_rate(float, optional): Learning rate for training. Defaults to .025.
             warmup_steps(int, optional): The number of steps of warm-up training. Defaults to 0.
             warmup_start_lr(float, optional): Start learning rate of warm-up training. Defaults to 0..
@@ -216,8 +219,15 @@ class BaseClassifier(BaseModel):
             early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
             early_stop_patience(int, optional): Early stop patience. Defaults to 5.
             use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
+            resume_checkpoint(str or None, optional): The path of the checkpoint to resume training from.
+                If None, no training checkpoint will be resumed. At most one of `resume_checkpoint` and
+                `pretrain_weights` can be set simultaneously. Defaults to None.
 
         """
+        if pretrain_weights is not None and resume_checkpoint is not None:
+            logging.error(
+                "pretrain_weights and resume_checkpoint cannot be set simultaneously.",
+                exit=True)
         self.labels = train_dataset.labels
 
         # build optimizer if not defined
@@ -252,7 +262,9 @@ class BaseClassifier(BaseModel):
                     exit=True)
         pretrained_dir = osp.join(save_dir, 'pretrain')
         self.net_initialize(
-            pretrain_weights=pretrain_weights, save_dir=pretrained_dir)
+            pretrain_weights=pretrain_weights,
+            save_dir=pretrained_dir,
+            resume_checkpoint=resume_checkpoint)
 
         # start train loop
         self.train_loop(
@@ -284,6 +296,7 @@ class BaseClassifier(BaseModel):
                           early_stop=False,
                           early_stop_patience=5,
                           use_vdl=True,
+                          resume_checkpoint=None,
                           quant_config=None):
         """
         Quantization-aware training.
@@ -309,6 +322,8 @@ class BaseClassifier(BaseModel):
             use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
             quant_config(dict or None, optional): Quantization configuration. If None, a default rule of thumb
                 configuration will be used. Defaults to None.
+            resume_checkpoint(str or None, optional): The path of the checkpoint to resume quantization-aware training
+                from. If None, no training checkpoint will be resumed. Defaults to None.
 
         """
         self._prepare_qat(quant_config)
@@ -329,7 +344,8 @@ class BaseClassifier(BaseModel):
             lr_decay_gamma=lr_decay_gamma,
             early_stop=early_stop,
             early_stop_patience=early_stop_patience,
-            use_vdl=use_vdl)
+            use_vdl=use_vdl,
+            resume_checkpoint=resume_checkpoint)
 
     def evaluate(self, eval_dataset, batch_size=1, return_details=False):
         """

+ 17 - 3
dygraph/paddlex/cv/models/detector.py

@@ -158,7 +158,8 @@ class BaseDetector(BaseModel):
               use_ema=False,
               early_stop=False,
               early_stop_patience=5,
-              use_vdl=True):
+              use_vdl=True,
+              resume_checkpoint=None):
         """
         Train the model.
         Args:
@@ -185,8 +186,15 @@ class BaseDetector(BaseModel):
             early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
             early_stop_patience(int, optional): Early stop patience. Defaults to 5.
             use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
+            resume_checkpoint(str or None, optional): The path of the checkpoint to resume training from.
+                If None, no training checkpoint will be resumed. At most one of `resume_checkpoint` and
+                `pretrain_weights` can be set simultaneously. Defaults to None.
 
         """
+        if pretrain_weights is not None and resume_checkpoint is not None:
+            logging.error(
+                "pretrain_weights and resume_checkpoint cannot be set simultaneously.",
+                exit=True)
         if train_dataset.__class__.__name__ == 'VOCDetection':
             train_dataset.data_fields = {
                 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
@@ -253,7 +261,9 @@ class BaseDetector(BaseModel):
                     exit=True)
         pretrained_dir = osp.join(save_dir, 'pretrain')
         self.net_initialize(
-            pretrain_weights=pretrain_weights, save_dir=pretrained_dir)
+            pretrain_weights=pretrain_weights,
+            save_dir=pretrained_dir,
+            resume_checkpoint=resume_checkpoint)
 
         if use_ema:
             ema = ExponentialMovingAverage(
@@ -293,6 +303,7 @@ class BaseDetector(BaseModel):
                           early_stop=False,
                           early_stop_patience=5,
                           use_vdl=True,
+                          resume_checkpoint=None,
                           quant_config=None):
         """
         Quantization-aware training.
@@ -320,6 +331,8 @@ class BaseDetector(BaseModel):
             use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
             quant_config(dict or None, optional): Quantization configuration. If None, a default rule of thumb
                 configuration will be used. Defaults to None.
+            resume_checkpoint(str or None, optional): The path of the checkpoint to resume quantization-aware training
+                from. If None, no training checkpoint will be resumed. Defaults to None.
 
         """
         self._prepare_qat(quant_config)
@@ -342,7 +355,8 @@ class BaseDetector(BaseModel):
             use_ema=use_ema,
             early_stop=early_stop,
             early_stop_patience=early_stop_patience,
-            use_vdl=use_vdl)
+            use_vdl=use_vdl,
+            resume_checkpoint=resume_checkpoint)
 
     def evaluate(self,
                  eval_dataset,

+ 2 - 2
dygraph/paddlex/cv/models/load_model.py

@@ -107,8 +107,8 @@ def load_model(model_dir):
         if status == 'Quantized':
             with open(osp.join(model_dir, "quant.yml")) as f:
                 quant_info = yaml.load(f.read(), Loader=yaml.Loader)
-                quant_config = quant_info['quant_config']
-                model.quantizer = paddleslim.QAT(quant_config)
+                model.quant_config = quant_info['quant_config']
+                model.quantizer = paddleslim.QAT(model.quant_config)
                 model.quantizer.quantize(model.net)
 
         if status == 'Infer':

+ 17 - 3
dygraph/paddlex/cv/models/segmenter.py

@@ -193,7 +193,8 @@ class BaseSegmenter(BaseModel):
               lr_decay_power=0.9,
               early_stop=False,
               early_stop_patience=5,
-              use_vdl=True):
+              use_vdl=True,
+              resume_checkpoint=None):
         """
         Train the model.
         Args:
@@ -214,8 +215,15 @@ class BaseSegmenter(BaseModel):
             early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
             early_stop_patience(int, optional): Early stop patience. Defaults to 5.
             use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
+            resume_checkpoint(str or None, optional): The path of the checkpoint to resume training from.
+                If None, no training checkpoint will be resumed. At most one of `resume_checkpoint` and
+                `pretrain_weights` can be set simultaneously. Defaults to None.
 
         """
+        if pretrain_weights is not None and resume_checkpoint is not None:
+            logging.error(
+                "pretrain_weights and resume_checkpoint cannot be set simultaneously.",
+                exit=True)
         self.labels = train_dataset.labels
         if self.losses is None:
             self.losses = self.default_loss()
@@ -248,7 +256,9 @@ class BaseSegmenter(BaseModel):
                     exit=True)
         pretrained_dir = osp.join(save_dir, 'pretrain')
         self.net_initialize(
-            pretrain_weights=pretrain_weights, save_dir=pretrained_dir)
+            pretrain_weights=pretrain_weights,
+            save_dir=pretrained_dir,
+            resume_checkpoint=resume_checkpoint)
 
         self.train_loop(
             num_epochs=num_epochs,
@@ -276,6 +286,7 @@ class BaseSegmenter(BaseModel):
                           early_stop=False,
                           early_stop_patience=5,
                           use_vdl=True,
+                          resume_checkpoint=None,
                           quant_config=None):
         """
         Quantization-aware training.
@@ -297,6 +308,8 @@ class BaseSegmenter(BaseModel):
             use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
             quant_config(dict or None, optional): Quantization configuration. If None, a default rule of thumb
                 configuration will be used. Defaults to None.
+            resume_checkpoint(str or None, optional): The path of the checkpoint to resume quantization-aware training
+                from. If None, no training checkpoint will be resumed. Defaults to None.
 
         """
         self._prepare_qat(quant_config)
@@ -314,7 +327,8 @@ class BaseSegmenter(BaseModel):
             lr_decay_power=lr_decay_power,
             early_stop=early_stop,
             early_stop_patience=early_stop_patience,
-            use_vdl=use_vdl)
+            use_vdl=use_vdl,
+            resume_checkpoint=resume_checkpoint)
 
     def evaluate(self, eval_dataset, batch_size=1, return_details=False):
         """

+ 1 - 1
dygraph/paddlex/utils/__init__.py

@@ -17,7 +17,7 @@ from . import utils
 from .utils import (seconds_to_hms, get_encoding, get_single_card_bs, dict2str,
                     EarlyStop, path_normalization, is_pic, MyEncoder,
                     DisablePrint)
-from .checkpoint import get_pretrain_weights, load_pretrain_weights
+from .checkpoint import get_pretrain_weights, load_pretrain_weights, load_checkpoint
 from .env import get_environ_info, get_num_workers, init_parallel_env
 from .download import download_and_decompress, decompress
 from .stats import SmoothedValue, TrainingStats

+ 19 - 1
dygraph/paddlex/utils/checkpoint.py

@@ -394,7 +394,7 @@ def load_pretrain_weights(model, pretrain_weights=None, model_name=None):
                 else:
                     model_state_dict[k] = para_state_dict[k]
                     num_params_loaded += 1
-            model.set_dict(model_state_dict)
+            model.set_state_dict(model_state_dict)
             logging.info("There are {}/{} variables loaded into {}.".format(
                 num_params_loaded, len(model_state_dict), model_name))
         else:
@@ -404,3 +404,21 @@ def load_pretrain_weights(model, pretrain_weights=None, model_name=None):
         logging.info(
             'No pretrained model to load, {} will be trained from scratch.'.
             format(model_name))
+
+
+def load_optimizer(optimizer, state_dict_path):
+    logging.info("Loading optimizer from {}".format(state_dict_path))
+    optim_state_dict = paddle.load(state_dict_path)
+    if 'last_epoch' in optim_state_dict:
+        optim_state_dict.pop('last_epoch')
+    optimizer.set_state_dict(optim_state_dict)
+
+
+def load_checkpoint(model, optimizer, model_name, checkpoint):
+    logging.info("Loading checkpoint from {}".format(checkpoint))
+    load_pretrain_weights(
+        model,
+        pretrain_weights=osp.join(checkpoint, 'model.pdparams'),
+        model_name=model_name)
+    load_optimizer(
+        optimizer, state_dict_path=osp.join(checkpoint, "model.pdopt"))