Răsfoiți Sursa

add coco pretrained weights for detection

FlyingQianMM 5 ani în urmă
părinte
comite
ee2d40d43e
2 a modificat fișierele cu 49 adăugiri și 4 ștergeri
  1. 1 1
      paddlex/cv/models/base.py
  2. 48 3
      paddlex/cv/models/utils/pretrain_weights.py

+ 1 - 1
paddlex/cv/models/base.py

@@ -201,7 +201,7 @@ class BaseAPI:
                 if backbone == "HRNet":
                     backbone = backbone + "_W{}".format(self.width)
             pretrain_weights = get_pretrain_weights(
-                pretrain_weights, self.model_type, backbone, pretrain_dir)
+                pretrain_weights, class_name, backbone, pretrain_dir)
         if startup_prog is None:
             startup_prog = fluid.default_startup_program()
         self.exe.run(startup_prog)

+ 48 - 3
paddlex/cv/models/utils/pretrain_weights.py

@@ -1,4 +1,5 @@
 import paddlex
+import paddlex.utils.logging as logging
 import paddlehub as hub
 import os
 import os.path as osp
@@ -73,16 +74,58 @@ image_pretrain = {
 }
 
 coco_pretrain = {
+    'YOLOv3_DarkNet53':
+    'https://paddlemodels.bj.bcebos.com/object_detection/yolov3_darknet.tar',
+    'YOLOv3_MobileNetV1':
+    'https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1.tar',
+    'YOLOv3_MobileNetV3_large':
+    'https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v3.pdparams',
+    'YOLOv3_ResNet34':
+    'https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r34.tar',
+    'YOLOv3_ResNet50_vd':
+    'https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r50vd_dcn.tar',
+    'FasterRCNN_ResNet50':
+    'https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r50_fpn_2x.tar',
+    'FasterRCNN_ResNet50_vd':
+    'https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r50_vd_fpn_2x.tar',
+    'FasterRCNN_ResNet101':
+    'https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r101_fpn_2x.tar',
+    'FasterRCNN_ResNet101_vd':
+    'https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r101_vd_fpn_2x.tar',
+    'FasterRCNN_HRNet_W18':
+    'https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_hrnetv2p_w18_2x.tar',
+    'MaskRCNN_ResNet50':
+    'https://paddlemodels.bj.bcebos.com/object_detection/mask_rcnn_r50_fpn_2x.tar',
+    'MaskRCNN_ResNet50_vd':
+    'https://paddlemodels.bj.bcebos.com/object_detection/mask_rcnn_r50_vd_fpn_2x.tar',
+    'MaskRCNN_ResNet101':
+    'https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r101_fpn_2x.tar',
+    'MaskRCNN_ResNet101_vd':
+    'https://paddlemodels.bj.bcebos.com/object_detection/mask_rcnn_r101_vd_fpn_1x.tar',
     'UNet': 'https://paddleseg.bj.bcebos.com/models/unet_coco_v3.tgz'
 }
 
 
-def get_pretrain_weights(flag, model_type, backbone, save_dir):
+def get_pretrain_weights(flag, class_name, backbone, save_dir):
     if flag is None:
         return None
     elif osp.isdir(flag):
         return flag
-    elif flag == 'IMAGENET':
+    warning_info = "{} supports to be finetuned with weights pretrained on the IMAGENET dataset only, so pretrain_weights is forced to be set to IMAGENET"
+    if flag == 'COCO':
+        if class_name == "FasterRCNN" and backbone in ['ResNet18'] or \
+            class_name == "MaskRCNN" and backbone in ['ResNet18', 'HRNet_W18'] or \
+            class_name == 'DeepLabv3p' and backbone in ['Xception41', 'MobileNetV2_x0.25', 'MobileNetV2_x0.5', 'MobileNetV2_x1.5', 'MobileNetV2_x2.0']:
+            model_name = '{}_{}'.format(class_name, backbone)
+            logging.warning(warning_info.format(model_name))
+            flag = 'IMAGENET'
+        elif class_name == 'HRNet':
+            logging.warning(warning_info.format(class_name))
+            flag = 'IMAGENET'
+    if flag == 'CITYSCAPES':
+        model_name = '{}_{}'.format(class_name, backbone)
+
+    if flag == 'IMAGENET':
         new_save_dir = save_dir
         if hasattr(paddlex, 'pretrain_dir'):
             new_save_dir = paddlex.pretrain_dir
@@ -94,7 +137,7 @@ def get_pretrain_weights(flag, model_type, backbone, save_dir):
             backbone = 'MobileNetV3_small_x1_0_ssld'
         elif backbone == 'MobileNetV3_large_ssld':
             backbone = 'MobileNetV3_large_x1_0_ssld'
-        if model_type == 'detector':
+        if class_name in ['YOLOv3', 'FasterRCNN', 'MaskRCNN']:
             if backbone == 'ResNet50':
                 backbone = 'DetResNet50'
         assert backbone in image_pretrain, "There is not ImageNet pretrain weights for {}, you may try COCO.".format(
@@ -121,6 +164,8 @@ def get_pretrain_weights(flag, model_type, backbone, save_dir):
         new_save_dir = save_dir
         if hasattr(paddlex, 'pretrain_dir'):
             new_save_dir = paddlex.pretrain_dir
+        if class_name in ['YOLOv3', 'FasterRCNN', 'MaskRCNN']:
+            backbone = '{}_{}'.format(class_name, backbone)
         url = coco_pretrain[backbone]
         fname = osp.split(url)[-1].split('.')[0]
         #        paddlex.utils.download_and_decompress(url, path=new_save_dir)