jiangjiajun пре 5 година
родитељ
комит
e3b27ac580
1 измењених фајлова са 37 додато и 41 уклоњено
  1. 37 41
      paddlex/cv/models/utils/pretrain_weights.py

+ 37 - 41
paddlex/cv/models/utils/pretrain_weights.py

@@ -1,5 +1,5 @@
 import paddlex
-#import paddlehub as hub
+import paddlehub as hub
 import os
 import os.path as osp
 
@@ -85,53 +85,49 @@ def get_pretrain_weights(flag, model_type, backbone, save_dir):
                 backbone = 'DetResNet50'
         assert backbone in image_pretrain, "There is not ImageNet pretrain weights for {}, you may try COCO.".format(
             backbone)
-        url = image_pretrain[backbone]
-        fname = osp.split(url)[-1].split('.')[0]
-        paddlex.utils.download_and_decompress(url, path=new_save_dir)
-        return osp.join(new_save_dir, fname)
-#        try:
-#            hub.download(backbone, save_path=new_save_dir)
-#        except Exception as e:
-#            if isinstance(e, hub.ResourceNotFoundError):
-#                raise Exception(
-#                    "Resource for backbone {} not found".format(backbone))
-#            elif isinstance(e, hub.ServerConnectionError):
-#                raise Exception(
-#                    "Cannot get reource for backbone {}, please check your internet connecgtion"
-#                    .format(backbone))
-#            else:
-#                raise Exception(
-#                    "Unexpected error, please make sure paddlehub >= 1.6.2")
-#        return osp.join(new_save_dir, backbone)
+        #        url = image_pretrain[backbone]
+        #        fname = osp.split(url)[-1].split('.')[0]
+        #        paddlex.utils.download_and_decompress(url, path=new_save_dir)
+        #        return osp.join(new_save_dir, fname)
+        try:
+            hub.download(backbone, save_path=new_save_dir)
+        except Exception as e:
+            if isinstance(e, hub.ResourceNotFoundError):
+                raise Exception("Resource for backbone {} not found".format(
+                    backbone))
+            elif isinstance(e, hub.ServerConnectionError):
+                raise Exception(
+                    "Cannot get reource for backbone {}, please check your internet connecgtion"
+                    .format(backbone))
+            else:
+                raise Exception(
+                    "Unexpected error, please make sure paddlehub >= 1.6.2")
+        return osp.join(new_save_dir, backbone)
     elif flag == 'COCO':
         new_save_dir = save_dir
         if hasattr(paddlex, 'pretrain_dir'):
             new_save_dir = paddlex.pretrain_dir
         url = coco_pretrain[backbone]
         fname = osp.split(url)[-1].split('.')[0]
-        paddlex.utils.download_and_decompress(url, path=new_save_dir)
-        return osp.join(new_save_dir, fname)
-
+        #        paddlex.utils.download_and_decompress(url, path=new_save_dir)
+        #        return osp.join(new_save_dir, fname)
 
-#        new_save_dir = save_dir
-#        if hasattr(paddlex, 'pretrain_dir'):
-#            new_save_dir = paddlex.pretrain_dir
-#        assert backbone in coco_pretrain, "There is not COCO pretrain weights for {}, you may try ImageNet.".format(
-#            backbone)
-#        try:
-#            hub.download(backbone, save_path=new_save_dir)
-#        except Exception as e:
-#            if isinstance(hub.ResourceNotFoundError):
-#                raise Exception(
-#                    "Resource for backbone {} not found".format(backbone))
-#            elif isinstance(hub.ServerConnectionError):
-#                raise Exception(
-#                    "Cannot get reource for backbone {}, please check your internet connecgtion"
-#                    .format(backbone))
-#            else:
-#                raise Exception(
-#                    "Unexpected error, please make sure paddlehub >= 1.6.2")
-#        return osp.join(new_save_dir, backbone)
+        assert backbone in coco_pretrain, "There is not COCO pretrain weights for {}, you may try ImageNet.".format(
+            backbone)
+        try:
+            hub.download(backbone, save_path=new_save_dir)
+        except Exception as e:
+            if isinstance(hub.ResourceNotFoundError):
+                raise Exception("Resource for backbone {} not found".format(
+                    backbone))
+            elif isinstance(hub.ServerConnectionError):
+                raise Exception(
+                    "Cannot get reource for backbone {}, please check your internet connecgtion"
+                    .format(backbone))
+            else:
+                raise Exception(
+                    "Unexpected error, please make sure paddlehub >= 1.6.2")
+        return osp.join(new_save_dir, backbone)
     else:
         raise Exception(
             "pretrain_weights need to be defined as directory path or `IMAGENET` or 'COCO' (download pretrain weights automatically)."