|
@@ -82,10 +82,11 @@ class BaseClassifier(BaseModel):
|
|
|
def _get_test_inputs(self, image_shape):
|
|
def _get_test_inputs(self, image_shape):
|
|
|
if image_shape is not None:
|
|
if image_shape is not None:
|
|
|
if len(image_shape) == 2:
|
|
if len(image_shape) == 2:
|
|
|
- image_shape = [None, 3] + image_shape
|
|
|
|
|
|
|
+ image_shape = [1, 3] + image_shape
|
|
|
self._fix_transforms_shape(image_shape[-2:])
|
|
self._fix_transforms_shape(image_shape[-2:])
|
|
|
else:
|
|
else:
|
|
|
image_shape = [None, 3, -1, -1]
|
|
image_shape = [None, 3, -1, -1]
|
|
|
|
|
+ self.fixed_input_shape = image_shape
|
|
|
input_spec = [
|
|
input_spec = [
|
|
|
InputSpec(
|
|
InputSpec(
|
|
|
shape=image_shape, name='image', dtype='float32')
|
|
shape=image_shape, name='image', dtype='float32')
|
|
@@ -191,7 +192,8 @@ class BaseClassifier(BaseModel):
|
|
|
lr_decay_gamma=0.1,
|
|
lr_decay_gamma=0.1,
|
|
|
early_stop=False,
|
|
early_stop=False,
|
|
|
early_stop_patience=5,
|
|
early_stop_patience=5,
|
|
|
- use_vdl=True):
|
|
|
|
|
|
|
+ use_vdl=True,
|
|
|
|
|
+ resume_checkpoint=None):
|
|
|
"""
|
|
"""
|
|
|
Train the model.
|
|
Train the model.
|
|
|
Args:
|
|
Args:
|
|
@@ -206,7 +208,9 @@ class BaseClassifier(BaseModel):
|
|
|
log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
|
|
log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
|
|
|
save_dir(str, optional): Directory to save the model. Defaults to 'output'.
|
|
save_dir(str, optional): Directory to save the model. Defaults to 'output'.
|
|
|
pretrain_weights(str or None, optional):
|
|
pretrain_weights(str or None, optional):
|
|
|
- None or name/path of pretrained weights. If None, no pretrained weights will be loaded. Defaults to 'IMAGENET'.
|
|
|
|
|
|
|
+ None or name/path of pretrained weights. If None, no pretrained weights will be loaded.
|
|
|
|
|
+ At most one of `resume_checkpoint` and `pretrain_weights` can be set simultaneously.
|
|
|
|
|
+ Defaults to 'IMAGENET'.
|
|
|
learning_rate(float, optional): Learning rate for training. Defaults to .025.
|
|
learning_rate(float, optional): Learning rate for training. Defaults to .025.
|
|
|
warmup_steps(int, optional): The number of steps of warm-up training. Defaults to 0.
|
|
warmup_steps(int, optional): The number of steps of warm-up training. Defaults to 0.
|
|
|
warmup_start_lr(float, optional): Start learning rate of warm-up training. Defaults to 0..
|
|
warmup_start_lr(float, optional): Start learning rate of warm-up training. Defaults to 0..
|
|
@@ -216,8 +220,15 @@ class BaseClassifier(BaseModel):
|
|
|
early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
|
|
early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
|
|
|
early_stop_patience(int, optional): Early stop patience. Defaults to 5.
|
|
early_stop_patience(int, optional): Early stop patience. Defaults to 5.
|
|
|
use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
|
|
use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
|
|
|
|
|
+ resume_checkpoint(str or None, optional): The path of the checkpoint to resume training from.
|
|
|
|
|
+ If None, no training checkpoint will be resumed. At most one of `resume_checkpoint` and
|
|
|
|
|
+ `pretrain_weights` can be set simultaneously. Defaults to None.
|
|
|
|
|
|
|
|
"""
|
|
"""
|
|
|
|
|
+ if pretrain_weights is not None and resume_checkpoint is not None:
|
|
|
|
|
+ logging.error(
|
|
|
|
|
+ "pretrain_weights and resume_checkpoint cannot be set simultaneously.",
|
|
|
|
|
+ exit=True)
|
|
|
self.labels = train_dataset.labels
|
|
self.labels = train_dataset.labels
|
|
|
|
|
|
|
|
# build optimizer if not defined
|
|
# build optimizer if not defined
|
|
@@ -252,7 +263,9 @@ class BaseClassifier(BaseModel):
|
|
|
exit=True)
|
|
exit=True)
|
|
|
pretrained_dir = osp.join(save_dir, 'pretrain')
|
|
pretrained_dir = osp.join(save_dir, 'pretrain')
|
|
|
self.net_initialize(
|
|
self.net_initialize(
|
|
|
- pretrain_weights=pretrain_weights, save_dir=pretrained_dir)
|
|
|
|
|
|
|
+ pretrain_weights=pretrain_weights,
|
|
|
|
|
+ save_dir=pretrained_dir,
|
|
|
|
|
+ resume_checkpoint=resume_checkpoint)
|
|
|
|
|
|
|
|
# start train loop
|
|
# start train loop
|
|
|
self.train_loop(
|
|
self.train_loop(
|
|
@@ -284,6 +297,7 @@ class BaseClassifier(BaseModel):
|
|
|
early_stop=False,
|
|
early_stop=False,
|
|
|
early_stop_patience=5,
|
|
early_stop_patience=5,
|
|
|
use_vdl=True,
|
|
use_vdl=True,
|
|
|
|
|
+ resume_checkpoint=None,
|
|
|
quant_config=None):
|
|
quant_config=None):
|
|
|
"""
|
|
"""
|
|
|
Quantization-aware training.
|
|
Quantization-aware training.
|
|
@@ -309,6 +323,8 @@ class BaseClassifier(BaseModel):
|
|
|
use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
|
|
use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
|
|
|
quant_config(dict or None, optional): Quantization configuration. If None, a default rule of thumb
|
|
quant_config(dict or None, optional): Quantization configuration. If None, a default rule of thumb
|
|
|
configuration will be used. Defaults to None.
|
|
configuration will be used. Defaults to None.
|
|
|
|
|
+ resume_checkpoint(str or None, optional): The path of the checkpoint to resume quantization-aware training
|
|
|
|
|
+ from. If None, no training checkpoint will be resumed. Defaults to None.
|
|
|
|
|
|
|
|
"""
|
|
"""
|
|
|
self._prepare_qat(quant_config)
|
|
self._prepare_qat(quant_config)
|
|
@@ -329,7 +345,8 @@ class BaseClassifier(BaseModel):
|
|
|
lr_decay_gamma=lr_decay_gamma,
|
|
lr_decay_gamma=lr_decay_gamma,
|
|
|
early_stop=early_stop,
|
|
early_stop=early_stop,
|
|
|
early_stop_patience=early_stop_patience,
|
|
early_stop_patience=early_stop_patience,
|
|
|
- use_vdl=use_vdl)
|
|
|
|
|
|
|
+ use_vdl=use_vdl,
|
|
|
|
|
+ resume_checkpoint=resume_checkpoint)
|
|
|
|
|
|
|
|
def evaluate(self, eval_dataset, batch_size=1, return_details=False):
|
|
def evaluate(self, eval_dataset, batch_size=1, return_details=False):
|
|
|
"""
|
|
"""
|
|
@@ -554,7 +571,7 @@ class AlexNet(BaseClassifier):
|
|
|
'Please check image shape after transforms is [3, 224, 224], if not, fixed_input_shape '
|
|
'Please check image shape after transforms is [3, 224, 224], if not, fixed_input_shape '
|
|
|
+ 'should be specified manually.')
|
|
+ 'should be specified manually.')
|
|
|
self._fix_transforms_shape(image_shape[-2:])
|
|
self._fix_transforms_shape(image_shape[-2:])
|
|
|
-
|
|
|
|
|
|
|
+ self.fixed_input_shape = image_shape
|
|
|
input_spec = [
|
|
input_spec = [
|
|
|
InputSpec(
|
|
InputSpec(
|
|
|
shape=image_shape, name='image', dtype='float32')
|
|
shape=image_shape, name='image', dtype='float32')
|
|
@@ -762,6 +779,7 @@ class ShuffleNetV2(BaseClassifier):
|
|
|
'Please check image shape after transforms is [3, 224, 224], if not, fixed_input_shape '
|
|
'Please check image shape after transforms is [3, 224, 224], if not, fixed_input_shape '
|
|
|
+ 'should be specified manually.')
|
|
+ 'should be specified manually.')
|
|
|
self._fix_transforms_shape(image_shape[-2:])
|
|
self._fix_transforms_shape(image_shape[-2:])
|
|
|
|
|
+ self.fixed_input_shape = image_shape
|
|
|
input_spec = [
|
|
input_spec = [
|
|
|
InputSpec(
|
|
InputSpec(
|
|
|
shape=image_shape, name='image', dtype='float32')
|
|
shape=image_shape, name='image', dtype='float32')
|
|
@@ -788,6 +806,7 @@ class ShuffleNetV2_swish(BaseClassifier):
|
|
|
'Please check image shape after transforms is [3, 224, 224], if not, fixed_input_shape '
|
|
'Please check image shape after transforms is [3, 224, 224], if not, fixed_input_shape '
|
|
|
+ 'should be specified manually.')
|
|
+ 'should be specified manually.')
|
|
|
self._fix_transforms_shape(image_shape[-2:])
|
|
self._fix_transforms_shape(image_shape[-2:])
|
|
|
|
|
+ self.fixed_input_shape = image_shape
|
|
|
input_spec = [
|
|
input_spec = [
|
|
|
InputSpec(
|
|
InputSpec(
|
|
|
shape=image_shape, name='image', dtype='float32')
|
|
shape=image_shape, name='image', dtype='float32')
|