Browse Source

remove image_shape from quant_aware_train since no inference model is saved in qat process

will-jl944 4 years ago
parent
commit
2213969dce

+ 1 - 2
dygraph/paddlex/cv/models/base.py

@@ -443,7 +443,7 @@ class BaseModel:
             self.save_model(save_dir)
             logging.info("Pruned model is saved at {}".format(save_dir))
 
-    def _prepare_qat(self, quant_config, image_shape):
+    def _prepare_qat(self, quant_config):
         if quant_config is None:
             # default quantization configuration
             quant_config = {
@@ -475,7 +475,6 @@ class BaseModel:
         self.quantizer.quantize(self.net)
         logging.info("Model is ready for quantization-aware training.")
         self.status = 'Quantized'
-        self.test_inputs = self.get_test_inputs(image_shape)
 
     def _export_inference_model(self, save_dir, image_shape=[-1, -1]):
         save_dir = osp.join(save_dir, 'inference_model')

+ 1 - 2
dygraph/paddlex/cv/models/classifier.py

@@ -259,7 +259,6 @@ class BaseClassifier(BaseModel):
                           early_stop=False,
                           early_stop_patience=5,
                           use_vdl=True,
-                          infer_image_shape=[-1, -1],
                           quant_config=None):
         """
         Quantization-aware training.
@@ -289,7 +288,7 @@ class BaseClassifier(BaseModel):
                 configuration will be used. Defaults to None.
 
         """
-        self._prepare_qat(quant_config, infer_image_shape)
+        self._prepare_qat(quant_config)
         self.train(
             num_epochs=num_epochs,
             train_dataset=train_dataset,

+ 1 - 2
dygraph/paddlex/cv/models/detector.py

@@ -265,7 +265,6 @@ class BaseDetector(BaseModel):
                           early_stop=False,
                           early_stop_patience=5,
                           use_vdl=True,
-                          infer_image_shape=[-1, -1],
                           quant_config=None):
         """
         Quantization-aware training.
@@ -297,7 +296,7 @@ class BaseDetector(BaseModel):
                 configuration will be used. Defaults to None.
 
         """
-        self._prepare_qat(quant_config, infer_image_shape)
+        self._prepare_qat(quant_config)
         self.train(
             num_epochs=num_epochs,
             train_dataset=train_dataset,

+ 1 - 2
dygraph/paddlex/cv/models/segmenter.py

@@ -242,7 +242,6 @@ class BaseSegmenter(BaseModel):
                           early_stop=False,
                           early_stop_patience=5,
                           use_vdl=True,
-                          infer_image_shape=[-1, -1],
                           quant_config=None):
         """
         Quantization-aware training.
@@ -268,7 +267,7 @@ class BaseSegmenter(BaseModel):
                 configuration will be used. Defaults to None.
 
         """
-        self._prepare_qat(quant_config, infer_image_shape)
+        self._prepare_qat(quant_config)
         self.train(
             num_epochs=num_epochs,
             train_dataset=train_dataset,