will-jl944 4 tahun lalu
induk
melakukan
fce21d27e7

+ 46 - 6
dygraph/paddlex/cv/models/base.py

@@ -22,6 +22,7 @@ import yaml
 import json
 import paddle
 from paddle.io import DataLoader, DistributedBatchSampler
+from paddleslim import QAT
 from paddleslim.analysis import flops
 from paddleslim import L1NormFilterPruner, FPGMFilterPruner
 import paddlex
@@ -53,6 +54,7 @@ class BaseModel:
         self.completed_epochs = 0
         self.pruner = None
         self.pruning_ratios = None
+        self.quanter = None
 
     def net_initialize(self, pretrain_weights=None, save_dir='.'):
         if pretrain_weights is not None and \
@@ -129,10 +131,15 @@ class BaseModel:
             os.makedirs(save_dir)
         model_info = self.get_model_info()
         model_info['status'] = self.status
-        paddle.save(self.net.state_dict(),
-                    os.path.join(save_dir, 'model.pdparams'))
-        paddle.save(self.optimizer.state_dict(),
-                    os.path.join(save_dir, 'model.pdopt'))
+
+        if self.status == 'Quantized':
+            self.quanter.save_quantized_model(
+                self.net, save_dir, input_spec=self.test_inputs)
+        else:
+            paddle.save(self.net.state_dict(),
+                        os.path.join(save_dir, 'model.pdparams'))
+            paddle.save(self.optimizer.state_dict(),
+                        os.path.join(save_dir, 'model.pdopt'))
 
         with open(
                 osp.join(save_dir, 'model.yml'), encoding='utf-8',
@@ -400,8 +407,8 @@ class BaseModel:
 
         Args:
             pruned_flops(float): Ratio of FLOPs to be pruned.
-            save_dir(None or str, optional): If None, the pruned model will not be saved
-            Otherwise, the pruned model will be saved at save_dir. Defaults to None.
+            save_dir(None or str, optional): If None, the pruned model will not be saved.
+                Otherwise, the pruned model will be saved at save_dir. Defaults to None.
 
         """
         if self.status == "Pruned":
@@ -427,6 +434,39 @@ class BaseModel:
             self.save_model(save_dir)
             logging.info("Pruned model is saved at {}".format(save_dir))
 
+    def _prepare_qat(self, quant_config, image_shape):
+        if quant_config is None:
+            # default quantization configuration
+            quant_config = {
+                # {None, 'PACT'}. Weight preprocess type. If None, no preprocessing is performed.
+                'weight_preprocess_type': None,
+                # {None, 'PACT'}. Activation preprocess type. If None, no preprocessing is performed.
+                'activation_preprocess_type': None,
+                # {'abs_max', 'channel_wise_abs_max', 'range_abs_max', 'moving_average_abs_max'}.
+                # Weight quantization type.
+                'weight_quantize_type': 'channel_wise_abs_max',
+                # {'abs_max', 'range_abs_max', 'moving_average_abs_max'}. Activation quantization type.
+                'activation_quantize_type': 'moving_average_abs_max',
+                # The number of bits of weights after quantization.
+                'weight_bits': 8,
+                # The number of bits of activation after quantization.
+                'activation_bits': 8,
+                # Data type after quantization, such as 'uint8', 'int8', etc.
+                'dtype': 'int8',
+                # Window size for 'range_abs_max' quantization.
+                'window_size': 10000,
+                # Decay coefficient of moving average.
+                'moving_rate': .9,
+                # Types of layers that will be quantized.
+                'quantizable_layer_type': ['Conv2D', 'Linear']
+            }
+        self.quanter = QAT(config=quant_config)
+        logging.info("Preparing the model for quantization-aware training...")
+        self.quanter.quantize(self.net)
+        logging.info("Model is ready for quantization-aware training.")
+        self.status = 'Quantized'
+        self.test_inputs = self.get_test_inputs(image_shape)
+
     def _export_inference_model(self, save_dir, image_shape=[-1, -1]):
         save_dir = osp.join(save_dir, 'inference_model')
         self.net.eval()

+ 71 - 0
dygraph/paddlex/cv/models/classifier.py

@@ -21,6 +21,7 @@ import paddle
 from paddle import to_tensor
 import paddle.nn.functional as F
 from paddle.static import InputSpec
+from paddleslim import QAT
 from paddlex.utils import logging, TrainingStats, DisablePrint
 from paddlex.cv.models.base import BaseModel
 from paddlex.cv.transforms import arrange_transforms
@@ -242,6 +243,76 @@ class BaseClassifier(BaseModel):
             early_stop_patience=early_stop_patience,
             use_vdl=use_vdl)
 
+    def quant_aware_train(self,
+                          num_epochs,
+                          train_dataset,
+                          train_batch_size=64,
+                          eval_dataset=None,
+                          optimizer=None,
+                          save_interval_epochs=1,
+                          log_interval_steps=10,
+                          save_dir='output',
+                          pretrain_weights='IMAGENET',
+                          learning_rate=.025,
+                          warmup_steps=0,
+                          warmup_start_lr=0.0,
+                          lr_decay_epochs=(30, 60, 90),
+                          lr_decay_gamma=0.1,
+                          early_stop=False,
+                          early_stop_patience=5,
+                          use_vdl=True,
+                          infer_image_shape=[-1, -1],
+                          quant_config=None):
+        """
+        Quantization-aware training.
+        Args:
+            num_epochs(int): The number of epochs.
+            train_dataset(paddlex.dataset): Training dataset.
+            train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 64.
+            eval_dataset(paddlex.dataset, optional):
+                Evaluation dataset. If None, the model will not be evaluated during training process. Defaults to None.
+            optimizer(paddle.optimizer.Optimizer or None, optional):
+                Optimizer used for training. If None, a default optimizer is used. Defaults to None.
+            save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
+            log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
+            save_dir(str, optional): Directory to save the model. Defaults to 'output'.
+            pretrain_weights(str or None, optional):
+                None or name/path of pretrained weights. If None, no pretrained weights will be loaded. Defaults to 'IMAGENET'.
+            learning_rate(float, optional): Learning rate for training. Defaults to .025.
+            warmup_steps(int, optional): The number of steps of warm-up training. Defaults to 0.
+            warmup_start_lr(float, optional): Start learning rate of warm-up training. Defaults to 0..
+            lr_decay_epochs(List[int] or Tuple[int], optional):
+                Epoch milestones for learning rate decay. Defaults to (20, 60, 90).
+            lr_decay_gamma(float, optional): Gamma coefficient of learning rate decay, default .1.
+            early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
+            early_stop_patience(int, optional): Early stop patience. Defaults to 5.
+            use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
+            infer_image_shape(List[int], optional): The shape of input images during inference process, in [w, h] format.
+                If the shape of images is variable, set `infer_image_shape` to [-1, -1]. Defaults to [-1, -1].
+            quant_config(dict or None, optional): Quantization configuration. If None, a default rule of thumb
+                configuration will be used. Defaults to None.
+
+        """
+        self._prepare_qat(quant_config, infer_image_shape)
+        self.train(
+            num_epochs=num_epochs,
+            train_dataset=train_dataset,
+            train_batch_size=train_batch_size,
+            eval_dataset=eval_dataset,
+            optimizer=optimizer,
+            save_interval_epochs=save_interval_epochs,
+            log_interval_steps=log_interval_steps,
+            save_dir=save_dir,
+            pretrain_weights=pretrain_weights,
+            learning_rate=learning_rate,
+            warmup_steps=warmup_steps,
+            warmup_start_lr=warmup_start_lr,
+            lr_decay_epochs=lr_decay_epochs,
+            lr_decay_gamma=lr_decay_gamma,
+            early_stop=early_stop,
+            early_stop_patience=early_stop_patience,
+            use_vdl=use_vdl)
+
     def evaluate(self, eval_dataset, batch_size=1, return_details=False):
         """
         Evaluate the model.

+ 76 - 0
dygraph/paddlex/cv/models/detector.py

@@ -246,6 +246,82 @@ class BaseDetector(BaseModel):
             early_stop_patience=early_stop_patience,
             use_vdl=use_vdl)
 
+    def quant_aware_train(self,
+                          num_epochs,
+                          train_dataset,
+                          train_batch_size=64,
+                          eval_dataset=None,
+                          optimizer=None,
+                          save_interval_epochs=1,
+                          log_interval_steps=10,
+                          save_dir='output',
+                          pretrain_weights='IMAGENET',
+                          learning_rate=.001,
+                          warmup_steps=0,
+                          warmup_start_lr=0.0,
+                          lr_decay_epochs=(216, 243),
+                          lr_decay_gamma=0.1,
+                          metric=None,
+                          use_ema=False,
+                          early_stop=False,
+                          early_stop_patience=5,
+                          use_vdl=True,
+                          infer_image_shape=[-1, -1],
+                          quant_config=None):
+        """
+        Quantization-aware training.
+        Args:
+            num_epochs(int): The number of epochs.
+            train_dataset(paddlex.dataset): Training dataset.
+            train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 64.
+            eval_dataset(paddlex.dataset, optional):
+                Evaluation dataset. If None, the model will not be evaluated during training process. Defaults to None.
+            optimizer(paddle.optimizer.Optimizer or None, optional):
+                Optimizer used for training. If None, a default optimizer is used. Defaults to None.
+            save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
+            log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
+            save_dir(str, optional): Directory to save the model. Defaults to 'output'.
+            pretrain_weights(str or None, optional):
+                None or name/path of pretrained weights. If None, no pretrained weights will be loaded. Defaults to 'IMAGENET'.
+            learning_rate(float, optional): Learning rate for training. Defaults to .001.
+            warmup_steps(int, optional): The number of steps of warm-up training. Defaults to 0.
+            warmup_start_lr(float, optional): Start learning rate of warm-up training. Defaults to 0..
+            lr_decay_epochs(list or tuple, optional): Epoch milestones for learning rate decay. Defaults to (216, 243).
+            lr_decay_gamma(float, optional): Gamma coefficient of learning rate decay. Defaults to .1.
+            metric({'VOC', 'COCO', None}, optional):
+                Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
+            use_ema(bool, optional): Whether to use exponential moving average strategy. Defaults to False.
+            early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
+            early_stop_patience(int, optional): Early stop patience. Defaults to 5.
+            use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
+            infer_image_shape(List[int], optional): The shape of input images during inference process, in [w, h] format.
+                If the shape of images is variable, set `infer_image_shape` to [-1, -1]. Defaults to [-1, -1].
+            quant_config(dict or None, optional): Quantization configuration. If None, a default rule of thumb
+                configuration will be used. Defaults to None.
+
+        """
+        self._prepare_qat(quant_config, infer_image_shape)
+        self.train(
+            num_epochs=num_epochs,
+            train_dataset=train_dataset,
+            train_batch_size=train_batch_size,
+            eval_dataset=eval_dataset,
+            optimizer=optimizer,
+            save_interval_epochs=save_interval_epochs,
+            log_interval_steps=log_interval_steps,
+            save_dir=save_dir,
+            pretrain_weights=pretrain_weights,
+            learning_rate=learning_rate,
+            warmup_steps=warmup_steps,
+            warmup_start_lr=warmup_start_lr,
+            lr_decay_epochs=lr_decay_epochs,
+            lr_decay_gamma=lr_decay_gamma,
+            metric=metric,
+            use_ema=use_ema,
+            early_stop=early_stop,
+            early_stop_patience=early_stop_patience,
+            use_vdl=use_vdl)
+
     def evaluate(self,
                  eval_dataset,
                  batch_size=1,

+ 60 - 0
dygraph/paddlex/cv/models/segmenter.py

@@ -228,6 +228,66 @@ class BaseSegmenter(BaseModel):
             early_stop_patience=early_stop_patience,
             use_vdl=use_vdl)
 
+    def quant_aware_train(self,
+                          num_epochs,
+                          train_dataset,
+                          train_batch_size=2,
+                          eval_dataset=None,
+                          optimizer=None,
+                          save_interval_epochs=1,
+                          log_interval_steps=2,
+                          save_dir='output',
+                          pretrain_weights='CITYSCAPES',
+                          learning_rate=0.01,
+                          lr_decay_power=0.9,
+                          early_stop=False,
+                          early_stop_patience=5,
+                          use_vdl=True,
+                          infer_image_shape=[-1, -1],
+                          quant_config=None):
+        """
+        Quantization-aware training.
+        Args:
+            num_epochs(int): The number of epochs.
+            train_dataset(paddlex.dataset): Training dataset.
+            train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 2.
+            eval_dataset(paddlex.dataset, optional):
+                Evaluation dataset. If None, the model will not be evaluated furing training process. Defaults to None.
+            optimizer(paddle.optimizer.Optimizer or None, optional):
+                Optimizer used in training. If None, a default optimizer is used. Defaults to None.
+            save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
+            log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
+            save_dir(str, optional): Directory to save the model. Defaults to 'output'.
+            pretrain_weights(str or None, optional):
+                None or name/path of pretrained weights. If None, no pretrained weights will be loaded. Defaults to 'IMAGENET'.
+            learning_rate(float, optional): Learning rate for training. Defaults to .025.
+            lr_decay_power(float, optional): Learning decay power. Defaults to .9.
+            early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
+            early_stop_patience(int, optional): Early stop patience. Defaults to 5.
+            use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
+            infer_image_shape(List[int], optional): The shape of input images during inference process, in [w, h] format.
+                If the shape of images is variable, set `infer_image_shape` to [-1, -1]. Defaults to [-1, -1].
+            quant_config(dict or None, optional): Quantization configuration. If None, a default rule of thumb
+                configuration will be used. Defaults to None.
+
+        """
+        self._prepare_qat(quant_config, infer_image_shape)
+        self.train(
+            num_epochs=num_epochs,
+            train_dataset=train_dataset,
+            train_batch_size=train_batch_size,
+            eval_dataset=eval_dataset,
+            optimizer=optimizer,
+            save_interval_epochs=save_interval_epochs,
+            log_interval_steps=log_interval_steps,
+            save_dir=save_dir,
+            pretrain_weights=pretrain_weights,
+            learning_rate=learning_rate,
+            lr_decay_power=lr_decay_power,
+            early_stop=early_stop,
+            early_stop_patience=early_stop_patience,
+            use_vdl=use_vdl)
+
     def evaluate(self, eval_dataset, batch_size=1, return_details=False):
         """
         Evaluate the model.