|
|
@@ -22,8 +22,8 @@ import paddle.nn.functional as F
|
|
|
from paddle.static import InputSpec
|
|
|
from paddlex.utils import logging, TrainingStats
|
|
|
from paddlex.cv.models.base import BaseModel
|
|
|
-from paddlex.cv.nets.ppcls.modeling import architectures
|
|
|
-from paddlex.cv.nets.ppcls.modeling.loss import CELoss
|
|
|
+from PaddleClas.ppcls.modeling import architectures
|
|
|
+from PaddleClas.ppcls.modeling.loss import CELoss
|
|
|
from paddlex.cv.transforms import arrange_transforms
|
|
|
|
|
|
__all__ = [
|
|
|
@@ -399,7 +399,10 @@ class ResNet50_vd(BaseClassifier):
|
|
|
class ResNet50_vd_ssld(BaseClassifier):
|
|
|
def __init__(self, num_classes=1000):
|
|
|
super(ResNet50_vd_ssld, self).__init__(
|
|
|
- model_name='ResNet50_vd_ssld', num_classes=num_classes)
|
|
|
+ model_name='ResNet50_vd',
|
|
|
+ num_classes=num_classes,
|
|
|
+ lr_mult_list=[.1, .1, .2, .2, .3])
|
|
|
+ self.model_name = 'ResNet50_vd_ssld'
|
|
|
|
|
|
|
|
|
class ResNet101_vd(BaseClassifier):
|
|
|
@@ -411,7 +414,10 @@ class ResNet101_vd(BaseClassifier):
|
|
|
class ResNet101_vd_ssld(BaseClassifier):
|
|
|
def __init__(self, num_classes=1000):
|
|
|
super(ResNet101_vd_ssld, self).__init__(
|
|
|
- model_name='ResNet101_vd_ssld', num_classes=num_classes)
|
|
|
+ model_name='ResNet101_vd_ssld',
|
|
|
+ num_classes=num_classes,
|
|
|
+ lr_mult_list=[.1, .1, .2, .2, .3])
|
|
|
+ self.model_name = 'ResNet101_vd_ssld'
|
|
|
|
|
|
|
|
|
class ResNet152_vd(BaseClassifier):
|
|
|
@@ -458,9 +464,13 @@ class MobileNetV1(BaseClassifier):
|
|
|
logging.warning("scale={} is not supported by MobileNetV1, "
|
|
|
"scale is forcibly set to 1.0".format(scale))
|
|
|
scale = 1.0
|
|
|
- params = {'scale': scale}
|
|
|
+ if scale == 1:
|
|
|
+ model_name = 'MobileNetV1'
|
|
|
+ else:
|
|
|
+ model_name = 'MobileNetV1_x' + str(scale).replace('.', '_')
|
|
|
+ self.scale = scale
|
|
|
super(MobileNetV1, self).__init__(
|
|
|
- model_name='MobileNetV1', num_classes=num_classes, **params)
|
|
|
+ model_name=model_name, num_classes=num_classes)
|
|
|
|
|
|
|
|
|
class MobileNetV2(BaseClassifier):
|
|
|
@@ -470,9 +480,12 @@ class MobileNetV2(BaseClassifier):
|
|
|
logging.warning("scale={} is not supported by MobileNetV2, "
|
|
|
"scale is forcibly set to 1.0".format(scale))
|
|
|
scale = 1.0
|
|
|
- params = {'scale': scale}
|
|
|
+ if scale == 1:
|
|
|
+ model_name = 'MobileNetV2'
|
|
|
+ else:
|
|
|
+ model_name = 'MobileNetV2_x' + str(scale).replace('.', '_')
|
|
|
super(MobileNetV2, self).__init__(
|
|
|
- model_name='MobileNetV2', num_classes=num_classes, **params)
|
|
|
+ model_name=model_name, num_classes=num_classes)
|
|
|
|
|
|
|
|
|
class MobileNetV3_small(BaseClassifier):
|
|
|
@@ -482,9 +495,10 @@ class MobileNetV3_small(BaseClassifier):
|
|
|
logging.warning("scale={} is not supported by MobileNetV3_small, "
|
|
|
"scale is forcibly set to 1.0".format(scale))
|
|
|
scale = 1.0
|
|
|
- params = {'scale': scale}
|
|
|
+ model_name = 'MobileNetV3_small_x' + str(float(scale)).replace('.',
|
|
|
+ '_')
|
|
|
super(MobileNetV3_small, self).__init__(
|
|
|
- model_name='MobileNetV3_small', num_classes=num_classes, **params)
|
|
|
+ model_name=model_name, num_classes=num_classes)
|
|
|
|
|
|
|
|
|
class MobileNetV3_large(BaseClassifier):
|
|
|
@@ -494,9 +508,10 @@ class MobileNetV3_large(BaseClassifier):
|
|
|
logging.warning("scale={} is not supported by MobileNetV3_large, "
|
|
|
"scale is forcibly set to 1.0".format(scale))
|
|
|
scale = 1.0
|
|
|
- params = {'scale': scale}
|
|
|
+ model_name = 'MobileNetV3_large_x' + str(float(scale)).replace('.',
|
|
|
+ '_')
|
|
|
super(MobileNetV3_large, self).__init__(
|
|
|
- model_name='MobileNetV3_large', num_classes=num_classes, **params)
|
|
|
+ model_name=model_name, num_classes=num_classes)
|
|
|
|
|
|
|
|
|
class DenseNet121(BaseClassifier):
|
|
|
@@ -596,9 +611,9 @@ class ShuffleNetV2(BaseClassifier):
|
|
|
logging.warning("scale={} is not supported by ShuffleNetV2, "
|
|
|
"scale is forcibly set to 1.0".format(scale))
|
|
|
scale = 1.0
|
|
|
- params = {'scale': scale}
|
|
|
+ model_name = 'ShuffleNetV2_x' + str(float(scale)).replace('.', '_')
|
|
|
super(ShuffleNetV2, self).__init__(
|
|
|
- model_name='ShuffleNetV2', num_classes=num_classes, **params)
|
|
|
+ model_name=model_name, num_classes=num_classes)
|
|
|
|
|
|
def get_test_inputs(self, image_shape):
|
|
|
if image_shape == [-1, -1]:
|