|
|
@@ -54,7 +54,7 @@ class BaseClassifier(BaseModel):
|
|
|
self.init_params.update(params)
|
|
|
del self.init_params['params']
|
|
|
super(BaseClassifier, self).__init__('classifier')
|
|
|
- if not hasattr(architectures, model_name.strip('_ssld')):
|
|
|
+ if not hasattr(architectures, model_name):
|
|
|
raise Exception("ERROR: There's no model named {}.".format(
|
|
|
model_name))
|
|
|
|
|
|
@@ -67,7 +67,7 @@ class BaseClassifier(BaseModel):
|
|
|
|
|
|
def build_net(self, **params):
|
|
|
with paddle.utils.unique_name.guard():
|
|
|
- net = architectures.__dict__[self.model_name.strip('_ssld')](
|
|
|
+ net = architectures.__dict__[self.model_name](
|
|
|
class_dim=self.num_classes, **params)
|
|
|
return net
|
|
|
|
|
|
@@ -407,9 +407,10 @@ class ResNet50_vd(BaseClassifier):
|
|
|
class ResNet50_vd_ssld(BaseClassifier):
|
|
|
def __init__(self, num_classes=1000):
|
|
|
super(ResNet50_vd_ssld, self).__init__(
|
|
|
- model_name='ResNet50_vd_ssld',
|
|
|
+ model_name='ResNet50_vd',
|
|
|
num_classes=num_classes,
|
|
|
lr_mult_list=[.1, .1, .2, .2, .3])
|
|
|
+ self.model_name = 'ResNet50_vd_ssld'
|
|
|
|
|
|
|
|
|
class ResNet101_vd(BaseClassifier):
|
|
|
@@ -421,9 +422,10 @@ class ResNet101_vd(BaseClassifier):
|
|
|
class ResNet101_vd_ssld(BaseClassifier):
|
|
|
def __init__(self, num_classes=1000):
|
|
|
super(ResNet101_vd_ssld, self).__init__(
|
|
|
- model_name='ResNet101_vd_ssld',
|
|
|
+ model_name='ResNet101_vd',
|
|
|
num_classes=num_classes,
|
|
|
lr_mult_list=[.1, .1, .2, .2, .3])
|
|
|
+ self.model_name = 'ResNet101_vd_ssld'
|
|
|
|
|
|
|
|
|
class ResNet152_vd(BaseClassifier):
|
|
|
@@ -517,10 +519,11 @@ class MobileNetV3_small_ssld(BaseClassifier):
|
|
|
"scale={} is not supported by MobileNetV3_small_ssld, "
|
|
|
"scale is forcibly set to 1.0".format(scale))
|
|
|
scale = 1.0
|
|
|
- model_name = 'MobileNetV3_small_x' + str(float(scale)).replace(
|
|
|
- '.', '_') + '_ssld'
|
|
|
+ model_name = 'MobileNetV3_small_x' + str(float(scale)).replace('.',
|
|
|
+ '_')
|
|
|
super(MobileNetV3_small_ssld, self).__init__(
|
|
|
model_name=model_name, num_classes=num_classes)
|
|
|
+ self.model_name = model_name + '_ssld'
|
|
|
|
|
|
|
|
|
class MobileNetV3_large(BaseClassifier):
|
|
|
@@ -539,7 +542,8 @@ class MobileNetV3_large(BaseClassifier):
|
|
|
class MobileNetV3_large_ssld(BaseClassifier):
|
|
|
def __init__(self, num_classes=1000):
|
|
|
super(MobileNetV3_large_ssld, self).__init__(
|
|
|
- model_name='MobileNetV3_large_x1_0_ssld', num_classes=num_classes)
|
|
|
+ model_name='MobileNetV3_large_x1_0', num_classes=num_classes)
|
|
|
+ self.model_name = 'MobileNetV3_large_x1_0_ssld'
|
|
|
|
|
|
|
|
|
class DenseNet121(BaseClassifier):
|