|
@@ -424,6 +424,9 @@ def load_pretrain_weights(model, pretrain_weights=None, model_name=None):
|
|
|
def load_optimizer(optimizer, state_dict_path):
|
|
def load_optimizer(optimizer, state_dict_path):
|
|
|
logging.info("Loading optimizer from {}".format(state_dict_path))
|
|
logging.info("Loading optimizer from {}".format(state_dict_path))
|
|
|
optim_state_dict = paddle.load(state_dict_path)
|
|
optim_state_dict = paddle.load(state_dict_path)
|
|
|
|
|
+ for key in optimizer.state_dict().keys():
|
|
|
|
|
+ if key not in optim_state_dict.keys():
|
|
|
|
|
+ optim_state_dict[key] = optimizer.state_dict()[key]
|
|
|
if 'last_epoch' in optim_state_dict:
|
|
if 'last_epoch' in optim_state_dict:
|
|
|
optim_state_dict.pop('last_epoch')
|
|
optim_state_dict.pop('last_epoch')
|
|
|
optimizer.set_state_dict(optim_state_dict)
|
|
optimizer.set_state_dict(optim_state_dict)
|