jiangjiajun 5 سال پیش
والد
کامیت
2ef5e7aca8
1فایلهای تغییر یافته به همراه10 افزوده شده و 7 حذف شده
  1. 10 7
      paddlex/cv/models/utils/pretrain_weights.py

+ 10 - 7
paddlex/cv/models/utils/pretrain_weights.py

@@ -202,11 +202,11 @@ def get_pretrain_weights(flag, class_name, backbone, save_dir):
         assert backbone in image_pretrain, "There is not ImageNet pretrain weights for {}, you may try COCO.".format(
             backbone)
 
-        #        if backbone == 'AlexNet':
-        #            url = image_pretrain[backbone]
-        #            fname = osp.split(url)[-1].split('.')[0]
-        #            paddlex.utils.download_and_decompress(url, path=new_save_dir)
-        #            return osp.join(new_save_dir, fname)
+        if getattr(paddlex, 'gui_mode', False):
+            url = image_pretrain[backbone]
+            fname = osp.split(url)[-1].split('.')[0]
+            paddlex.utils.download_and_decompress(url, path=new_save_dir)
+            return osp.join(new_save_dir, fname)
         try:
             logging.info(
                 "Connecting PaddleHub server to get pretrain weights...")
@@ -241,8 +241,11 @@ def get_pretrain_weights(flag, class_name, backbone, save_dir):
         elif flag == 'CITYSCAPES':
             url = cityscapes_pretrain[backbone]
         fname = osp.split(url)[-1].split('.')[0]
-        #        paddlex.utils.download_and_decompress(url, path=new_save_dir)
-        #        return osp.join(new_save_dir, fname)
+
+        if getattr(paddlex, 'gui_mode', False):
+            paddlex.utils.download_and_decompress(url, path=new_save_dir)
+            return osp.join(new_save_dir, fname)
+
         try:
             logging.info(
                 "Connecting PaddleHub server to get pretrain weights...")