|
|
@@ -404,33 +404,43 @@ class BaseClassifier(BaseAPI):
|
|
|
|
|
|
|
|
|
class ResNet18(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, input_channel=3):
|
|
|
super(ResNet18, self).__init__(
|
|
|
- model_name='ResNet18', num_classes=num_classes)
|
|
|
+ model_name='ResNet18',
|
|
|
+ num_classes=num_classes,
|
|
|
+ input_channel=input_channel)
|
|
|
|
|
|
|
|
|
class ResNet34(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, input_channel=3):
|
|
|
super(ResNet34, self).__init__(
|
|
|
- model_name='ResNet34', num_classes=num_classes)
|
|
|
+ model_name='ResNet34',
|
|
|
+ num_classes=num_classes,
|
|
|
+ input_channel=input_channel)
|
|
|
|
|
|
|
|
|
class ResNet50(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, input_channel=3):
|
|
|
super(ResNet50, self).__init__(
|
|
|
- model_name='ResNet50', num_classes=num_classes)
|
|
|
+ model_name='ResNet50',
|
|
|
+ num_classes=num_classes,
|
|
|
+ input_channel=input_channel)
|
|
|
|
|
|
|
|
|
class ResNet101(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, input_channel=3):
|
|
|
super(ResNet101, self).__init__(
|
|
|
- model_name='ResNet101', num_classes=num_classes)
|
|
|
+ model_name='ResNet101',
|
|
|
+ num_classes=num_classes,
|
|
|
+ input_channel=input_channel)
|
|
|
|
|
|
|
|
|
class ResNet50_vd(BaseClassifier):
|
|
|
- def __init__(self, num_classes=1000):
|
|
|
+ def __init__(self, num_classes=1000, input_channel=3):
|
|
|
super(ResNet50_vd, self).__init__(
|
|
|
- model_name='ResNet50_vd', num_classes=num_classes)
|
|
|
+ model_name='ResNet50_vd',
|
|
|
+ num_classes=num_classes,
|
|
|
+ input_channel=input_channel)
|
|
|
|
|
|
def train(self,
|
|
|
num_epochs,
|