|
|
@@ -129,9 +129,7 @@ class BaseClassifier(BaseAPI):
|
|
|
ValueError: 模型从inference model进行加载。
|
|
|
"""
|
|
|
if not self.trainable:
|
|
|
- raise ValueError(
|
|
|
- "Model is not trainable since it was loaded from a inference model."
|
|
|
- )
|
|
|
+ raise ValueError("Model is not trainable from load_model method.")
|
|
|
self.labels = train_dataset.labels
|
|
|
if optimizer is None:
|
|
|
num_steps_each_epoch = train_dataset.num_samples // train_batch_size
|
|
|
@@ -300,17 +298,18 @@ class ResNet101_vd(BaseClassifier):
|
|
|
def __init__(self, num_classes=1000):
|
|
|
super(ResNet101_vd, self).__init__(
|
|
|
model_name='ResNet101_vd', num_classes=num_classes)
|
|
|
-
|
|
|
-
|
|
|
+
|
|
|
+
|
|
|
class ResNet50_vd_ssld(BaseClassifier):
|
|
|
def __init__(self, num_classes=1000):
|
|
|
- super(ResNet50_vd_ssld, self).__init__(model_name='ResNet50_vd_ssld',
|
|
|
- num_classes=num_classes)
|
|
|
-
|
|
|
+ super(ResNet50_vd_ssld, self).__init__(
|
|
|
+ model_name='ResNet50_vd_ssld', num_classes=num_classes)
|
|
|
+
|
|
|
+
|
|
|
class ResNet101_vd_ssld(BaseClassifier):
|
|
|
def __init__(self, num_classes=1000):
|
|
|
- super(ResNet101_vd_ssld, self).__init__(model_name='ResNet101_vd_ssld',
|
|
|
- num_classes=num_classes)
|
|
|
+ super(ResNet101_vd_ssld, self).__init__(
|
|
|
+ model_name='ResNet101_vd_ssld', num_classes=num_classes)
|
|
|
|
|
|
|
|
|
class DarkNet53(BaseClassifier):
|
|
|
@@ -341,19 +340,18 @@ class MobileNetV3_large(BaseClassifier):
|
|
|
def __init__(self, num_classes=1000):
|
|
|
super(MobileNetV3_large, self).__init__(
|
|
|
model_name='MobileNetV3_large', num_classes=num_classes)
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
+
|
|
|
+
|
|
|
class MobileNetV3_small_ssld(BaseClassifier):
|
|
|
def __init__(self, num_classes=1000):
|
|
|
- super(MobileNetV3_small_ssld, self).__init__(model_name='MobileNetV3_small_ssld',
|
|
|
- num_classes=num_classes)
|
|
|
+ super(MobileNetV3_small_ssld, self).__init__(
|
|
|
+ model_name='MobileNetV3_small_ssld', num_classes=num_classes)
|
|
|
|
|
|
|
|
|
class MobileNetV3_large_ssld(BaseClassifier):
|
|
|
def __init__(self, num_classes=1000):
|
|
|
- super(MobileNetV3_large_ssld, self).__init__(model_name='MobileNetV3_large_ssld',
|
|
|
- num_classes=num_classes)
|
|
|
+ super(MobileNetV3_large_ssld, self).__init__(
|
|
|
+ model_name='MobileNetV3_large_ssld', num_classes=num_classes)
|
|
|
|
|
|
|
|
|
class Xception65(BaseClassifier):
|