Explorar o código

Merge pull request #908 from will-jl944/develop_jf

use deepcopy in ema
FlyingQianMM %!s(int64=4) %!d(string=hai) anos
pai
achega
55b1e53b1f
Modificáronse 1 ficheiros con 1 adicións e 1 borrados
  1. 1 1
      dygraph/paddlex/cv/models/base.py

+ 1 - 1
dygraph/paddlex/cv/models/base.py

@@ -384,7 +384,7 @@ class BaseModel:
 
             # 每间隔save_interval_epochs, 在验证集上评估和对模型进行保存
             if ema is not None:
-                weight = self.net.state_dict()
+                weight = copy.deepcopy(self.net.state_dict())
                 self.net.set_state_dict(ema.apply())
             eval_epoch_tic = time.time()
             if (i + 1) % save_interval_epochs == 0 or i == num_epochs - 1: