|
@@ -37,7 +37,7 @@ __all__ = [
|
|
|
"DenseNet121", "DenseNet161", "DenseNet169", "DenseNet201", "DenseNet264",
|
|
"DenseNet121", "DenseNet161", "DenseNet169", "DenseNet201", "DenseNet264",
|
|
|
"HRNet_W18_C", "HRNet_W30_C", "HRNet_W32_C", "HRNet_W40_C", "HRNet_W44_C",
|
|
"HRNet_W18_C", "HRNet_W30_C", "HRNet_W32_C", "HRNet_W40_C", "HRNet_W44_C",
|
|
|
"HRNet_W48_C", "HRNet_W64_C", "Xception41", "Xception65", "Xception71",
|
|
"HRNet_W48_C", "HRNet_W64_C", "Xception41", "Xception65", "Xception71",
|
|
|
- "ShuffleNetV2", "ShuffleNetV2_swish"
|
|
|
|
|
|
|
+ "ShuffleNetV2", "ShuffleNetV2_swish", "PPLCNet", "PPLCNet_ssld"
|
|
|
]
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
@@ -144,40 +144,78 @@ class BaseClassifier(BaseModel):
|
|
|
|
|
|
|
|
return outputs
|
|
return outputs
|
|
|
|
|
|
|
|
- def default_optimizer(self, parameters, learning_rate, warmup_steps,
|
|
|
|
|
- warmup_start_lr, lr_decay_epochs, lr_decay_gamma,
|
|
|
|
|
- num_steps_each_epoch):
|
|
|
|
|
- boundaries = [b * num_steps_each_epoch for b in lr_decay_epochs]
|
|
|
|
|
- values = [
|
|
|
|
|
- learning_rate * (lr_decay_gamma**i)
|
|
|
|
|
- for i in range(len(lr_decay_epochs) + 1)
|
|
|
|
|
- ]
|
|
|
|
|
- scheduler = paddle.optimizer.lr.PiecewiseDecay(boundaries, values)
|
|
|
|
|
- if warmup_steps > 0:
|
|
|
|
|
- if warmup_steps > lr_decay_epochs[0] * num_steps_each_epoch:
|
|
|
|
|
|
|
+ def default_optimizer(self,
|
|
|
|
|
+ parameters,
|
|
|
|
|
+ learning_rate,
|
|
|
|
|
+ warmup_steps,
|
|
|
|
|
+ warmup_start_lr,
|
|
|
|
|
+ lr_decay_epochs,
|
|
|
|
|
+ lr_decay_gamma,
|
|
|
|
|
+ num_steps_each_epoch,
|
|
|
|
|
+ reg_coeff=1e-04,
|
|
|
|
|
+ scheduler='Piecewise',
|
|
|
|
|
+ num_epochs=None):
|
|
|
|
|
+ if scheduler.lower() == 'piecewise':
|
|
|
|
|
+ if warmup_steps > 0 and warmup_steps > lr_decay_epochs[
|
|
|
|
|
+ 0] * num_steps_each_epoch:
|
|
|
|
|
+ logging.error(
|
|
|
|
|
+ "In function train(), parameters must satisfy: "
|
|
|
|
|
+ "warmup_steps <= lr_decay_epochs[0] * num_samples_in_train_dataset. "
|
|
|
|
|
+ "See this doc for more information: "
|
|
|
|
|
+ "https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/parameters.md",
|
|
|
|
|
+ exit=False)
|
|
|
|
|
+ logging.error(
|
|
|
|
|
+ "Either `warmup_steps` be less than {} or lr_decay_epochs[0] be greater than {} "
|
|
|
|
|
+ "must be satisfied, please modify 'warmup_steps' or 'lr_decay_epochs' in train function".
|
|
|
|
|
+ format(lr_decay_epochs[0] * num_steps_each_epoch,
|
|
|
|
|
+ warmup_steps // num_steps_each_epoch),
|
|
|
|
|
+ exit=True)
|
|
|
|
|
+ boundaries = [b * num_steps_each_epoch for b in lr_decay_epochs]
|
|
|
|
|
+ values = [
|
|
|
|
|
+ learning_rate * (lr_decay_gamma**i)
|
|
|
|
|
+ for i in range(len(lr_decay_epochs) + 1)
|
|
|
|
|
+ ]
|
|
|
|
|
+ scheduler = paddle.optimizer.lr.PiecewiseDecay(boundaries, values)
|
|
|
|
|
+ elif scheduler.lower() == 'cosine':
|
|
|
|
|
+ if num_epochs is None:
|
|
|
logging.error(
|
|
logging.error(
|
|
|
- "In function train(), parameters should satisfy: "
|
|
|
|
|
- "warmup_steps <= lr_decay_epochs[0]*num_samples_in_train_dataset.",
|
|
|
|
|
|
|
+ "`num_epochs` must be set while using cosine annealing decay scheduler, but received {}".
|
|
|
|
|
+ format(num_epochs),
|
|
|
exit=False)
|
|
exit=False)
|
|
|
|
|
+ if warmup_steps > 0 and warmup_steps > num_epochs * num_steps_each_epoch:
|
|
|
logging.error(
|
|
logging.error(
|
|
|
|
|
+ "In function train(), parameters must satisfy: "
|
|
|
|
|
+ "warmup_steps <= num_epochs * num_samples_in_train_dataset. "
|
|
|
"See this doc for more information: "
|
|
"See this doc for more information: "
|
|
|
"https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/parameters.md",
|
|
"https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/parameters.md",
|
|
|
exit=False)
|
|
exit=False)
|
|
|
logging.error(
|
|
logging.error(
|
|
|
- "warmup_steps should less than {} or lr_decay_epochs[0] greater than {}, "
|
|
|
|
|
- "please modify 'lr_decay_epochs' or 'warmup_steps' in train function".
|
|
|
|
|
- format(lr_decay_epochs[0] * num_steps_each_epoch,
|
|
|
|
|
- warmup_steps // num_steps_each_epoch))
|
|
|
|
|
|
|
+ "`warmup_steps` must be less than the total number of steps({}), "
|
|
|
|
|
+ "please modify 'num_epochs' or 'warmup_steps' in train function".
|
|
|
|
|
+ format(num_epochs * num_steps_each_epoch),
|
|
|
|
|
+ exit=True)
|
|
|
|
|
+ T_max = num_epochs * num_steps_each_epoch - warmup_steps
|
|
|
|
|
+ scheduler = paddle.optimizer.lr.CosineAnnealingDecay(
|
|
|
|
|
+ learning_rate=learning_rate,
|
|
|
|
|
+ T_max=T_max,
|
|
|
|
|
+ eta_min=0.0,
|
|
|
|
|
+ last_epoch=-1)
|
|
|
|
|
+ else:
|
|
|
|
|
+ logging.error(
|
|
|
|
|
+ "Invalid learning rate scheduler: {}!".format(scheduler),
|
|
|
|
|
+ exit=True)
|
|
|
|
|
|
|
|
|
|
+ if warmup_steps > 0:
|
|
|
scheduler = paddle.optimizer.lr.LinearWarmup(
|
|
scheduler = paddle.optimizer.lr.LinearWarmup(
|
|
|
learning_rate=scheduler,
|
|
learning_rate=scheduler,
|
|
|
warmup_steps=warmup_steps,
|
|
warmup_steps=warmup_steps,
|
|
|
start_lr=warmup_start_lr,
|
|
start_lr=warmup_start_lr,
|
|
|
end_lr=learning_rate)
|
|
end_lr=learning_rate)
|
|
|
|
|
+
|
|
|
optimizer = paddle.optimizer.Momentum(
|
|
optimizer = paddle.optimizer.Momentum(
|
|
|
scheduler,
|
|
scheduler,
|
|
|
momentum=.9,
|
|
momentum=.9,
|
|
|
- weight_decay=paddle.regularizer.L2Decay(coeff=1e-04),
|
|
|
|
|
|
|
+ weight_decay=paddle.regularizer.L2Decay(coeff=reg_coeff),
|
|
|
parameters=parameters)
|
|
parameters=parameters)
|
|
|
return optimizer
|
|
return optimizer
|
|
|
|
|
|
|
@@ -836,3 +874,116 @@ class ShuffleNetV2_swish(BaseClassifier):
|
|
|
shape=image_shape, name='image', dtype='float32')
|
|
shape=image_shape, name='image', dtype='float32')
|
|
|
]
|
|
]
|
|
|
return input_spec
|
|
return input_spec
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class PPLCNet(BaseClassifier):
|
|
|
|
|
+ def __init__(self, num_classes=1000, scale=1., **params):
|
|
|
|
|
+ supported_scale = [.25, .35, .5, .75, 1., 1.5, 2., 2.5]
|
|
|
|
|
+ if scale not in supported_scale:
|
|
|
|
|
+ logging.warning("scale={} is not supported by PPLCNet, "
|
|
|
|
|
+ "scale is forcibly set to 1.0".format(scale))
|
|
|
|
|
+ scale = 1.0
|
|
|
|
|
+ model_name = 'PPLCNet_x' + str(float(scale)).replace('.', '_')
|
|
|
|
|
+ super(PPLCNet, self).__init__(
|
|
|
|
|
+ model_name=model_name, num_classes=num_classes, **params)
|
|
|
|
|
+
|
|
|
|
|
+ def train(self,
|
|
|
|
|
+ num_epochs,
|
|
|
|
|
+ train_dataset,
|
|
|
|
|
+ train_batch_size=64,
|
|
|
|
|
+ eval_dataset=None,
|
|
|
|
|
+ optimizer=None,
|
|
|
|
|
+ save_interval_epochs=1,
|
|
|
|
|
+ log_interval_steps=10,
|
|
|
|
|
+ save_dir='output',
|
|
|
|
|
+ pretrain_weights='IMAGENET',
|
|
|
|
|
+ learning_rate=.1,
|
|
|
|
|
+ warmup_steps=0,
|
|
|
|
|
+ warmup_start_lr=0.0,
|
|
|
|
|
+ lr_decay_epochs=(30, 60, 90),
|
|
|
|
|
+ lr_decay_gamma=0.1,
|
|
|
|
|
+ label_smoothing=None,
|
|
|
|
|
+ early_stop=False,
|
|
|
|
|
+ early_stop_patience=5,
|
|
|
|
|
+ use_vdl=True,
|
|
|
|
|
+ resume_checkpoint=None):
|
|
|
|
|
+ """
|
|
|
|
|
+ Train the model.
|
|
|
|
|
+ Args:
|
|
|
|
|
+ num_epochs(int): The number of epochs.
|
|
|
|
|
+ train_dataset(paddlex.dataset): Training dataset.
|
|
|
|
|
+ train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 64.
|
|
|
|
|
+ eval_dataset(paddlex.dataset, optional):
|
|
|
|
|
+ Evaluation dataset. If None, the model will not be evaluated during training process. Defaults to None.
|
|
|
|
|
+ optimizer(paddle.optimizer.Optimizer or None, optional):
|
|
|
|
|
+ Optimizer used for training. If None, a default optimizer is used. Defaults to None.
|
|
|
|
|
+ save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
|
|
|
|
|
+ log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
|
|
|
|
|
+ save_dir(str, optional): Directory to save the model. Defaults to 'output'.
|
|
|
|
|
+ pretrain_weights(str or None, optional):
|
|
|
|
|
+ None or name/path of pretrained weights. If None, no pretrained weights will be loaded.
|
|
|
|
|
+ At most one of `resume_checkpoint` and `pretrain_weights` can be set simultaneously.
|
|
|
|
|
+ Defaults to 'IMAGENET'.
|
|
|
|
|
+ learning_rate(float, optional): Learning rate for training. Defaults to .025.
|
|
|
|
|
+ warmup_steps(int, optional): The number of steps of warm-up training. Defaults to 0.
|
|
|
|
|
+ warmup_start_lr(float, optional): Start learning rate of warm-up training. Defaults to 0..
|
|
|
|
|
+ lr_decay_epochs(List[int] or Tuple[int], optional):
|
|
|
|
|
+ Epoch milestones for learning rate decay. Defaults to (20, 60, 90).
|
|
|
|
|
+ lr_decay_gamma(float, optional): Gamma coefficient of learning rate decay, default .1.
|
|
|
|
|
+ label_smoothing(float, bool or None, optional): Whether to adopt label smoothing or not.
|
|
|
|
|
+ If float, the value refer to epsilon coefficient of label smoothing. If False or None, label smoothing
|
|
|
|
|
+ will not be adopted. Otherwise, adopt label smoothing with epsilon equals to 0.1. Defaults to None.
|
|
|
|
|
+ early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
|
|
|
|
|
+ early_stop_patience(int, optional): Early stop patience. Defaults to 5.
|
|
|
|
|
+ use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
|
|
|
|
|
+ resume_checkpoint(str or None, optional): The path of the checkpoint to resume training from.
|
|
|
|
|
+ If None, no training checkpoint will be resumed. At most one of `resume_checkpoint` and
|
|
|
|
|
+ `pretrain_weights` can be set simultaneously. Defaults to None.
|
|
|
|
|
+
|
|
|
|
|
+ """
|
|
|
|
|
+ if optimizer is None:
|
|
|
|
|
+ num_steps_each_epoch = len(train_dataset) // train_batch_size
|
|
|
|
|
+ optimizer = self.default_optimizer(
|
|
|
|
|
+ parameters=self.net.parameters(),
|
|
|
|
|
+ learning_rate=learning_rate,
|
|
|
|
|
+ warmup_steps=warmup_steps,
|
|
|
|
|
+ warmup_start_lr=warmup_start_lr,
|
|
|
|
|
+ lr_decay_epochs=lr_decay_epochs,
|
|
|
|
|
+ lr_decay_gamma=lr_decay_gamma,
|
|
|
|
|
+ num_steps_each_epoch=num_steps_each_epoch,
|
|
|
|
|
+ reg_coeff=3e-5,
|
|
|
|
|
+ scheduler='Cosine',
|
|
|
|
|
+ num_epochs=num_epochs)
|
|
|
|
|
+ super(PPLCNet, self).train(
|
|
|
|
|
+ num_epochs=num_epochs,
|
|
|
|
|
+ train_dataset=train_dataset,
|
|
|
|
|
+ train_batch_size=train_batch_size,
|
|
|
|
|
+ eval_dataset=eval_dataset,
|
|
|
|
|
+ optimizer=optimizer,
|
|
|
|
|
+ save_interval_epochs=save_interval_epochs,
|
|
|
|
|
+ log_interval_steps=log_interval_steps,
|
|
|
|
|
+ save_dir=save_dir,
|
|
|
|
|
+ pretrain_weights=pretrain_weights,
|
|
|
|
|
+ learning_rate=learning_rate,
|
|
|
|
|
+ warmup_steps=warmup_steps,
|
|
|
|
|
+ warmup_start_lr=warmup_start_lr,
|
|
|
|
|
+ lr_decay_epochs=lr_decay_epochs,
|
|
|
|
|
+ lr_decay_gamma=lr_decay_gamma,
|
|
|
|
|
+ label_smoothing=label_smoothing,
|
|
|
|
|
+ early_stop=early_stop,
|
|
|
|
|
+ early_stop_patience=early_stop_patience,
|
|
|
|
|
+ use_vdl=use_vdl,
|
|
|
|
|
+ resume_checkpoint=resume_checkpoint)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class PPLCNet_ssld(PPLCNet):
|
|
|
|
|
+ def __init__(self, num_classes=1000, scale=1., **params):
|
|
|
|
|
+ supported_scale = [.5, 1., 2.5]
|
|
|
|
|
+ if scale not in supported_scale:
|
|
|
|
|
+ logging.warning("scale={} is not supported by PPLCNet, "
|
|
|
|
|
+ "scale is forcibly set to 1.0".format(scale))
|
|
|
|
|
+ scale = 1.0
|
|
|
|
|
+ model_name = 'PPLCNet_x' + str(float(scale)).replace('.', '_')
|
|
|
|
|
+ super(PPLCNet, self).__init__(
|
|
|
|
|
+ model_name=model_name, num_classes=num_classes, **params)
|
|
|
|
|
+ self.model_name = model_name + '_ssld'
|