Jelajahi Sumber

fix resume checkpoint bug

will-jl944 4 tahun lalu
induk
melakukan
26e1f36e92
1 mengubah file dengan 3 tambahan dan 0 penghapusan
  1. 3 0
      paddlex/utils/checkpoint.py

+ 3 - 0
paddlex/utils/checkpoint.py

@@ -424,6 +424,9 @@ def load_pretrain_weights(model, pretrain_weights=None, model_name=None):
 def load_optimizer(optimizer, state_dict_path):
     logging.info("Loading optimizer from {}".format(state_dict_path))
     optim_state_dict = paddle.load(state_dict_path)
+    for key in optimizer.state_dict().keys():
+        if key not in optim_state_dict.keys():
+            optim_state_dict[key] = optimizer.state_dict()[key]
     if 'last_epoch' in optim_state_dict:
         optim_state_dict.pop('last_epoch')
     optimizer.set_state_dict(optim_state_dict)