|
|
@@ -160,7 +160,8 @@ def get_pretrain_weights(flag, class_name, backbone, save_dir):
|
|
|
logging.warning(warning_info.format(class_name, flag, 'IMAGENET'))
|
|
|
flag = 'IMAGENET'
|
|
|
elif class_name == 'FastSCNN':
|
|
|
- logging.warning(warning_info.format(class_name, flag, 'CITYSCAPES'))
|
|
|
+ logging.warning(
|
|
|
+ warning_info.format(class_name, flag, 'CITYSCAPES'))
|
|
|
flag = 'CITYSCAPES'
|
|
|
elif flag == 'CITYSCAPES':
|
|
|
model_name = '{}_{}'.format(class_name, backbone)
|
|
|
@@ -183,7 +184,8 @@ def get_pretrain_weights(flag, class_name, backbone, save_dir):
|
|
|
logging.warning(warning_info.format(class_name, flag, 'COCO'))
|
|
|
flag = 'COCO'
|
|
|
elif class_name == 'FastSCNN':
|
|
|
- logging.warning(warning_info.format(class_name, flag, 'CITYSCAPES'))
|
|
|
+ logging.warning(
|
|
|
+ warning_info.format(class_name, flag, 'CITYSCAPES'))
|
|
|
flag = 'CITYSCAPES'
|
|
|
elif flag == 'BAIDU10W':
|
|
|
if class_name not in ['ResNet50_vd']:
|
|
|
@@ -254,6 +256,8 @@ def get_pretrain_weights(flag, class_name, backbone, save_dir):
|
|
|
if getattr(paddlex, 'gui_mode', False):
|
|
|
paddlex.utils.download_and_decompress(url, path=new_save_dir)
|
|
|
return osp.join(new_save_dir, fname)
|
|
|
+
|
|
|
+ import paddlehub as hub
|
|
|
try:
|
|
|
logging.info(
|
|
|
"Connecting PaddleHub server to get pretrain weights...")
|