Browse Source

add imagnet pretrain weigths downlink for deeplabv3p

FlyingQianMM 4 năm trước cách đây
mục cha
commit
3db888f2ce

+ 1 - 1
dygraph/paddlex/cv/models/segmenter.py

@@ -208,7 +208,7 @@ class BaseSegmenter(BaseModel):
             log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
             save_dir(str, optional): Directory to save the model. Defaults to 'output'.
             pretrain_weights(str or None, optional):
-                None or name/path of pretrained weights. If None, no pretrained weights will be loaded. Defaults to 'IMAGENET'.
+                None or name/path of pretrained weights. If None, no pretrained weights will be loaded. Defaults to 'CITYSCAPES'.
             learning_rate(float, optional): Learning rate for training. Defaults to .025.
             lr_decay_power(float, optional): Learning decay power. Defaults to .9.
             early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.

+ 6 - 2
dygraph/paddlex/utils/checkpoint.py

@@ -20,7 +20,7 @@ from .download import download_and_decompress
 
 seg_pretrain_weights_dict = {
     'UNet': ['CITYSCAPES'],
-    'DeepLabV3P': ['CITYSCAPES', 'PascalVOC'],
+    'DeepLabV3P': ['CITYSCAPES', 'PascalVOC', 'IMAGENET'],
     'FastSCNN': ['CITYSCAPES'],
     'HRNet': ['CITYSCAPES', 'PascalVOC'],
     'BiSeNetV2': ['CITYSCAPES']
@@ -254,7 +254,11 @@ imagenet_weights = {
     'MaskRCNN_ResNet101_fpn_IMAGENET':
     'https://paddledet.bj.bcebos.com/models/pretrained/ResNet101_pretrained.pdparams',
     'MaskRCNN_ResNet101_vd_fpn_IMAGENET':
-    'https://paddledet.bj.bcebos.com/models/pretrained/ResNet101_vd_pretrained.pdparams'
+    'https://paddledet.bj.bcebos.com/models/pretrained/ResNet101_vd_pretrained.pdparams',
+    'DeepLabV3P_ResNet50_vd_IMAGENET':
+    'https://bj.bcebos.com/paddleseg/dygraph/resnet50_vd_ssld.tar.gz',
+    'DeepLabV3P_ResNet101_vd_IMAGENET':
+    'https://bj.bcebos.com/paddleseg/dygraph/resnet101_vd_ssld.tar.gz'
 }
 
 pascalvoc_weights = {