Selaa lähdekoodia

add input_channel for resetnet18-50

FlyingQianMM 4 vuotta sitten
vanhempi
commit
86443d0e38
1 muutettua tiedostoa jossa 20 lisäystä ja 10 poistoa
  1. 20 10
      paddlex/cv/models/classifier.py

+ 20 - 10
paddlex/cv/models/classifier.py

@@ -404,33 +404,43 @@ class BaseClassifier(BaseAPI):
 
 
 class ResNet18(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, input_channel=3):
         super(ResNet18, self).__init__(
-            model_name='ResNet18', num_classes=num_classes)
+            model_name='ResNet18',
+            num_classes=num_classes,
+            input_channel=input_channel)
 
 
 class ResNet34(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, input_channel=3):
         super(ResNet34, self).__init__(
-            model_name='ResNet34', num_classes=num_classes)
+            model_name='ResNet34',
+            num_classes=num_classes,
+            input_channel=input_channel)
 
 
 class ResNet50(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, input_channel=3):
         super(ResNet50, self).__init__(
-            model_name='ResNet50', num_classes=num_classes)
+            model_name='ResNet50',
+            num_classes=num_classes,
+            input_channel=input_channel)
 
 
 class ResNet101(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, input_channel=3):
         super(ResNet101, self).__init__(
-            model_name='ResNet101', num_classes=num_classes)
+            model_name='ResNet101',
+            num_classes=num_classes,
+            input_channel=input_channel)
 
 
 class ResNet50_vd(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, input_channel=3):
         super(ResNet50_vd, self).__init__(
-            model_name='ResNet50_vd', num_classes=num_classes)
+            model_name='ResNet50_vd',
+            num_classes=num_classes,
+            input_channel=input_channel)
 
     def train(self,
               num_epochs,