|
|
@@ -37,9 +37,13 @@ class BaseClassifier(BaseAPI):
|
|
|
'MobileNetV1', 'MobileNetV2', 'Xception41',
|
|
|
'Xception65', 'Xception71']。默认为'ResNet50'。
|
|
|
num_classes (int): 类别数。默认为1000。
|
|
|
+ input_channel (int): 输入图像的通道数量。默认为3。
|
|
|
"""
|
|
|
|
|
|
- def __init__(self, model_name='ResNet50', num_classes=1000):
|
|
|
+ def __init__(self,
|
|
|
+ model_name='ResNet50',
|
|
|
+ num_classes=1000,
|
|
|
+ input_channel=3):
|
|
|
self.init_params = locals()
|
|
|
super(BaseClassifier, self).__init__('classifier')
|
|
|
if not hasattr(paddlex.cv.nets, str.lower(model_name)):
|
|
|
@@ -49,19 +53,23 @@ class BaseClassifier(BaseAPI):
|
|
|
self.labels = None
|
|
|
self.num_classes = num_classes
|
|
|
self.fixed_input_shape = None
|
|
|
+ self.input_channel = input_channel
|
|
|
|
|
|
def build_net(self, mode='train'):
|
|
|
if self.__class__.__name__ == "AlexNet":
|
|
|
assert self.fixed_input_shape is not None, "In AlexNet, input_shape should be defined, e.g. model = paddlex.cls.AlexNet(num_classes=1000, input_shape=[224, 224])"
|
|
|
if self.fixed_input_shape is not None:
|
|
|
input_shape = [
|
|
|
- None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0]
|
|
|
+ None, self.input_channel, self.fixed_input_shape[1],
|
|
|
+ self.fixed_input_shape[0]
|
|
|
]
|
|
|
image = fluid.data(
|
|
|
dtype='float32', shape=input_shape, name='image')
|
|
|
else:
|
|
|
image = fluid.data(
|
|
|
- dtype='float32', shape=[None, 3, None, None], name='image')
|
|
|
+ dtype='float32',
|
|
|
+ shape=[None, self.input_channel, None, None],
|
|
|
+ name='image')
|
|
|
if mode != 'test':
|
|
|
label = fluid.data(dtype='int64', shape=[None, 1], name='label')
|
|
|
model = getattr(paddlex.cv.nets, str.lower(self.model_name))
|
|
|
@@ -223,11 +231,13 @@ class BaseClassifier(BaseAPI):
|
|
|
tuple (metrics, eval_details): 当return_details为True时,增加返回dict,
|
|
|
包含关键字:'true_labels'、'pred_scores',分别代表真实类别id、每个类别的预测得分。
|
|
|
"""
|
|
|
+ input_channel = getattr(self, 'input_channel', 3)
|
|
|
arrange_transforms(
|
|
|
model_type=self.model_type,
|
|
|
class_name=self.__class__.__name__,
|
|
|
transforms=eval_dataset.transforms,
|
|
|
- mode='eval')
|
|
|
+ mode='eval',
|
|
|
+ input_channel=input_channel)
|
|
|
data_generator = eval_dataset.generator(
|
|
|
batch_size=batch_size, drop_last=False)
|
|
|
k = min(5, self.num_classes)
|
|
|
@@ -283,12 +293,14 @@ class BaseClassifier(BaseAPI):
|
|
|
transforms,
|
|
|
model_type,
|
|
|
class_name,
|
|
|
- thread_pool=None):
|
|
|
+ thread_pool=None,
|
|
|
+ input_channel=3):
|
|
|
arrange_transforms(
|
|
|
model_type=model_type,
|
|
|
class_name=class_name,
|
|
|
transforms=transforms,
|
|
|
- mode='test')
|
|
|
+ mode='test',
|
|
|
+ input_channel=input_channel)
|
|
|
if thread_pool is not None:
|
|
|
batch_data = thread_pool.map(transforms, images)
|
|
|
else:
|
|
|
@@ -334,8 +346,13 @@ class BaseClassifier(BaseAPI):
|
|
|
|
|
|
if transforms is None:
|
|
|
transforms = self.test_transforms
|
|
|
- im = BaseClassifier._preprocess(images, transforms, self.model_type,
|
|
|
- self.__class__.__name__)
|
|
|
+ input_channel = getattr(self, 'input_channel', 3)
|
|
|
+ im = BaseClassifier._preprocess(
|
|
|
+ images,
|
|
|
+ transforms,
|
|
|
+ self.model_type,
|
|
|
+ self.__class__.__name__,
|
|
|
+ input_channel=input_channel)
|
|
|
|
|
|
with fluid.scope_guard(self.scope):
|
|
|
result = self.exe.run(self.test_prog,
|
|
|
@@ -366,9 +383,14 @@ class BaseClassifier(BaseAPI):
|
|
|
|
|
|
if transforms is None:
|
|
|
transforms = self.test_transforms
|
|
|
+ input_channel = getattr(self, 'input_channel', 3)
|
|
|
im = BaseClassifier._preprocess(
|
|
|
- img_file_list, transforms, self.model_type,
|
|
|
- self.__class__.__name__, self.thread_pool)
|
|
|
+ img_file_list,
|
|
|
+ transforms,
|
|
|
+ self.model_type,
|
|
|
+ self.__class__.__name__,
|
|
|
+ self.thread_pool,
|
|
|
+ input_channel=input_channel)
|
|
|
|
|
|
with fluid.scope_guard(self.scope):
|
|
|
result = self.exe.run(self.test_prog,
|
|
|
@@ -470,109 +492,145 @@ class ResNet50_vd(BaseClassifier):
|
|
|
|
|
|
|
|
|
class ResNet101_vd(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, input_channel=3):
|
|
|
super(ResNet101_vd, self).__init__(
|
|
|
- model_name='ResNet101_vd', num_classes=num_classes)
|
|
|
+ model_name='ResNet101_vd',
|
|
|
+ num_classes=num_classes,
|
|
|
+ input_channel=input_channel)
|
|
|
|
|
|
|
|
|
class ResNet50_vd_ssld(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, input_channel=3):
|
|
|
super(ResNet50_vd_ssld, self).__init__(
|
|
|
- model_name='ResNet50_vd_ssld', num_classes=num_classes)
|
|
|
+ model_name='ResNet50_vd_ssld',
|
|
|
+ num_classes=num_classes,
|
|
|
+ input_channel=input_channel)
|
|
|
|
|
|
|
|
|
class ResNet101_vd_ssld(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, input_channel=3):
|
|
|
super(ResNet101_vd_ssld, self).__init__(
|
|
|
- model_name='ResNet101_vd_ssld', num_classes=num_classes)
|
|
|
+ model_name='ResNet101_vd_ssld',
|
|
|
+ num_classes=num_classes,
|
|
|
+ input_channel=input_channel)
|
|
|
|
|
|
|
|
|
class DarkNet53(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, input_channel=3):
|
|
|
super(DarkNet53, self).__init__(
|
|
|
- model_name='DarkNet53', num_classes=num_classes)
|
|
|
+ model_name='DarkNet53',
|
|
|
+ num_classes=num_classes,
|
|
|
+ input_channel=input_channel)
|
|
|
|
|
|
|
|
|
class MobileNetV1(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, input_channel=3):
|
|
|
super(MobileNetV1, self).__init__(
|
|
|
- model_name='MobileNetV1', num_classes=num_classes)
|
|
|
+ model_name='MobileNetV1',
|
|
|
+ num_classes=num_classes,
|
|
|
+ input_channel=input_channel)
|
|
|
|
|
|
|
|
|
class MobileNetV2(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, input_channel=3):
|
|
|
super(MobileNetV2, self).__init__(
|
|
|
- model_name='MobileNetV2', num_classes=num_classes)
|
|
|
+ model_name='MobileNetV2',
|
|
|
+ num_classes=num_classes,
|
|
|
+ input_channel=input_channel)
|
|
|
|
|
|
|
|
|
class MobileNetV3_small(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, input_channel=3):
|
|
|
super(MobileNetV3_small, self).__init__(
|
|
|
- model_name='MobileNetV3_small', num_classes=num_classes)
|
|
|
+ model_name='MobileNetV3_small',
|
|
|
+ num_classes=num_classes,
|
|
|
+ input_channel=input_channel)
|
|
|
|
|
|
|
|
|
class MobileNetV3_large(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, input_channel=3):
|
|
|
super(MobileNetV3_large, self).__init__(
|
|
|
- model_name='MobileNetV3_large', num_classes=num_classes)
|
|
|
+ model_name='MobileNetV3_large',
|
|
|
+ num_classes=num_classes,
|
|
|
+ input_channel=input_channel)
|
|
|
|
|
|
|
|
|
class MobileNetV3_small_ssld(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, input_channel=3):
|
|
|
super(MobileNetV3_small_ssld, self).__init__(
|
|
|
- model_name='MobileNetV3_small_ssld', num_classes=num_classes)
|
|
|
+ model_name='MobileNetV3_small_ssld',
|
|
|
+ num_classes=num_classes,
|
|
|
+ input_channel=input_channel)
|
|
|
|
|
|
|
|
|
class MobileNetV3_large_ssld(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, input_channel=3):
|
|
|
super(MobileNetV3_large_ssld, self).__init__(
|
|
|
- model_name='MobileNetV3_large_ssld', num_classes=num_classes)
|
|
|
+ model_name='MobileNetV3_large_ssld',
|
|
|
+ num_classes=num_classes,
|
|
|
+ input_channel=input_channel)
|
|
|
|
|
|
|
|
|
class Xception65(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, input_channel=3):
|
|
|
super(Xception65, self).__init__(
|
|
|
- model_name='Xception65', num_classes=num_classes)
|
|
|
+ model_name='Xception65',
|
|
|
+ num_classes=num_classes,
|
|
|
+ input_channel=input_channel)
|
|
|
|
|
|
|
|
|
class Xception41(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, input_channel=3):
|
|
|
super(Xception41, self).__init__(
|
|
|
- model_name='Xception41', num_classes=num_classes)
|
|
|
+ model_name='Xception41',
|
|
|
+ num_classes=num_classes,
|
|
|
+ input_channel=input_channel)
|
|
|
|
|
|
|
|
|
class DenseNet121(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, input_channel=3):
|
|
|
super(DenseNet121, self).__init__(
|
|
|
- model_name='DenseNet121', num_classes=num_classes)
|
|
|
+ model_name='DenseNet121',
|
|
|
+ num_classes=num_classes,
|
|
|
+ input_channel=input_channel)
|
|
|
|
|
|
|
|
|
class DenseNet161(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, input_channel=3):
|
|
|
super(DenseNet161, self).__init__(
|
|
|
- model_name='DenseNet161', num_classes=num_classes)
|
|
|
+ model_name='DenseNet161',
|
|
|
+ num_classes=num_classes,
|
|
|
+ input_channel=input_channel)
|
|
|
|
|
|
|
|
|
class DenseNet201(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, input_channel=3):
|
|
|
super(DenseNet201, self).__init__(
|
|
|
- model_name='DenseNet201', num_classes=num_classes)
|
|
|
+ model_name='DenseNet201',
|
|
|
+ num_classes=num_classes,
|
|
|
+ input_channel=input_channel)
|
|
|
|
|
|
|
|
|
class ShuffleNetV2(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, input_channel=3):
|
|
|
super(ShuffleNetV2, self).__init__(
|
|
|
- model_name='ShuffleNetV2', num_classes=num_classes)
|
|
|
+ model_name='ShuffleNetV2',
|
|
|
+ num_classes=num_classes,
|
|
|
+ input_channel=input_channel)
|
|
|
|
|
|
|
|
|
class HRNet_W18(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, input_channel=3):
|
|
|
super(HRNet_W18, self).__init__(
|
|
|
- model_name='HRNet_W18', num_classes=num_classes)
|
|
|
+ model_name='HRNet_W18',
|
|
|
+ num_classes=num_classes,
|
|
|
+ input_channel=input_channel)
|
|
|
|
|
|
|
|
|
class AlexNet(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000, input_shape=None):
|
|
|
+ def __init__(self, num_classes=1000, input_shape=None, input_channel=3):
|
|
|
super(AlexNet, self).__init__(
|
|
|
- model_name='AlexNet', num_classes=num_classes)
|
|
|
+ model_name='AlexNet',
|
|
|
+ num_classes=num_classes,
|
|
|
+ input_channel=input_channel)
|
|
|
self.fixed_input_shape = input_shape
|