Browse Source

Merge pull request #496 from FlyingQianMM/develop_qh

import paddlehub to download COCO pretrained weights
Jason 4 years ago
parent
commit
9d81e2f786
1 changed files with 6 additions and 2 deletions
  1. 6 2
      paddlex/cv/models/utils/pretrain_weights.py

+ 6 - 2
paddlex/cv/models/utils/pretrain_weights.py

@@ -160,7 +160,8 @@ def get_pretrain_weights(flag, class_name, backbone, save_dir):
             logging.warning(warning_info.format(class_name, flag, 'IMAGENET'))
             flag = 'IMAGENET'
         elif class_name == 'FastSCNN':
-            logging.warning(warning_info.format(class_name, flag, 'CITYSCAPES'))
+            logging.warning(
+                warning_info.format(class_name, flag, 'CITYSCAPES'))
             flag = 'CITYSCAPES'
     elif flag == 'CITYSCAPES':
         model_name = '{}_{}'.format(class_name, backbone)
@@ -183,7 +184,8 @@ def get_pretrain_weights(flag, class_name, backbone, save_dir):
             logging.warning(warning_info.format(class_name, flag, 'COCO'))
             flag = 'COCO'
         elif class_name == 'FastSCNN':
-            logging.warning(warning_info.format(class_name, flag, 'CITYSCAPES'))
+            logging.warning(
+                warning_info.format(class_name, flag, 'CITYSCAPES'))
             flag = 'CITYSCAPES'
     elif flag == 'BAIDU10W':
         if class_name not in ['ResNet50_vd']:
@@ -254,6 +256,8 @@ def get_pretrain_weights(flag, class_name, backbone, save_dir):
         if getattr(paddlex, 'gui_mode', False):
             paddlex.utils.download_and_decompress(url, path=new_save_dir)
             return osp.join(new_save_dir, fname)
+
+        import paddlehub as hub
         try:
             logging.info(
                 "Connecting PaddleHub server to get pretrain weights...")