|
@@ -24,6 +24,7 @@ from paddle.static import InputSpec
|
|
|
from paddlex.utils import logging, TrainingStats, DisablePrint
|
|
from paddlex.utils import logging, TrainingStats, DisablePrint
|
|
|
from paddlex.cv.models.base import BaseModel
|
|
from paddlex.cv.models.base import BaseModel
|
|
|
from paddlex.cv.transforms import arrange_transforms
|
|
from paddlex.cv.transforms import arrange_transforms
|
|
|
|
|
+from paddlex.cv.transforms.operators import Resize
|
|
|
|
|
|
|
|
with DisablePrint():
|
|
with DisablePrint():
|
|
|
from PaddleClas.ppcls.modeling import architectures
|
|
from PaddleClas.ppcls.modeling import architectures
|
|
@@ -51,7 +52,8 @@ class BaseClassifier(BaseModel):
|
|
|
def __init__(self, model_name='ResNet50', num_classes=1000, **params):
|
|
def __init__(self, model_name='ResNet50', num_classes=1000, **params):
|
|
|
self.init_params = locals()
|
|
self.init_params = locals()
|
|
|
self.init_params.update(params)
|
|
self.init_params.update(params)
|
|
|
- del self.init_params['params']
|
|
|
|
|
|
|
+ if 'lr_mult_list' in self.init_params:
|
|
|
|
|
+ del self.init_params['lr_mult_list']
|
|
|
super(BaseClassifier, self).__init__('classifier')
|
|
super(BaseClassifier, self).__init__('classifier')
|
|
|
if not hasattr(architectures, model_name):
|
|
if not hasattr(architectures, model_name):
|
|
|
raise Exception("ERROR: There's no model named {}.".format(
|
|
raise Exception("ERROR: There's no model named {}.".format(
|
|
@@ -70,10 +72,28 @@ class BaseClassifier(BaseModel):
|
|
|
class_dim=self.num_classes, **params)
|
|
class_dim=self.num_classes, **params)
|
|
|
return net
|
|
return net
|
|
|
|
|
|
|
|
- def get_test_inputs(self, image_shape):
|
|
|
|
|
|
|
+ def _fix_transforms_shape(self, image_shape):
|
|
|
|
|
+ if hasattr(self, 'test_transforms'):
|
|
|
|
|
+ if self.test_transforms is not None:
|
|
|
|
|
+ normalize_op_idx = len(self.test_transforms.transforms)
|
|
|
|
|
+ for idx, op in enumerate(self.test_transforms.transforms):
|
|
|
|
|
+ name = op.__class__.__name__
|
|
|
|
|
+ if name == 'Normalize':
|
|
|
|
|
+ normalize_op_idx = idx
|
|
|
|
|
+
|
|
|
|
|
+ self.test_transforms.transforms.insert(
|
|
|
|
|
+ normalize_op_idx, Resize(target_size=image_shape))
|
|
|
|
|
+
|
|
|
|
|
+ def _get_test_inputs(self, image_shape):
|
|
|
|
|
+ if image_shape is not None:
|
|
|
|
|
+ if len(image_shape) == 2:
|
|
|
|
|
+ image_shape = [None, 3] + image_shape
|
|
|
|
|
+ self._fix_transforms_shape(image_shape[-2:])
|
|
|
|
|
+ else:
|
|
|
|
|
+ image_shape = [None, 3, -1, -1]
|
|
|
input_spec = [
|
|
input_spec = [
|
|
|
InputSpec(
|
|
InputSpec(
|
|
|
- shape=[None, 3] + image_shape, name='image', dtype='float32')
|
|
|
|
|
|
|
+ shape=image_shape, name='image', dtype='float32')
|
|
|
]
|
|
]
|
|
|
return input_spec
|
|
return input_spec
|
|
|
|
|
|
|
@@ -445,15 +465,18 @@ class AlexNet(BaseClassifier):
|
|
|
model_name='AlexNet', num_classes=num_classes)
|
|
model_name='AlexNet', num_classes=num_classes)
|
|
|
|
|
|
|
|
def get_test_inputs(self, image_shape):
|
|
def get_test_inputs(self, image_shape):
|
|
|
- if image_shape == [-1, -1]:
|
|
|
|
|
|
|
+ if image_shape is not None:
|
|
|
|
|
+ if len(image_shape) == 2:
|
|
|
|
|
+ image_shape = [None, 3] + image_shape
|
|
|
|
|
+ else:
|
|
|
image_shape = [224, 224]
|
|
image_shape = [224, 224]
|
|
|
logging.info('When exporting inference model for {},'.format(
|
|
logging.info('When exporting inference model for {},'.format(
|
|
|
self.__class__.__name__
|
|
self.__class__.__name__
|
|
|
- ) + ' if image_shape is [-1, -1], it will be forcibly set to [224, 224]'
|
|
|
|
|
|
|
+ ) + ' if fixed_input_shape is not set, it will be forcibly set to [None, 3, 224, 224]'
|
|
|
)
|
|
)
|
|
|
input_spec = [
|
|
input_spec = [
|
|
|
InputSpec(
|
|
InputSpec(
|
|
|
- shape=[None, 3] + image_shape, name='image', dtype='float32')
|
|
|
|
|
|
|
+ shape=image_shape, name='image', dtype='float32')
|
|
|
]
|
|
]
|
|
|
return input_spec
|
|
return input_spec
|
|
|
|
|
|
|
@@ -623,15 +646,18 @@ class ShuffleNetV2(BaseClassifier):
|
|
|
model_name=model_name, num_classes=num_classes)
|
|
model_name=model_name, num_classes=num_classes)
|
|
|
|
|
|
|
|
def get_test_inputs(self, image_shape):
|
|
def get_test_inputs(self, image_shape):
|
|
|
- if image_shape == [-1, -1]:
|
|
|
|
|
|
|
+ if image_shape is not None:
|
|
|
|
|
+ if len(image_shape) == 2:
|
|
|
|
|
+ image_shape = [None, 3] + image_shape
|
|
|
|
|
+ else:
|
|
|
image_shape = [224, 224]
|
|
image_shape = [224, 224]
|
|
|
logging.info('When exporting inference model for {},'.format(
|
|
logging.info('When exporting inference model for {},'.format(
|
|
|
self.__class__.__name__
|
|
self.__class__.__name__
|
|
|
- ) + ' if image_shape is [-1, -1], it will be forcibly set to [224, 224]'
|
|
|
|
|
|
|
+ ) + ' if fixed_input_shape is not set, it will be forcibly set to [None, 3, 224, 224]'
|
|
|
)
|
|
)
|
|
|
input_spec = [
|
|
input_spec = [
|
|
|
InputSpec(
|
|
InputSpec(
|
|
|
- shape=[None, 3] + image_shape, name='image', dtype='float32')
|
|
|
|
|
|
|
+ shape=image_shape, name='image', dtype='float32')
|
|
|
]
|
|
]
|
|
|
return input_spec
|
|
return input_spec
|
|
|
|
|
|
|
@@ -642,14 +668,17 @@ class ShuffleNetV2_swish(BaseClassifier):
|
|
|
model_name='ShuffleNetV2_x1_5', num_classes=num_classes)
|
|
model_name='ShuffleNetV2_x1_5', num_classes=num_classes)
|
|
|
|
|
|
|
|
def get_test_inputs(self, image_shape):
|
|
def get_test_inputs(self, image_shape):
|
|
|
- if image_shape == [-1, -1]:
|
|
|
|
|
|
|
+ if image_shape is not None:
|
|
|
|
|
+ if len(image_shape) == 2:
|
|
|
|
|
+ image_shape = [None, 3] + image_shape
|
|
|
|
|
+ else:
|
|
|
image_shape = [224, 224]
|
|
image_shape = [224, 224]
|
|
|
logging.info('When exporting inference model for {},'.format(
|
|
logging.info('When exporting inference model for {},'.format(
|
|
|
self.__class__.__name__
|
|
self.__class__.__name__
|
|
|
- ) + ' if image_shape is [-1, -1], it will be forcibly set to [224, 224]'
|
|
|
|
|
|
|
+ ) + ' if fixed_input_shape is not set, it will be forcibly set to [None, 3, 224, 224]'
|
|
|
)
|
|
)
|
|
|
input_spec = [
|
|
input_spec = [
|
|
|
InputSpec(
|
|
InputSpec(
|
|
|
- shape=[None, 3] + image_shape, name='image', dtype='float32')
|
|
|
|
|
|
|
+ shape=image_shape, name='image', dtype='float32')
|
|
|
]
|
|
]
|
|
|
return input_spec
|
|
return input_spec
|