|
|
@@ -106,6 +106,8 @@ coco_pretrain = {
|
|
|
'https://paddlemodels.bj.bcebos.com/object_detection/mask_rcnn_r101_fpn_1x.tar',
|
|
|
'MaskRCNN_ResNet101_vd_COCO':
|
|
|
'https://paddlemodels.bj.bcebos.com/object_detection/mask_rcnn_r101_vd_fpn_1x.tar',
|
|
|
+ 'MaskRCNN_HRNet_W18_COCO':
|
|
|
+ 'https://bj.bcebos.com/paddlex/pretrained_weights/mask_rcnn_hrnetv2p_w18_2x.tar',
|
|
|
'UNet_COCO': 'https://paddleseg.bj.bcebos.com/models/unet_coco_v3.tgz',
|
|
|
'DeepLabv3p_MobileNetV2_x1.0_COCO':
|
|
|
'https://bj.bcebos.com/v1/paddleseg/deeplab_mobilenet_x1_0_coco.tgz',
|
|
|
@@ -135,7 +137,7 @@ def get_pretrain_weights(flag, class_name, backbone, save_dir):
|
|
|
warning_info = "{} does not support to be finetuned with weights pretrained on the {} dataset, so pretrain_weights is forced to be set to {}"
|
|
|
if flag == 'COCO':
|
|
|
if class_name == "FasterRCNN" and backbone in ['ResNet18'] or \
|
|
|
- class_name == "MaskRCNN" and backbone in ['ResNet18', 'HRNet_W18'] or \
|
|
|
+ class_name == "MaskRCNN" and backbone in ['ResNet18'] or \
|
|
|
class_name == 'DeepLabv3p' and backbone in ['Xception41', 'MobileNetV2_x0.25', 'MobileNetV2_x0.5', 'MobileNetV2_x1.5', 'MobileNetV2_x2.0']:
|
|
|
model_name = '{}_{}'.format(class_name, backbone)
|
|
|
logging.warning(warning_info.format(model_name, flag, 'IMAGENET'))
|
|
|
@@ -144,7 +146,8 @@ def get_pretrain_weights(flag, class_name, backbone, save_dir):
|
|
|
logging.warning(warning_info.format(class_name, flag, 'IMAGENET'))
|
|
|
flag = 'IMAGENET'
|
|
|
elif class_name == 'FastSCNN':
|
|
|
- logging.warning(warning_info.format(class_name, flag, 'CITYSCAPES'))
|
|
|
+ logging.warning(
|
|
|
+ warning_info.format(class_name, flag, 'CITYSCAPES'))
|
|
|
flag = 'CITYSCAPES'
|
|
|
elif flag == 'CITYSCAPES':
|
|
|
model_name = '{}_{}'.format(class_name, backbone)
|
|
|
@@ -167,7 +170,8 @@ def get_pretrain_weights(flag, class_name, backbone, save_dir):
|
|
|
logging.warning(warning_info.format(class_name, flag, 'COCO'))
|
|
|
flag = 'COCO'
|
|
|
elif class_name == 'FastSCNN':
|
|
|
- logging.warning(warning_info.format(class_name, flag, 'CITYSCAPES'))
|
|
|
+ logging.warning(
|
|
|
+ warning_info.format(class_name, flag, 'CITYSCAPES'))
|
|
|
flag = 'CITYSCAPES'
|
|
|
|
|
|
if flag == 'IMAGENET':
|