Ver Fonte

Merge pull request #956 from will-jl944/develop_jf

fix the bug that same pretrained weights being downloaded repeatedly
FlyingQianMM há 4 anos atrás
pai
commit
e8e0cf23a7
1 ficheiros alterados com 2 adições e 1 exclusões
  1. 2 1
      dygraph/paddlex/utils/checkpoint.py

+ 2 - 1
dygraph/paddlex/utils/checkpoint.py

@@ -16,6 +16,7 @@ import os
 import os.path as osp
 import glob
 import paddle
+import paddlex
 import paddlex.utils.logging as logging
 from .download import download_and_decompress
 
@@ -352,7 +353,7 @@ def get_pretrain_weights(flag, class_name, save_dir, backbone_name=None):
         return flag
 
     # TODO: check flag
-    new_save_dir = save_dir
+    new_save_dir = getattr(paddlex, 'pretrain_dir', save_dir)
     if backbone_name is not None:
         weights_key = "{}_{}_{}".format(class_name, backbone_name, flag)
     else: