Răsfoiți Sursa

quantization aware training

will-jl944 4 ani în urmă
părinte
comite
620898b48a

+ 36 - 17
dygraph/paddlex/cv/models/base.py

@@ -54,7 +54,8 @@ class BaseModel:
         self.completed_epochs = 0
         self.completed_epochs = 0
         self.pruner = None
         self.pruner = None
         self.pruning_ratios = None
         self.pruning_ratios = None
-        self.quanter = None
+        self.quantizer = None
+        self.quant_config = None
 
 
     def net_initialize(self, pretrain_weights=None, save_dir='.'):
     def net_initialize(self, pretrain_weights=None, save_dir='.'):
         if pretrain_weights is not None and \
         if pretrain_weights is not None and \
@@ -124,6 +125,11 @@ class BaseModel:
         info['pruner_inputs'] = self.pruner.inputs
         info['pruner_inputs'] = self.pruner.inputs
         return info
         return info
 
 
+    def get_quant_info(self):
+        info = dict()
+        info['quant_config'] = self.quant_config
+        return self.info
+
     def save_model(self, save_dir):
     def save_model(self, save_dir):
         if not osp.isdir(save_dir):
         if not osp.isdir(save_dir):
             if osp.exists(save_dir):
             if osp.exists(save_dir):
@@ -132,16 +138,10 @@ class BaseModel:
         model_info = self.get_model_info()
         model_info = self.get_model_info()
         model_info['status'] = self.status
         model_info['status'] = self.status
 
 
-        if self.status == 'Quantized':
-            self.quanter.save_quantized_model(
-                self.net,
-                osp.join(save_dir, 'model'),
-                input_spec=self.test_inputs)
-        else:
-            paddle.save(self.net.state_dict(),
-                        osp.join(save_dir, 'model.pdparams'))
-            paddle.save(self.optimizer.state_dict(),
-                        osp.join(save_dir, 'model.pdopt'))
+        paddle.save(self.net.state_dict(),
+                    osp.join(save_dir, 'model.pdparams'))
+        paddle.save(self.optimizer.state_dict(),
+                    osp.join(save_dir, 'model.pdopt'))
 
 
         with open(
         with open(
                 osp.join(save_dir, 'model.yml'), encoding='utf-8',
                 osp.join(save_dir, 'model.yml'), encoding='utf-8',
@@ -160,6 +160,13 @@ class BaseModel:
                     mode='w') as f:
                     mode='w') as f:
                 yaml.dump(pruning_info, f)
                 yaml.dump(pruning_info, f)
 
 
+        if self.status == 'Quantized' and self.quanter is not None:
+            quant_info = self.get_quant_info()
+            with open(
+                    osp.join(save_dir, 'quant.yml'), encoding='utf-8',
+                    mode='w') as f:
+                yaml.dump(quant_info, f)
+
         # 模型保存成功的标志
         # 模型保存成功的标志
         open(osp.join(save_dir, '.success'), 'w').close()
         open(osp.join(save_dir, '.success'), 'w').close()
         logging.info("Model saved in {}.".format(save_dir))
         logging.info("Model saved in {}.".format(save_dir))
@@ -413,7 +420,7 @@ class BaseModel:
                 Otherwise, the pruned model will be saved at save_dir. Defaults to None.
                 Otherwise, the pruned model will be saved at save_dir. Defaults to None.
 
 
         """
         """
-        if self.status == "Pruned":
+        if self.ne == "Pruned":
             raise Exception(
             raise Exception(
                 "A pruned model cannot be done model pruning again!")
                 "A pruned model cannot be done model pruning again!")
         pre_pruning_flops = flops(self.net, self.pruner.inputs)
         pre_pruning_flops = flops(self.net, self.pruner.inputs)
@@ -462,9 +469,10 @@ class BaseModel:
                 # Types of layers that will be quantized.
                 # Types of layers that will be quantized.
                 'quantizable_layer_type': ['Conv2D', 'Linear']
                 'quantizable_layer_type': ['Conv2D', 'Linear']
             }
             }
-        self.quanter = QAT(config=quant_config)
+        self.quant_config = quant_config
+        self.quantizer = QAT(config=self.quant_config)
         logging.info("Preparing the model for quantization-aware training...")
         logging.info("Preparing the model for quantization-aware training...")
-        self.quanter.quantize(self.net)
+        self.quantizer.quantize(self.net)
         logging.info("Model is ready for quantization-aware training.")
         logging.info("Model is ready for quantization-aware training.")
         self.status = 'Quantized'
         self.status = 'Quantized'
         self.test_inputs = self.get_test_inputs(image_shape)
         self.test_inputs = self.get_test_inputs(image_shape)
@@ -473,9 +481,20 @@ class BaseModel:
         save_dir = osp.join(save_dir, 'inference_model')
         save_dir = osp.join(save_dir, 'inference_model')
         self.net.eval()
         self.net.eval()
         self.test_inputs = self.get_test_inputs(image_shape)
         self.test_inputs = self.get_test_inputs(image_shape)
-        static_net = paddle.jit.to_static(
-            self.net, input_spec=self.test_inputs)
-        paddle.jit.save(static_net, osp.join(save_dir, 'model'))
+
+        if self.status == 'Quantized':
+            self.quantizer.save_quantized_model(self.net,
+                                                osp.join(save_dir, 'model'),
+                                                self.test_inputs)
+            quant_info = self.get_quant_info()
+            with open(
+                    osp.join(save_dir, 'quant.yml'), encoding='utf-8',
+                    mode='w') as f:
+                yaml.dump(quant_info, f)
+        else:
+            static_net = paddle.jit.to_static(
+                self.net, input_spec=self.test_inputs)
+            paddle.jit.save(static_net, osp.join(save_dir, 'model'))
 
 
         if self.status == 'Pruned':
         if self.status == 'Pruned':
             pruning_info = self.get_pruning_info()
             pruning_info = self.get_pruning_info()

+ 0 - 1
dygraph/paddlex/cv/models/classifier.py

@@ -21,7 +21,6 @@ import paddle
 from paddle import to_tensor
 from paddle import to_tensor
 import paddle.nn.functional as F
 import paddle.nn.functional as F
 from paddle.static import InputSpec
 from paddle.static import InputSpec
-from paddleslim import QAT
 from paddlex.utils import logging, TrainingStats, DisablePrint
 from paddlex.utils import logging, TrainingStats, DisablePrint
 from paddlex.cv.models.base import BaseModel
 from paddlex.cv.models.base import BaseModel
 from paddlex.cv.transforms import arrange_transforms
 from paddlex.cv.transforms import arrange_transforms

+ 8 - 0
dygraph/paddlex/cv/models/load_model.py

@@ -76,6 +76,14 @@ def load_model(model_dir):
                     ratios=model.pruning_ratios,
                     ratios=model.pruning_ratios,
                     axis=paddleslim.dygraph.prune.filter_pruner.FILTER_DIM)
                     axis=paddleslim.dygraph.prune.filter_pruner.FILTER_DIM)
 
 
+        if status == 'Quantized' or osp.exists(
+                osp.join(model_dir, "quant.yml")):
+            with open(osp.join(model_dir, "quant.yml")) as f:
+                quant_info = yaml.load(f.read(), Loader=yaml.Loader)
+                quant_config = quant_info['quant_config']
+                model.quantizer = paddleslim.QAT(quant_config)
+                model.quantizer.quantize(model.net)
+
         if status == 'Infer':
         if status == 'Infer':
             if model_info['Model'] in ['FasterRCNN', 'MaskRCNN']:
             if model_info['Model'] in ['FasterRCNN', 'MaskRCNN']:
                 net_state_dict = paddle.load(
                 net_state_dict = paddle.load(

+ 1 - 1
dygraph/tutorials/train/image_classification/alexnet.py

@@ -42,6 +42,6 @@ model.train(
     train_batch_size=32,
     train_batch_size=32,
     eval_dataset=eval_dataset,
     eval_dataset=eval_dataset,
     lr_decay_epochs=[4, 6, 8],
     lr_decay_epochs=[4, 6, 8],
-    learning_rate=0.01,
+    learning_rate=0.001,
     save_dir='output/alexnet',
     save_dir='output/alexnet',
     use_vdl=True)
     use_vdl=True)