瀏覽代碼

fix the bug that same pretrained weights being downloaded repeatedly

will-jl944 4 年之前
父節點
當前提交
697f421223
共有 1 個文件被更改,包括 2 次插入1 次删除
  1. 2 1
      dygraph/paddlex/utils/checkpoint.py

+ 2 - 1
dygraph/paddlex/utils/checkpoint.py

@@ -16,6 +16,7 @@ import os
 import os.path as osp
 import glob
 import paddle
+import paddlex
 import paddlex.utils.logging as logging
 from .download import download_and_decompress
 
@@ -352,7 +353,7 @@ def get_pretrain_weights(flag, class_name, save_dir, backbone_name=None):
         return flag
 
     # TODO: check flag
-    new_save_dir = save_dir
+    new_save_dir = getattr(paddlex, 'pretrain_dir', save_dir)
     if backbone_name is not None:
         weights_key = "{}_{}_{}".format(class_name, backbone_name, flag)
     else: