Sfoglia il codice sorgente

import from paddleseg submodule

will-jl944 4 anni fa
parent
commit
0f89bd6e79
3 ha cambiato i file con 18 aggiunte e 14 eliminazioni
  1. 1 1
      dygraph/PaddleSeg
  2. 15 13
      dygraph/paddlex/cv/models/segmenter.py
  3. 2 0
      dygraph/requirements.txt

+ 1 - 1
dygraph/PaddleSeg

@@ -1 +1 @@
-Subproject commit 2dc99b8dd4dfb1468eed238b46f395e3100695e1
+Subproject commit eb3a98f4c17c467ed451bee1bdc8690a9917d7c5

+ 15 - 13
dygraph/paddlex/cv/models/segmenter.py

@@ -19,15 +19,14 @@ from collections import OrderedDict
 import paddle
 import paddle.nn.functional as F
 from paddle.static import InputSpec
+import paddleseg
 import paddlex
-from paddlex.cv.nets.paddleseg import models
 from paddlex.cv.transforms import arrange_transforms
 from paddlex.utils import get_single_card_bs
 import paddlex.utils.logging as logging
 from .base import BaseModel
 from .utils import seg_metrics as metrics
 from paddlex.utils.checkpoint import seg_pretrain_weights_dict
-from paddlex.cv.nets.paddleseg.cvlibs import manager
 from paddlex.cv.transforms import Decode
 
 __all__ = ["UNet", "DeepLabV3P", "FastSCNN", "HRNet", "BiSeNetV2"]
@@ -41,7 +40,7 @@ class BaseSegmenter(BaseModel):
                  **params):
         self.init_params = locals()
         super(BaseSegmenter, self).__init__('segmenter')
-        if not hasattr(models, model_name):
+        if not hasattr(paddleseg.models, model_name):
             raise Exception("ERROR: There's no model named {}.".format(
                 model_name))
         self.model_name = model_name
@@ -54,8 +53,8 @@ class BaseSegmenter(BaseModel):
     def build_net(self, **params):
         # TODO: when using paddle.utils.unique_name.guard,
         # DeepLabv3p and HRNet will raise a error
-        net = models.__dict__[self.model_name](num_classes=self.num_classes,
-                                               **params)
+        net = paddleseg.models.__dict__[self.model_name](
+            num_classes=self.num_classes, **params)
         return net
 
     def get_test_inputs(self, image_shape):
@@ -101,15 +100,16 @@ class BaseSegmenter(BaseModel):
         if isinstance(self.use_mixed_loss, bool):
             if self.use_mixed_loss:
                 losses = [
-                    manager.LOSSES['CrossEntropyLoss'](),
-                    manager.LOSSES['LovaszSoftmaxLoss']()
+                    paddleseg.models.CrossEntropyLoss(),
+                    paddleseg.models.LovaszSoftmaxLoss()
                 ]
                 coef = [.8, .2]
                 loss_type = [
-                    manager.LOSSES['MixedLoss'](losses=losses, coef=coef)
+                    paddleseg.models.MixedLoss(
+                        losses=losses, coef=coef),
                 ]
             else:
-                loss_type = [manager.LOSSES['CrossEntropyLoss']()]
+                loss_type = [paddleseg.models.CrossEntropyLoss()]
         else:
             losses, coef = list(zip(*self.use_mixed_loss))
             if not set(losses).issubset(
@@ -117,9 +117,10 @@ class BaseSegmenter(BaseModel):
                 raise ValueError(
                     "Only 'CrossEntropyLoss', 'DiceLoss', 'LovaszSoftmaxLoss' are supported."
                 )
-            losses = [manager.LOSSES[loss]() for loss in losses]
+            losses = [getattr(paddleseg.models, loss)() for loss in losses]
             loss_type = [
-                manager.LOSSES['MixedLoss'](losses=losses, coef=list(coef))
+                paddleseg.models.MixedLoss(
+                    losses=losses, coef=list(coef))
             ]
         if self.model_name == 'FastSCNN':
             loss_type *= 2
@@ -447,7 +448,8 @@ class DeepLabV3P(BaseSegmenter):
             raise ValueError(
                 "backbone: {} is not supported. Please choose one of "
                 "('ResNet50_vd', 'ResNet101_vd')".format(backbone))
-        backbone = manager.BACKBONES[backbone](output_stride=output_stride)
+        backbone = getattr(paddleseg.models, backbone)(
+            output_stride=output_stride)
         params = {
             'backbone': backbone,
             'backbone_indices': backbone_indices,
@@ -486,7 +488,7 @@ class HRNet(BaseSegmenter):
                 "width={} is not supported, please choose from [18, 48]".
                 format(width))
         self.backbone_name = 'HRNet_W{}'.format(width)
-        backbone = manager.BACKBONES[self.backbone_name](
+        backbone = getattr(paddleseg.models, self.backbone_name)(
             align_corners=align_corners)
 
         params = {'backbone': backbone, 'align_corners': align_corners}

+ 2 - 0
dygraph/requirements.txt

@@ -9,3 +9,5 @@ shapely
 paddlepaddle-gpu==2.1.0
 opencv-python
 -r PaddleClas/requirements.txt
+-r ./PaddleSeg/requirements.txt
+./PaddleSeg