|
|
@@ -55,6 +55,8 @@ class BaseClassifier(BaseModel):
|
|
|
self.init_params.update(params)
|
|
|
if 'lr_mult_list' in self.init_params:
|
|
|
del self.init_params['lr_mult_list']
|
|
|
+ if 'with_net' in self.init_params:
|
|
|
+ del self.init_params['with_net']
|
|
|
super(BaseClassifier, self).__init__('classifier')
|
|
|
if not hasattr(architectures, model_name):
|
|
|
raise Exception("ERROR: There's no model named {}.".format(
|
|
|
@@ -65,7 +67,9 @@ class BaseClassifier(BaseModel):
|
|
|
self.num_classes = num_classes
|
|
|
for k, v in params.items():
|
|
|
setattr(self, k, v)
|
|
|
- self.net = self.build_net(**params)
|
|
|
+ if params.get('with_net', True):
|
|
|
+ params.pop('with_net', None)
|
|
|
+ self.net = self.build_net(**params)
|
|
|
|
|
|
def build_net(self, **params):
|
|
|
with paddle.utils.unique_name.guard():
|
|
|
@@ -469,93 +473,95 @@ class BaseClassifier(BaseModel):
|
|
|
|
|
|
|
|
|
class ResNet18(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, **params):
|
|
|
super(ResNet18, self).__init__(
|
|
|
- model_name='ResNet18', num_classes=num_classes)
|
|
|
+ model_name='ResNet18', num_classes=num_classes, **params)
|
|
|
|
|
|
|
|
|
class ResNet34(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, **params):
|
|
|
super(ResNet34, self).__init__(
|
|
|
- model_name='ResNet34', num_classes=num_classes)
|
|
|
+ model_name='ResNet34', num_classes=num_classes, **params)
|
|
|
|
|
|
|
|
|
class ResNet50(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, **params):
|
|
|
super(ResNet50, self).__init__(
|
|
|
- model_name='ResNet50', num_classes=num_classes)
|
|
|
+ model_name='ResNet50', num_classes=num_classes, **params)
|
|
|
|
|
|
|
|
|
class ResNet101(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, **params):
|
|
|
super(ResNet101, self).__init__(
|
|
|
- model_name='ResNet101', num_classes=num_classes)
|
|
|
+ model_name='ResNet101', num_classes=num_classes, **params)
|
|
|
|
|
|
|
|
|
class ResNet152(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, **params):
|
|
|
super(ResNet152, self).__init__(
|
|
|
- model_name='ResNet152', num_classes=num_classes)
|
|
|
+ model_name='ResNet152', num_classes=num_classes, **params)
|
|
|
|
|
|
|
|
|
class ResNet18_vd(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, **params):
|
|
|
super(ResNet18_vd, self).__init__(
|
|
|
- model_name='ResNet18_vd', num_classes=num_classes)
|
|
|
+ model_name='ResNet18_vd', num_classes=num_classes, **params)
|
|
|
|
|
|
|
|
|
class ResNet34_vd(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, **params):
|
|
|
super(ResNet34_vd, self).__init__(
|
|
|
- model_name='ResNet34_vd', num_classes=num_classes)
|
|
|
+ model_name='ResNet34_vd', num_classes=num_classes, **params)
|
|
|
|
|
|
|
|
|
class ResNet50_vd(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, **params):
|
|
|
super(ResNet50_vd, self).__init__(
|
|
|
- model_name='ResNet50_vd', num_classes=num_classes)
|
|
|
+ model_name='ResNet50_vd', num_classes=num_classes, **params)
|
|
|
|
|
|
|
|
|
class ResNet50_vd_ssld(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, **params):
|
|
|
super(ResNet50_vd_ssld, self).__init__(
|
|
|
model_name='ResNet50_vd',
|
|
|
num_classes=num_classes,
|
|
|
- lr_mult_list=[.1, .1, .2, .2, .3])
|
|
|
+ lr_mult_list=[.1, .1, .2, .2, .3],
|
|
|
+ **params)
|
|
|
self.model_name = 'ResNet50_vd_ssld'
|
|
|
|
|
|
|
|
|
class ResNet101_vd(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, **params):
|
|
|
super(ResNet101_vd, self).__init__(
|
|
|
- model_name='ResNet101_vd', num_classes=num_classes)
|
|
|
+ model_name='ResNet101_vd', num_classes=num_classes, **params)
|
|
|
|
|
|
|
|
|
class ResNet101_vd_ssld(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, **params):
|
|
|
super(ResNet101_vd_ssld, self).__init__(
|
|
|
model_name='ResNet101_vd',
|
|
|
num_classes=num_classes,
|
|
|
- lr_mult_list=[.1, .1, .2, .2, .3])
|
|
|
+ lr_mult_list=[.1, .1, .2, .2, .3],
|
|
|
+ **params)
|
|
|
self.model_name = 'ResNet101_vd_ssld'
|
|
|
|
|
|
|
|
|
class ResNet152_vd(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, **params):
|
|
|
super(ResNet152_vd, self).__init__(
|
|
|
- model_name='ResNet152_vd', num_classes=num_classes)
|
|
|
+ model_name='ResNet152_vd', num_classes=num_classes, **params)
|
|
|
|
|
|
|
|
|
class ResNet200_vd(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, **params):
|
|
|
super(ResNet200_vd, self).__init__(
|
|
|
- model_name='ResNet200_vd', num_classes=num_classes)
|
|
|
+ model_name='ResNet200_vd', num_classes=num_classes, **params)
|
|
|
|
|
|
|
|
|
class AlexNet(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, **params):
|
|
|
super(AlexNet, self).__init__(
|
|
|
- model_name='AlexNet', num_classes=num_classes)
|
|
|
+ model_name='AlexNet', num_classes=num_classes, **params)
|
|
|
|
|
|
def _get_test_inputs(self, image_shape):
|
|
|
if image_shape is not None:
|
|
|
@@ -580,13 +586,13 @@ class AlexNet(BaseClassifier):
|
|
|
|
|
|
|
|
|
class DarkNet53(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, **params):
|
|
|
super(DarkNet53, self).__init__(
|
|
|
- model_name='DarkNet53', num_classes=num_classes)
|
|
|
+ model_name='DarkNet53', num_classes=num_classes, **params)
|
|
|
|
|
|
|
|
|
class MobileNetV1(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000, scale=1.0):
|
|
|
+ def __init__(self, num_classes=1000, scale=1.0, **params):
|
|
|
supported_scale = [.25, .5, .75, 1.0]
|
|
|
if scale not in supported_scale:
|
|
|
logging.warning("scale={} is not supported by MobileNetV1, "
|
|
|
@@ -598,11 +604,11 @@ class MobileNetV1(BaseClassifier):
|
|
|
model_name = 'MobileNetV1_x' + str(scale).replace('.', '_')
|
|
|
self.scale = scale
|
|
|
super(MobileNetV1, self).__init__(
|
|
|
- model_name=model_name, num_classes=num_classes)
|
|
|
+ model_name=model_name, num_classes=num_classes, **params)
|
|
|
|
|
|
|
|
|
class MobileNetV2(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000, scale=1.0):
|
|
|
+ def __init__(self, num_classes=1000, scale=1.0, **params):
|
|
|
supported_scale = [.25, .5, .75, 1.0, 1.5, 2.0]
|
|
|
if scale not in supported_scale:
|
|
|
logging.warning("scale={} is not supported by MobileNetV2, "
|
|
|
@@ -613,11 +619,11 @@ class MobileNetV2(BaseClassifier):
|
|
|
else:
|
|
|
model_name = 'MobileNetV2_x' + str(scale).replace('.', '_')
|
|
|
super(MobileNetV2, self).__init__(
|
|
|
- model_name=model_name, num_classes=num_classes)
|
|
|
+ model_name=model_name, num_classes=num_classes, **params)
|
|
|
|
|
|
|
|
|
class MobileNetV3_small(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000, scale=1.0):
|
|
|
+ def __init__(self, num_classes=1000, scale=1.0, **params):
|
|
|
supported_scale = [.35, .5, .75, 1.0, 1.25]
|
|
|
if scale not in supported_scale:
|
|
|
logging.warning("scale={} is not supported by MobileNetV3_small, "
|
|
|
@@ -626,11 +632,11 @@ class MobileNetV3_small(BaseClassifier):
|
|
|
model_name = 'MobileNetV3_small_x' + str(float(scale)).replace('.',
|
|
|
'_')
|
|
|
super(MobileNetV3_small, self).__init__(
|
|
|
- model_name=model_name, num_classes=num_classes)
|
|
|
+ model_name=model_name, num_classes=num_classes, **params)
|
|
|
|
|
|
|
|
|
class MobileNetV3_small_ssld(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000, scale=1.0):
|
|
|
+ def __init__(self, num_classes=1000, scale=1.0, **params):
|
|
|
supported_scale = [.35, 1.0]
|
|
|
if scale not in supported_scale:
|
|
|
logging.warning(
|
|
|
@@ -640,12 +646,12 @@ class MobileNetV3_small_ssld(BaseClassifier):
|
|
|
model_name = 'MobileNetV3_small_x' + str(float(scale)).replace('.',
|
|
|
'_')
|
|
|
super(MobileNetV3_small_ssld, self).__init__(
|
|
|
- model_name=model_name, num_classes=num_classes)
|
|
|
+ model_name=model_name, num_classes=num_classes, **params)
|
|
|
self.model_name = model_name + '_ssld'
|
|
|
|
|
|
|
|
|
class MobileNetV3_large(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000, scale=1.0):
|
|
|
+ def __init__(self, num_classes=1000, scale=1.0, **params):
|
|
|
supported_scale = [.35, .5, .75, 1.0, 1.25]
|
|
|
if scale not in supported_scale:
|
|
|
logging.warning("scale={} is not supported by MobileNetV3_large, "
|
|
|
@@ -654,108 +660,110 @@ class MobileNetV3_large(BaseClassifier):
|
|
|
model_name = 'MobileNetV3_large_x' + str(float(scale)).replace('.',
|
|
|
'_')
|
|
|
super(MobileNetV3_large, self).__init__(
|
|
|
- model_name=model_name, num_classes=num_classes)
|
|
|
+ model_name=model_name, num_classes=num_classes, **params)
|
|
|
|
|
|
|
|
|
class MobileNetV3_large_ssld(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, **params):
|
|
|
super(MobileNetV3_large_ssld, self).__init__(
|
|
|
- model_name='MobileNetV3_large_x1_0', num_classes=num_classes)
|
|
|
+ model_name='MobileNetV3_large_x1_0',
|
|
|
+ num_classes=num_classes,
|
|
|
+ **params)
|
|
|
self.model_name = 'MobileNetV3_large_x1_0_ssld'
|
|
|
|
|
|
|
|
|
class DenseNet121(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, **params):
|
|
|
super(DenseNet121, self).__init__(
|
|
|
- model_name='DenseNet121', num_classes=num_classes)
|
|
|
+ model_name='DenseNet121', num_classes=num_classes, **params)
|
|
|
|
|
|
|
|
|
class DenseNet161(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, **params):
|
|
|
super(DenseNet161, self).__init__(
|
|
|
- model_name='DenseNet161', num_classes=num_classes)
|
|
|
+ model_name='DenseNet161', num_classes=num_classes, **params)
|
|
|
|
|
|
|
|
|
class DenseNet169(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, **params):
|
|
|
super(DenseNet169, self).__init__(
|
|
|
- model_name='DenseNet169', num_classes=num_classes)
|
|
|
+ model_name='DenseNet169', num_classes=num_classes, **params)
|
|
|
|
|
|
|
|
|
class DenseNet201(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, **params):
|
|
|
super(DenseNet201, self).__init__(
|
|
|
- model_name='DenseNet201', num_classes=num_classes)
|
|
|
+ model_name='DenseNet201', num_classes=num_classes, **params)
|
|
|
|
|
|
|
|
|
class DenseNet264(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, **params):
|
|
|
super(DenseNet264, self).__init__(
|
|
|
- model_name='DenseNet264', num_classes=num_classes)
|
|
|
+ model_name='DenseNet264', num_classes=num_classes, **params)
|
|
|
|
|
|
|
|
|
class HRNet_W18_C(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, **params):
|
|
|
super(HRNet_W18_C, self).__init__(
|
|
|
- model_name='HRNet_W18_C', num_classes=num_classes)
|
|
|
+ model_name='HRNet_W18_C', num_classes=num_classes, **params)
|
|
|
|
|
|
|
|
|
class HRNet_W30_C(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, **params):
|
|
|
super(HRNet_W30_C, self).__init__(
|
|
|
- model_name='HRNet_W30_C', num_classes=num_classes)
|
|
|
+ model_name='HRNet_W30_C', num_classes=num_classes, **params)
|
|
|
|
|
|
|
|
|
class HRNet_W32_C(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, **params):
|
|
|
super(HRNet_W32_C, self).__init__(
|
|
|
- model_name='HRNet_W32_C', num_classes=num_classes)
|
|
|
+ model_name='HRNet_W32_C', num_classes=num_classes, **params)
|
|
|
|
|
|
|
|
|
class HRNet_W40_C(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, **params):
|
|
|
super(HRNet_W40_C, self).__init__(
|
|
|
- model_name='HRNet_W40_C', num_classes=num_classes)
|
|
|
+ model_name='HRNet_W40_C', num_classes=num_classes, **params)
|
|
|
|
|
|
|
|
|
class HRNet_W44_C(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, **params):
|
|
|
super(HRNet_W44_C, self).__init__(
|
|
|
- model_name='HRNet_W44_C', num_classes=num_classes)
|
|
|
+ model_name='HRNet_W44_C', num_classes=num_classes, **params)
|
|
|
|
|
|
|
|
|
class HRNet_W48_C(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, **params):
|
|
|
super(HRNet_W48_C, self).__init__(
|
|
|
- model_name='HRNet_W48_C', num_classes=num_classes)
|
|
|
+ model_name='HRNet_W48_C', num_classes=num_classes, **params)
|
|
|
|
|
|
|
|
|
class HRNet_W64_C(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, **params):
|
|
|
super(HRNet_W64_C, self).__init__(
|
|
|
- model_name='HRNet_W64_C', num_classes=num_classes)
|
|
|
+ model_name='HRNet_W64_C', num_classes=num_classes, **params)
|
|
|
|
|
|
|
|
|
class Xception41(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, **params):
|
|
|
super(Xception41, self).__init__(
|
|
|
- model_name='Xception41', num_classes=num_classes)
|
|
|
+ model_name='Xception41', num_classes=num_classes, **params)
|
|
|
|
|
|
|
|
|
class Xception65(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, **params):
|
|
|
super(Xception65, self).__init__(
|
|
|
- model_name='Xception65', num_classes=num_classes)
|
|
|
+ model_name='Xception65', num_classes=num_classes, **params)
|
|
|
|
|
|
|
|
|
class Xception71(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, **params):
|
|
|
super(Xception71, self).__init__(
|
|
|
- model_name='Xception71', num_classes=num_classes)
|
|
|
+ model_name='Xception71', num_classes=num_classes, **params)
|
|
|
|
|
|
|
|
|
class ShuffleNetV2(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000, scale=1.0):
|
|
|
+ def __init__(self, num_classes=1000, scale=1.0, **params):
|
|
|
supported_scale = [.25, .33, .5, 1.0, 1.5, 2.0]
|
|
|
if scale not in supported_scale:
|
|
|
logging.warning("scale={} is not supported by ShuffleNetV2, "
|
|
|
@@ -763,7 +771,7 @@ class ShuffleNetV2(BaseClassifier):
|
|
|
scale = 1.0
|
|
|
model_name = 'ShuffleNetV2_x' + str(float(scale)).replace('.', '_')
|
|
|
super(ShuffleNetV2, self).__init__(
|
|
|
- model_name=model_name, num_classes=num_classes)
|
|
|
+ model_name=model_name, num_classes=num_classes, **params)
|
|
|
|
|
|
def _get_test_inputs(self, image_shape):
|
|
|
if image_shape is not None:
|
|
|
@@ -788,9 +796,9 @@ class ShuffleNetV2(BaseClassifier):
|
|
|
|
|
|
|
|
|
class ShuffleNetV2_swish(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, **params):
|
|
|
super(ShuffleNetV2_swish, self).__init__(
|
|
|
- model_name='ShuffleNetV2_x1_5', num_classes=num_classes)
|
|
|
+ model_name='ShuffleNetV2_x1_5', num_classes=num_classes, **params)
|
|
|
|
|
|
def _get_test_inputs(self, image_shape):
|
|
|
if image_shape is not None:
|