Bläddra i källkod

fix strip model name bug

will-jl944 4 år sedan
förälder
incheckning
239ab2d864
1 ändrade filer med 11 tillägg och 7 borttagningar
  1. 11 7
      dygraph/paddlex/cv/models/classifier.py

+ 11 - 7
dygraph/paddlex/cv/models/classifier.py

@@ -54,7 +54,7 @@ class BaseClassifier(BaseModel):
         self.init_params.update(params)
         self.init_params.update(params)
         del self.init_params['params']
         del self.init_params['params']
         super(BaseClassifier, self).__init__('classifier')
         super(BaseClassifier, self).__init__('classifier')
-        if not hasattr(architectures, model_name.strip('_ssld')):
+        if not hasattr(architectures, model_name):
             raise Exception("ERROR: There's no model named {}.".format(
             raise Exception("ERROR: There's no model named {}.".format(
                 model_name))
                 model_name))
 
 
@@ -67,7 +67,7 @@ class BaseClassifier(BaseModel):
 
 
     def build_net(self, **params):
     def build_net(self, **params):
         with paddle.utils.unique_name.guard():
         with paddle.utils.unique_name.guard():
-            net = architectures.__dict__[self.model_name.strip('_ssld')](
+            net = architectures.__dict__[self.model_name](
                 class_dim=self.num_classes, **params)
                 class_dim=self.num_classes, **params)
         return net
         return net
 
 
@@ -407,9 +407,10 @@ class ResNet50_vd(BaseClassifier):
 class ResNet50_vd_ssld(BaseClassifier):
 class ResNet50_vd_ssld(BaseClassifier):
     def __init__(self, num_classes=1000):
     def __init__(self, num_classes=1000):
         super(ResNet50_vd_ssld, self).__init__(
         super(ResNet50_vd_ssld, self).__init__(
-            model_name='ResNet50_vd_ssld',
+            model_name='ResNet50_vd',
             num_classes=num_classes,
             num_classes=num_classes,
             lr_mult_list=[.1, .1, .2, .2, .3])
             lr_mult_list=[.1, .1, .2, .2, .3])
+        self.model_name = 'ResNet50_vd_ssld'
 
 
 
 
 class ResNet101_vd(BaseClassifier):
 class ResNet101_vd(BaseClassifier):
@@ -421,9 +422,10 @@ class ResNet101_vd(BaseClassifier):
 class ResNet101_vd_ssld(BaseClassifier):
 class ResNet101_vd_ssld(BaseClassifier):
     def __init__(self, num_classes=1000):
     def __init__(self, num_classes=1000):
         super(ResNet101_vd_ssld, self).__init__(
         super(ResNet101_vd_ssld, self).__init__(
-            model_name='ResNet101_vd_ssld',
+            model_name='ResNet101_vd',
             num_classes=num_classes,
             num_classes=num_classes,
             lr_mult_list=[.1, .1, .2, .2, .3])
             lr_mult_list=[.1, .1, .2, .2, .3])
+        self.model_name = 'ResNet101_vd_ssld'
 
 
 
 
 class ResNet152_vd(BaseClassifier):
 class ResNet152_vd(BaseClassifier):
@@ -517,10 +519,11 @@ class MobileNetV3_small_ssld(BaseClassifier):
                 "scale={} is not supported by MobileNetV3_small_ssld, "
                 "scale={} is not supported by MobileNetV3_small_ssld, "
                 "scale is forcibly set to 1.0".format(scale))
                 "scale is forcibly set to 1.0".format(scale))
             scale = 1.0
             scale = 1.0
-        model_name = 'MobileNetV3_small_x' + str(float(scale)).replace(
-            '.', '_') + '_ssld'
+        model_name = 'MobileNetV3_small_x' + str(float(scale)).replace('.',
+                                                                       '_')
         super(MobileNetV3_small_ssld, self).__init__(
         super(MobileNetV3_small_ssld, self).__init__(
             model_name=model_name, num_classes=num_classes)
             model_name=model_name, num_classes=num_classes)
+        self.model_name = model_name + '_ssld'
 
 
 
 
 class MobileNetV3_large(BaseClassifier):
 class MobileNetV3_large(BaseClassifier):
@@ -539,7 +542,8 @@ class MobileNetV3_large(BaseClassifier):
 class MobileNetV3_large_ssld(BaseClassifier):
 class MobileNetV3_large_ssld(BaseClassifier):
     def __init__(self, num_classes=1000):
     def __init__(self, num_classes=1000):
         super(MobileNetV3_large_ssld, self).__init__(
         super(MobileNetV3_large_ssld, self).__init__(
-            model_name='MobileNetV3_large_x1_0_ssld', num_classes=num_classes)
+            model_name='MobileNetV3_large_x1_0', num_classes=num_classes)
+        self.model_name = 'MobileNetV3_large_x1_0_ssld'
 
 
 
 
 class DenseNet121(BaseClassifier):
 class DenseNet121(BaseClassifier):