|
|
@@ -443,7 +443,7 @@ class BaseModel:
|
|
|
self.save_model(save_dir)
|
|
|
logging.info("Pruned model is saved at {}".format(save_dir))
|
|
|
|
|
|
- def _prepare_qat(self, quant_config, image_shape):
|
|
|
+ def _prepare_qat(self, quant_config):
|
|
|
if quant_config is None:
|
|
|
# default quantization configuration
|
|
|
quant_config = {
|
|
|
@@ -475,7 +475,6 @@ class BaseModel:
|
|
|
self.quantizer.quantize(self.net)
|
|
|
logging.info("Model is ready for quantization-aware training.")
|
|
|
self.status = 'Quantized'
|
|
|
- self.test_inputs = self.get_test_inputs(image_shape)
|
|
|
|
|
|
def _export_inference_model(self, save_dir, image_shape=[-1, -1]):
|
|
|
save_dir = osp.join(save_dir, 'inference_model')
|