Просмотр исходного кода

Merge pull request #670 from FlyingQianMM/develop_qh

fix pretrain_weight directory whose compressed name and decompressed name are not same when gui_mode is True
FlyingQianMM 4 лет назад
Родитель
Сommit
53b4c5626c
1 измененных файлов с 21 добавлено и 3 удалено
  1. 21 3
      paddlex/cv/models/utils/pretrain_weights.py

+ 21 - 3
paddlex/cv/models/utils/pretrain_weights.py

@@ -215,7 +215,13 @@ def get_pretrain_weights(flag, class_name, backbone, save_dir):
             url = image_pretrain[backbone]
             fname = osp.split(url)[-1].split('.')[0]
             paddlex.utils.download_and_decompress(url, path=new_save_dir)
-            return osp.join(new_save_dir, fname)
+            if not osp.exists(osp.join(new_save_dir, fname)):
+                for f in os.listdir(new_save_dir):
+                    dir_name = osp.join(new_save_dir, f)
+                    if osp.isdir(dir_name) and fname.split('_')[0] in dir_name:
+                        return dir_name
+            else:
+                return osp.join(new_save_dir, fname)
 
         import paddlehub as hub
         try:
@@ -255,7 +261,13 @@ def get_pretrain_weights(flag, class_name, backbone, save_dir):
 
         if getattr(paddlex, 'gui_mode', False):
             paddlex.utils.download_and_decompress(url, path=new_save_dir)
-            return osp.join(new_save_dir, fname)
+            if not osp.exists(osp.join(new_save_dir, fname)):
+                for f in os.listdir(new_save_dir):
+                    dir_name = osp.join(new_save_dir, f)
+                    if osp.isdir(dir_name) and fname.split('_')[0] in dir_name:
+                        return dir_name
+            else:
+                return osp.join(new_save_dir, fname)
 
         import paddlehub as hub
         try:
@@ -288,7 +300,13 @@ def get_pretrain_weights(flag, class_name, backbone, save_dir):
 
         if getattr(paddlex, 'gui_mode', False):
             paddlex.utils.download_and_decompress(url, path=new_save_dir)
-            return osp.join(new_save_dir, fname)
+            if not osp.exists(osp.join(new_save_dir, fname)):
+                for f in os.listdir(new_save_dir):
+                    dir_name = osp.join(new_save_dir, f)
+                    if osp.isdir(dir_name) and fname.split('_')[0] in dir_name:
+                        return dir_name
+            else:
+                return osp.join(new_save_dir, fname)
 
         import paddlehub as hub
         try: