瀏覽代碼

Merge pull request #1197 from will-jl944/dev_bug_fix

fix resume checkpoint bug
will-jl944 4 年之前
父節點
當前提交
326eebd563
共有 1 個文件被更改,包括 3 次插入0 次删除
  1. 3 0
      paddlex/utils/checkpoint.py

+ 3 - 0
paddlex/utils/checkpoint.py

@@ -424,6 +424,9 @@ def load_pretrain_weights(model, pretrain_weights=None, model_name=None):
 def load_optimizer(optimizer, state_dict_path):
     logging.info("Loading optimizer from {}".format(state_dict_path))
     optim_state_dict = paddle.load(state_dict_path)
+    for key in optimizer.state_dict().keys():
+        if key not in optim_state_dict.keys():
+            optim_state_dict[key] = optimizer.state_dict()[key]
     if 'last_epoch' in optim_state_dict:
         optim_state_dict.pop('last_epoch')
     optimizer.set_state_dict(optim_state_dict)