Prechádzať zdrojové kódy

import from PaddleClas submodule instead

will-jl944 4 rokov pred
rodič
commit
8fa1234d5d

+ 9 - 0
.gitmodules

@@ -0,0 +1,9 @@
+[submodule "dygraph/PaddleClas"]
+	path = dygraph/PaddleClas
+	url = https://github.com/PaddlePaddle/PaddleClas.git
+[submodule "dygraph/PaddleSeg"]
+	path = dygraph/PaddleSeg
+	url = https://github.com/PaddlePaddle/PaddleSeg.git
+[submodule "dygraph/PaddleDetection"]
+	path = dygraph/PaddleDetection
+	url = https://github.com/PaddlePaddle/PaddleDetection.git

+ 1 - 0
dygraph/PaddleClas

@@ -0,0 +1 @@
+Subproject commit b5de5322b9f40449db4d55078044a1e1b44c0644

+ 1 - 0
dygraph/PaddleDetection

@@ -0,0 +1 @@
+Subproject commit a065c4c0eb95e1a201615304eaa72f36bc91d699

+ 1 - 0
dygraph/PaddleSeg

@@ -0,0 +1 @@
+Subproject commit 56793289eda25dd2034160a6538fa43808c553f3

+ 7 - 7
dygraph/paddlex/cv/models/base.py

@@ -63,12 +63,8 @@ class BaseModel:
                     os.remove(save_dir)
                 os.makedirs(save_dir)
             if self.model_type == 'classifier':
-                scale = getattr(self, 'scale', None)
                 pretrain_weights = get_pretrain_weights(
-                    pretrain_weights,
-                    self.__class__.__name__,
-                    save_dir,
-                    scale=scale)
+                    pretrain_weights, self.model_name, save_dir)
             else:
                 backbone_name = getattr(self, 'backbone_name', None)
                 pretrain_weights = get_pretrain_weights(
@@ -163,8 +159,12 @@ class BaseModel:
     def build_data_loader(self, dataset, batch_size, mode='train'):
         batch_size_each_card = get_single_card_bs(batch_size=batch_size)
         if mode == 'eval':
-            batch_size = batch_size_each_card
-            total_steps = math.ceil(dataset.num_samples * 1.0 / batch_size)
+            if self.model_type == 'detector':
+                # detector only supports single card eval with batch size 1
+                total_steps = dataset.num_samples
+            else:
+                batch_size = batch_size_each_card
+                total_steps = math.ceil(dataset.num_samples * 1.0 / batch_size)
             logging.info(
                 "Start to evaluate(total_samples={}, total_steps={})...".
                 format(dataset.num_samples, total_steps))

+ 29 - 14
dygraph/paddlex/cv/models/classifier.py

@@ -22,8 +22,8 @@ import paddle.nn.functional as F
 from paddle.static import InputSpec
 from paddlex.utils import logging, TrainingStats
 from paddlex.cv.models.base import BaseModel
-from paddlex.cv.nets.ppcls.modeling import architectures
-from paddlex.cv.nets.ppcls.modeling.loss import CELoss
+from PaddleClas.ppcls.modeling import architectures
+from PaddleClas.ppcls.modeling.loss import CELoss
 from paddlex.cv.transforms import arrange_transforms
 
 __all__ = [
@@ -399,7 +399,10 @@ class ResNet50_vd(BaseClassifier):
 class ResNet50_vd_ssld(BaseClassifier):
     def __init__(self, num_classes=1000):
         super(ResNet50_vd_ssld, self).__init__(
-            model_name='ResNet50_vd_ssld', num_classes=num_classes)
+            model_name='ResNet50_vd',
+            num_classes=num_classes,
+            lr_mult_list=[.1, .1, .2, .2, .3])
+        self.model_name = 'ResNet50_vd_ssld'
 
 
 class ResNet101_vd(BaseClassifier):
@@ -411,7 +414,10 @@ class ResNet101_vd(BaseClassifier):
 class ResNet101_vd_ssld(BaseClassifier):
     def __init__(self, num_classes=1000):
         super(ResNet101_vd_ssld, self).__init__(
-            model_name='ResNet101_vd_ssld', num_classes=num_classes)
+            model_name='ResNet101_vd_ssld',
+            num_classes=num_classes,
+            lr_mult_list=[.1, .1, .2, .2, .3])
+        self.model_name = 'ResNet101_vd_ssld'
 
 
 class ResNet152_vd(BaseClassifier):
@@ -458,9 +464,13 @@ class MobileNetV1(BaseClassifier):
             logging.warning("scale={} is not supported by MobileNetV1, "
                             "scale is forcibly set to 1.0".format(scale))
             scale = 1.0
-        params = {'scale': scale}
+        if scale == 1:
+            model_name = 'MobileNetV1'
+        else:
+            model_name = 'MobileNetV1_x' + str(scale).replace('.', '_')
+        self.scale = scale
         super(MobileNetV1, self).__init__(
-            model_name='MobileNetV1', num_classes=num_classes, **params)
+            model_name=model_name, num_classes=num_classes)
 
 
 class MobileNetV2(BaseClassifier):
@@ -470,9 +480,12 @@ class MobileNetV2(BaseClassifier):
             logging.warning("scale={} is not supported by MobileNetV2, "
                             "scale is forcibly set to 1.0".format(scale))
             scale = 1.0
-        params = {'scale': scale}
+        if scale == 1:
+            model_name = 'MobileNetV2'
+        else:
+            model_name = 'MobileNetV2_x' + str(scale).replace('.', '_')
         super(MobileNetV2, self).__init__(
-            model_name='MobileNetV2', num_classes=num_classes, **params)
+            model_name=model_name, num_classes=num_classes)
 
 
 class MobileNetV3_small(BaseClassifier):
@@ -482,9 +495,10 @@ class MobileNetV3_small(BaseClassifier):
             logging.warning("scale={} is not supported by MobileNetV3_small, "
                             "scale is forcibly set to 1.0".format(scale))
             scale = 1.0
-        params = {'scale': scale}
+        model_name = 'MobileNetV3_small_x' + str(float(scale)).replace('.',
+                                                                       '_')
         super(MobileNetV3_small, self).__init__(
-            model_name='MobileNetV3_small', num_classes=num_classes, **params)
+            model_name=model_name, num_classes=num_classes)
 
 
 class MobileNetV3_large(BaseClassifier):
@@ -494,9 +508,10 @@ class MobileNetV3_large(BaseClassifier):
             logging.warning("scale={} is not supported by MobileNetV3_large, "
                             "scale is forcibly set to 1.0".format(scale))
             scale = 1.0
-        params = {'scale': scale}
+        model_name = 'MobileNetV3_large_x' + str(float(scale)).replace('.',
+                                                                       '_')
         super(MobileNetV3_large, self).__init__(
-            model_name='MobileNetV3_large', num_classes=num_classes, **params)
+            model_name=model_name, num_classes=num_classes)
 
 
 class DenseNet121(BaseClassifier):
@@ -596,9 +611,9 @@ class ShuffleNetV2(BaseClassifier):
             logging.warning("scale={} is not supported by ShuffleNetV2, "
                             "scale is forcibly set to 1.0".format(scale))
             scale = 1.0
-        params = {'scale': scale}
+        model_name = 'ShuffleNetV2_x' + str(float(scale)).replace('.', '_')
         super(ShuffleNetV2, self).__init__(
-            model_name='ShuffleNetV2', num_classes=num_classes, **params)
+            model_name=model_name, num_classes=num_classes)
 
     def get_test_inputs(self, image_shape):
         if image_shape == [-1, -1]:

+ 1 - 35
dygraph/paddlex/cv/models/detector.py

@@ -19,7 +19,7 @@ import copy
 import os
 import os.path as osp
 
-from paddle.io import DistributedBatchSampler
+from paddle.io import DistributedBatchSampler, DataLoader
 from paddle.static import InputSpec
 import paddlex
 import paddlex.utils.logging as logging
@@ -32,7 +32,6 @@ from paddlex.cv.transforms.operators import _NormalizeBox, _PadBox, _BboxXYXY2XY
 from paddlex.cv.transforms.batch_operators import BatchCompose, BatchRandomResize, BatchRandomResizeByShort, _BatchPadding, _Gt2YoloTarget, _Permute
 from paddlex.cv.transforms import arrange_transforms
 from .base import BaseModel
-from .utils.det_dataloader import BaseDataLoader
 from .utils.det_metrics import VOCMetric, COCOMetric
 from paddlex.utils.checkpoint import det_pretrain_weights_dict
 
@@ -75,39 +74,6 @@ class BaseDetector(BaseModel):
         backbone = backbones.__dict__[backbone_name](**params)
         return backbone
 
-    def build_data_loader(self, dataset, batch_size, mode='train'):
-        batch_size_each_card = get_single_card_bs(batch_size=batch_size)
-        if mode == 'eval':
-            # detector only supports single card eval with batch size 1
-            total_steps = dataset.num_samples
-            logging.info(
-                "Start to evaluate(total_samples={}, total_steps={})...".
-                format(dataset.num_samples, total_steps))
-        if dataset.num_samples < batch_size:
-            raise Exception(
-                'The volume of datset({}) must be larger than batch size({}).'
-                .format(dataset.num_samples, batch_size))
-
-        # TODO detection eval阶段需做判断
-        batch_sampler = DistributedBatchSampler(
-            dataset,
-            batch_size=batch_size_each_card,
-            shuffle=dataset.shuffle,
-            drop_last=mode == 'train')
-
-        shm_size = _get_shared_memory_size_in_M()
-        if shm_size is None or shm_size < 1024.:
-            use_shared_memory = False
-        else:
-            use_shared_memory = True
-
-        loader = BaseDataLoader(
-            dataset,
-            batch_sampler=batch_sampler,
-            use_shared_memory=use_shared_memory)
-
-        return loader
-
     def run(self, net, inputs, mode):
         net_out = net(inputs)
         if mode in ['train', 'eval']:

+ 0 - 52
dygraph/paddlex/cv/models/utils/det_dataloader.py

@@ -1,52 +0,0 @@
-# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import six
-import sys
-from paddle.io import DataLoader
-
-
-class BaseDataLoader(object):
-    def __init__(self, dataset, batch_sampler, use_shared_memory):
-        self._batch_transforms = dataset.batch_transforms
-        self._batch_sampler = batch_sampler
-        self.dataset = dataset
-        self.dataloader = DataLoader(
-            dataset=self.dataset,
-            batch_sampler=self._batch_sampler,
-            collate_fn=self._batch_transforms,
-            num_workers=self.dataset.num_workers,
-            return_list=True,
-            use_shared_memory=use_shared_memory)
-        self.loader = iter(self.dataloader)
-
-    def __call__(self):
-        return self
-
-    def __len__(self):
-        return len(self._batch_sampler)
-
-    def __iter__(self):
-        return self
-
-    def __next__(self):
-        try:
-            return next(self.loader)
-        except StopIteration:
-            self.loader = iter(self.dataloader)
-            six.reraise(*sys.exc_info())
-
-    def next(self):
-        # python2 compatibility
-        return self.__next__()

+ 4 - 12
dygraph/paddlex/utils/checkpoint.py

@@ -102,7 +102,7 @@ imagenet_weights = {
     'https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet152_vd_pretrained.pdparams',
     'ResNet200_vd_IMAGENET':
     'https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet200_vd_pretrained.pdparams',
-    'MobileNetV1_x1_0_IMAGENET':
+    'MobileNetV1_IMAGENET':
     'https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV1_pretrained.pdparams',
     'MobileNetV1_x0_25_IMAGENET':
     'https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV1_x0_25_pretrained.pdparams',
@@ -110,7 +110,7 @@ imagenet_weights = {
     'https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV1_x0_5_pretrained.pdparams',
     'MobileNetV1_x0_75_IMAGENET':
     'https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV1_x0_75_pretrained.pdparams',
-    'MobileNetV2_x1_0_IMAGENET':
+    'MobileNetV2_IMAGENET':
     'https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV2_pretrained.pdparams',
     'MobileNetV2_x0_25_IMAGENET':
     'https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV2_x0_25_pretrained.pdparams',
@@ -327,11 +327,7 @@ coco_weights = {
 }
 
 
-def get_pretrain_weights(flag,
-                         class_name,
-                         save_dir,
-                         scale=None,
-                         backbone_name=None):
+def get_pretrain_weights(flag, class_name, save_dir, backbone_name=None):
     if flag is None:
         return None
     elif osp.isdir(flag):
@@ -341,11 +337,7 @@ def get_pretrain_weights(flag,
 
     # TODO: check flag
     new_save_dir = save_dir
-    if scale is not None:
-        weights_key = "{}_x{}_{}".format(class_name,
-                                         str(float(scale)).replace('.', '_'),
-                                         flag)
-    elif backbone_name is not None:
+    if backbone_name is not None:
         weights_key = "{}_{}_{}".format(class_name, backbone_name, flag)
     else:
         weights_key = "{}_{}".format(class_name, flag)