瀏覽代碼

add imagenet weights for deeplabvep

will-jl944 4 年之前
父節點
當前提交
a9a677d983

+ 1 - 0
dygraph/examples/meter_reader/train_segmentation.py

@@ -49,5 +49,6 @@ model.train(
     train_dataset=train_dataset,
     train_dataset=train_dataset,
     train_batch_size=4,
     train_batch_size=4,
     eval_dataset=eval_dataset,
     eval_dataset=eval_dataset,
+    pretrain_weights='IMAGENET',
     learning_rate=0.1,
     learning_rate=0.1,
     save_dir='output/deeplabv3p_r50vd')
     save_dir='output/deeplabv3p_r50vd')

+ 10 - 3
dygraph/paddlex/cv/models/base.py

@@ -61,7 +61,8 @@ class BaseModel:
     def net_initialize(self,
     def net_initialize(self,
                        pretrain_weights=None,
                        pretrain_weights=None,
                        save_dir='.',
                        save_dir='.',
-                       resume_checkpoint=None):
+                       resume_checkpoint=None,
+                       is_backbone_weights=False):
         if pretrain_weights is not None and \
         if pretrain_weights is not None and \
                 not osp.exists(pretrain_weights):
                 not osp.exists(pretrain_weights):
             if not osp.isdir(save_dir):
             if not osp.isdir(save_dir):
@@ -79,8 +80,14 @@ class BaseModel:
                     save_dir,
                     save_dir,
                     backbone_name=backbone_name)
                     backbone_name=backbone_name)
         if pretrain_weights is not None:
         if pretrain_weights is not None:
-            load_pretrain_weights(
-                self.net, pretrain_weights, model_name=self.model_name)
+            if is_backbone_weights:
+                load_pretrain_weights(
+                    self.net.backbone,
+                    pretrain_weights,
+                    model_name=self.model_name)
+            else:
+                load_pretrain_weights(
+                    self.net, pretrain_weights, model_name=self.model_name)
         if resume_checkpoint is not None:
         if resume_checkpoint is not None:
             if not osp.exists(resume_checkpoint):
             if not osp.exists(resume_checkpoint):
                 logging.error(
                 logging.error(

+ 3 - 1
dygraph/paddlex/cv/models/segmenter.py

@@ -256,10 +256,12 @@ class BaseSegmenter(BaseModel):
                     "Invalid pretrain weights. Please specify a '.pdparams' file.",
                     "Invalid pretrain weights. Please specify a '.pdparams' file.",
                     exit=True)
                     exit=True)
         pretrained_dir = osp.join(save_dir, 'pretrain')
         pretrained_dir = osp.join(save_dir, 'pretrain')
+        is_backbone_weights = pretrain_weights == 'IMAGENET'
         self.net_initialize(
         self.net_initialize(
             pretrain_weights=pretrain_weights,
             pretrain_weights=pretrain_weights,
             save_dir=pretrained_dir,
             save_dir=pretrained_dir,
-            resume_checkpoint=resume_checkpoint)
+            resume_checkpoint=resume_checkpoint,
+            is_backbone_weights=is_backbone_weights)
 
 
         self.train_loop(
         self.train_loop(
             num_epochs=num_epochs,
             num_epochs=num_epochs,

+ 9 - 2
dygraph/paddlex/utils/checkpoint.py

@@ -14,13 +14,14 @@
 
 
 import os
 import os
 import os.path as osp
 import os.path as osp
+import glob
 import paddle
 import paddle
 import paddlex.utils.logging as logging
 import paddlex.utils.logging as logging
 from .download import download_and_decompress
 from .download import download_and_decompress
 
 
 seg_pretrain_weights_dict = {
 seg_pretrain_weights_dict = {
     'UNet': ['CITYSCAPES'],
     'UNet': ['CITYSCAPES'],
-    'DeepLabV3P': ['CITYSCAPES', 'PascalVOC'],
+    'DeepLabV3P': ['CITYSCAPES', 'PascalVOC', 'IMAGENET'],
     'FastSCNN': ['CITYSCAPES'],
     'FastSCNN': ['CITYSCAPES'],
     'HRNet': ['CITYSCAPES', 'PascalVOC'],
     'HRNet': ['CITYSCAPES', 'PascalVOC'],
     'BiSeNetV2': ['CITYSCAPES']
     'BiSeNetV2': ['CITYSCAPES']
@@ -254,7 +255,11 @@ imagenet_weights = {
     'MaskRCNN_ResNet101_fpn_IMAGENET':
     'MaskRCNN_ResNet101_fpn_IMAGENET':
     'https://paddledet.bj.bcebos.com/models/pretrained/ResNet101_pretrained.pdparams',
     'https://paddledet.bj.bcebos.com/models/pretrained/ResNet101_pretrained.pdparams',
     'MaskRCNN_ResNet101_vd_fpn_IMAGENET':
     'MaskRCNN_ResNet101_vd_fpn_IMAGENET':
-    'https://paddledet.bj.bcebos.com/models/pretrained/ResNet101_vd_pretrained.pdparams'
+    'https://paddledet.bj.bcebos.com/models/pretrained/ResNet101_vd_pretrained.pdparams',
+    'DeepLabV3P_ResNet50_vd_IMAGENET':
+    'https://bj.bcebos.com/paddleseg/dygraph/resnet50_vd_ssld_v2.tar.gz',
+    'DeepLabV3P_ResNet101_vd_IMAGENET':
+    'https://bj.bcebos.com/paddleseg/dygraph/resnet101_vd_ssld.tar.gz'
 }
 }
 
 
 pascalvoc_weights = {
 pascalvoc_weights = {
@@ -364,6 +369,8 @@ def get_pretrain_weights(flag, class_name, save_dir, backbone_name=None):
         raise ValueError('Given pretrained weights {} is undefined.'.format(
         raise ValueError('Given pretrained weights {} is undefined.'.format(
             flag))
             flag))
     fname = download_and_decompress(url, path=new_save_dir)
     fname = download_and_decompress(url, path=new_save_dir)
+    if osp.isdir(fname):
+        fname = glob.glob(osp.join(fname, '*.pdparams'))[0]
     return fname
     return fname
 
 
 
 

+ 10 - 5
dygraph/paddlex/utils/download.py

@@ -164,6 +164,7 @@ def decompress(fname):
 
 
     shutil.rmtree(fpath_tmp)
     shutil.rmtree(fpath_tmp)
     logging.debug("{} decompressed.".format(fname))
     logging.debug("{} decompressed.".format(fname))
+    return dst_dir
 
 
 
 
 def url2dir(url, path):
 def url2dir(url, path):
@@ -171,7 +172,7 @@ def url2dir(url, path):
     if url.endswith(('tgz', 'tar.gz', 'tar', 'zip')):
     if url.endswith(('tgz', 'tar.gz', 'tar', 'zip')):
         fname = osp.split(url)[-1]
         fname = osp.split(url)[-1]
         savepath = osp.join(path, fname)
         savepath = osp.join(path, fname)
-        decompress(savepath)
+        return decompress(savepath)
 
 
 
 
 def download_and_decompress(url, path='.'):
 def download_and_decompress(url, path='.'):
@@ -179,17 +180,21 @@ def download_and_decompress(url, path='.'):
     local_rank = paddle.distributed.get_rank()
     local_rank = paddle.distributed.get_rank()
     fname = osp.split(url)[-1]
     fname = osp.split(url)[-1]
     fullname = osp.join(path, fname)
     fullname = osp.join(path, fname)
-    if url.endswith(('tgz', 'tar.gz', 'tar', 'zip')):
-        fullname = osp.join(path, fname.split('.')[0])
+    # if url.endswith(('tgz', 'tar.gz', 'tar', 'zip')):
+    #     fullname = osp.join(path, fname.split('.')[0])
     if nranks <= 1:
     if nranks <= 1:
-        url2dir(url, path)
+        dst_dir = url2dir(url, path)
+        if dst_dir is not None:
+            fullname = dst_dir
     else:
     else:
         lock_path = fullname + '.lock'
         lock_path = fullname + '.lock'
         if not os.path.exists(fullname):
         if not os.path.exists(fullname):
             with open(lock_path, 'w'):
             with open(lock_path, 'w'):
                 os.utime(lock_path, None)
                 os.utime(lock_path, None)
             if local_rank == 0:
             if local_rank == 0:
-                url2dir(url, path)
+                dst_dir = url2dir(url, path)
+                if dst_dir is not None:
+                    fullname = dst_dir
                 os.remove(lock_path)
                 os.remove(lock_path)
             else:
             else:
                 while os.path.exists(lock_path):
                 while os.path.exists(lock_path):