Parcourir la source

Merge pull request #956 from will-jl944/develop_jf

fix the bug that same pretrained weights being downloaded repeatedly
FlyingQianMM il y a 4 ans
Parent
commit
e8e0cf23a7
1 fichiers modifiés avec 2 ajouts et 1 suppressions
  1. 2 1
      dygraph/paddlex/utils/checkpoint.py

+ 2 - 1
dygraph/paddlex/utils/checkpoint.py

@@ -16,6 +16,7 @@ import os
 import os.path as osp
 import glob
 import paddle
+import paddlex
 import paddlex.utils.logging as logging
 from .download import download_and_decompress
 
@@ -352,7 +353,7 @@ def get_pretrain_weights(flag, class_name, save_dir, backbone_name=None):
         return flag
 
     # TODO: check flag
-    new_save_dir = save_dir
+    new_save_dir = getattr(paddlex, 'pretrain_dir', save_dir)
     if backbone_name is not None:
         weights_key = "{}_{}_{}".format(class_name, backbone_name, flag)
     else: