|
|
@@ -1,4 +1,5 @@
|
|
|
import paddlex
|
|
|
+import paddlex.utils.logging as logging
|
|
|
import paddlehub as hub
|
|
|
import os
|
|
|
import os.path as osp
|
|
|
@@ -73,16 +74,58 @@ image_pretrain = {
|
|
|
}
|
|
|
|
|
|
coco_pretrain = {
|
|
|
+ 'YOLOv3_DarkNet53':
|
|
|
+ 'https://paddlemodels.bj.bcebos.com/object_detection/yolov3_darknet.tar',
|
|
|
+ 'YOLOv3_MobileNetV1':
|
|
|
+ 'https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1.tar',
|
|
|
+ 'YOLOv3_MobileNetV3_large':
|
|
|
+ 'https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v3.pdparams',
|
|
|
+ 'YOLOv3_ResNet34':
|
|
|
+ 'https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r34.tar',
|
|
|
+ 'YOLOv3_ResNet50_vd':
|
|
|
+ 'https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r50vd_dcn.tar',
|
|
|
+ 'FasterRCNN_ResNet50':
|
|
|
+ 'https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r50_fpn_2x.tar',
|
|
|
+ 'FasterRCNN_ResNet50_vd':
|
|
|
+ 'https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r50_vd_fpn_2x.tar',
|
|
|
+ 'FasterRCNN_ResNet101':
|
|
|
+ 'https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r101_fpn_2x.tar',
|
|
|
+ 'FasterRCNN_ResNet101_vd':
|
|
|
+ 'https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r101_vd_fpn_2x.tar',
|
|
|
+ 'FasterRCNN_HRNet_W18':
|
|
|
+ 'https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_hrnetv2p_w18_2x.tar',
|
|
|
+ 'MaskRCNN_ResNet50':
|
|
|
+ 'https://paddlemodels.bj.bcebos.com/object_detection/mask_rcnn_r50_fpn_2x.tar',
|
|
|
+ 'MaskRCNN_ResNet50_vd':
|
|
|
+ 'https://paddlemodels.bj.bcebos.com/object_detection/mask_rcnn_r50_vd_fpn_2x.tar',
|
|
|
+ 'MaskRCNN_ResNet101':
|
|
|
+ 'https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r101_fpn_2x.tar',
|
|
|
+ 'MaskRCNN_ResNet101_vd':
|
|
|
+ 'https://paddlemodels.bj.bcebos.com/object_detection/mask_rcnn_r101_vd_fpn_1x.tar',
|
|
|
'UNet': 'https://paddleseg.bj.bcebos.com/models/unet_coco_v3.tgz'
|
|
|
}
|
|
|
|
|
|
|
|
|
-def get_pretrain_weights(flag, model_type, backbone, save_dir):
|
|
|
+def get_pretrain_weights(flag, class_name, backbone, save_dir):
|
|
|
if flag is None:
|
|
|
return None
|
|
|
elif osp.isdir(flag):
|
|
|
return flag
|
|
|
- elif flag == 'IMAGENET':
|
|
|
+ warning_info = "{} supports to be finetuned with weights pretrained on the IMAGENET dataset only, so pretrain_weights is forced to be set to IMAGENET"
|
|
|
+ if flag == 'COCO':
|
|
|
+ if class_name == "FasterRCNN" and backbone in ['ResNet18'] or \
|
|
|
+ class_name == "MaskRCNN" and backbone in ['ResNet18', 'HRNet_W18'] or \
|
|
|
+ class_name == 'DeepLabv3p' and backbone in ['Xception41', 'MobileNetV2_x0.25', 'MobileNetV2_x0.5', 'MobileNetV2_x1.5', 'MobileNetV2_x2.0']:
|
|
|
+ model_name = '{}_{}'.format(class_name, backbone)
|
|
|
+ logging.warning(warning_info.format(model_name))
|
|
|
+ flag = 'IMAGENET'
|
|
|
+ elif class_name == 'HRNet':
|
|
|
+ logging.warning(warning_info.format(class_name))
|
|
|
+ flag = 'IMAGENET'
|
|
|
+ if flag == 'CITYSCAPES':
|
|
|
+ model_name = '{}_{}'.format(class_name, backbone)
|
|
|
+
|
|
|
+ if flag == 'IMAGENET':
|
|
|
new_save_dir = save_dir
|
|
|
if hasattr(paddlex, 'pretrain_dir'):
|
|
|
new_save_dir = paddlex.pretrain_dir
|
|
|
@@ -94,7 +137,7 @@ def get_pretrain_weights(flag, model_type, backbone, save_dir):
|
|
|
backbone = 'MobileNetV3_small_x1_0_ssld'
|
|
|
elif backbone == 'MobileNetV3_large_ssld':
|
|
|
backbone = 'MobileNetV3_large_x1_0_ssld'
|
|
|
- if model_type == 'detector':
|
|
|
+ if class_name in ['YOLOv3', 'FasterRCNN', 'MaskRCNN']:
|
|
|
if backbone == 'ResNet50':
|
|
|
backbone = 'DetResNet50'
|
|
|
assert backbone in image_pretrain, "There is not ImageNet pretrain weights for {}, you may try COCO.".format(
|
|
|
@@ -121,6 +164,8 @@ def get_pretrain_weights(flag, model_type, backbone, save_dir):
|
|
|
new_save_dir = save_dir
|
|
|
if hasattr(paddlex, 'pretrain_dir'):
|
|
|
new_save_dir = paddlex.pretrain_dir
|
|
|
+ if class_name in ['YOLOv3', 'FasterRCNN', 'MaskRCNN']:
|
|
|
+ backbone = '{}_{}'.format(class_name, backbone)
|
|
|
url = coco_pretrain[backbone]
|
|
|
fname = osp.split(url)[-1].split('.')[0]
|
|
|
# paddlex.utils.download_and_decompress(url, path=new_save_dir)
|