Browse Source

Merge pull request #1197 from will-jl944/dev_bug_fix

fix resume checkpoint bug
will-jl944 4 years ago
parent
commit
326eebd563
1 changed files with 3 additions and 0 deletions
  1. 3 0
      paddlex/utils/checkpoint.py

+ 3 - 0
paddlex/utils/checkpoint.py

@@ -424,6 +424,9 @@ def load_pretrain_weights(model, pretrain_weights=None, model_name=None):
 def load_optimizer(optimizer, state_dict_path):
     logging.info("Loading optimizer from {}".format(state_dict_path))
     optim_state_dict = paddle.load(state_dict_path)
+    for key in optimizer.state_dict().keys():
+        if key not in optim_state_dict.keys():
+            optim_state_dict[key] = optimizer.state_dict()[key]
     if 'last_epoch' in optim_state_dict:
         optim_state_dict.pop('last_epoch')
     optimizer.set_state_dict(optim_state_dict)