jiangjiajun пре 5 година
родитељ
комит
f02f04766e

+ 15 - 17
paddlex/cv/models/classifier.py

@@ -129,9 +129,7 @@ class BaseClassifier(BaseAPI):
             ValueError: 模型从inference model进行加载。
         """
         if not self.trainable:
-            raise ValueError(
-                "Model is not trainable since it was loaded from a inference model."
-            )
+            raise ValueError("Model is not trainable from load_model method.")
         self.labels = train_dataset.labels
         if optimizer is None:
             num_steps_each_epoch = train_dataset.num_samples // train_batch_size
@@ -300,17 +298,18 @@ class ResNet101_vd(BaseClassifier):
     def __init__(self, num_classes=1000):
         super(ResNet101_vd, self).__init__(
             model_name='ResNet101_vd', num_classes=num_classes)
-        
-        
+
+
 class ResNet50_vd_ssld(BaseClassifier):
     def __init__(self, num_classes=1000):
-        super(ResNet50_vd_ssld, self).__init__(model_name='ResNet50_vd_ssld',
-                                               num_classes=num_classes)
-        
+        super(ResNet50_vd_ssld, self).__init__(
+            model_name='ResNet50_vd_ssld', num_classes=num_classes)
+
+
 class ResNet101_vd_ssld(BaseClassifier):
     def __init__(self, num_classes=1000):
-        super(ResNet101_vd_ssld, self).__init__(model_name='ResNet101_vd_ssld',
-                                               num_classes=num_classes)
+        super(ResNet101_vd_ssld, self).__init__(
+            model_name='ResNet101_vd_ssld', num_classes=num_classes)
 
 
 class DarkNet53(BaseClassifier):
@@ -341,19 +340,18 @@ class MobileNetV3_large(BaseClassifier):
     def __init__(self, num_classes=1000):
         super(MobileNetV3_large, self).__init__(
             model_name='MobileNetV3_large', num_classes=num_classes)
-        
-        
-        
+
+
 class MobileNetV3_small_ssld(BaseClassifier):
     def __init__(self, num_classes=1000):
-        super(MobileNetV3_small_ssld, self).__init__(model_name='MobileNetV3_small_ssld',
-                                                num_classes=num_classes)
+        super(MobileNetV3_small_ssld, self).__init__(
+            model_name='MobileNetV3_small_ssld', num_classes=num_classes)
 
 
 class MobileNetV3_large_ssld(BaseClassifier):
     def __init__(self, num_classes=1000):
-        super(MobileNetV3_large_ssld, self).__init__(model_name='MobileNetV3_large_ssld',
-                                                num_classes=num_classes)
+        super(MobileNetV3_large_ssld, self).__init__(
+            model_name='MobileNetV3_large_ssld', num_classes=num_classes)
 
 
 class Xception65(BaseClassifier):

+ 1 - 3
paddlex/cv/models/deeplabv3p.py

@@ -257,9 +257,7 @@ class DeepLabv3p(BaseAPI):
             ValueError: 模型从inference model进行加载。
         """
         if not self.trainable:
-            raise ValueError(
-                "Model is not trainable since it was loaded from a inference model."
-            )
+            raise ValueError("Model is not trainable from load_model method.")
 
         self.labels = train_dataset.labels
 

+ 1 - 3
paddlex/cv/models/faster_rcnn.py

@@ -203,9 +203,7 @@ class FasterRCNN(BaseAPI):
         assert metric in ['COCO', 'VOC'], "Metric only support 'VOC' or 'COCO'"
         self.metric = metric
         if not self.trainable:
-            raise ValueError(
-                "Model is not trainable since it was loaded from a inference model."
-            )
+            raise ValueError("Model is not trainable from load_model method.")
         self.labels = copy.deepcopy(train_dataset.labels)
         self.labels.insert(0, 'background')
         # 构建训练网络

+ 1 - 0
paddlex/cv/models/load_model.py

@@ -98,6 +98,7 @@ def load_model(model_dir):
                 model.__dict__[k] = v
 
     logging.info("Model[{}] loaded.".format(info['Model']))
+    model.trainable = False
     return model
 
 

+ 1 - 3
paddlex/cv/models/mask_rcnn.py

@@ -165,9 +165,7 @@ class MaskRCNN(FasterRCNN):
         assert metric in ['COCO', 'VOC'], "Metric only support 'VOC' or 'COCO'"
         self.metric = metric
         if not self.trainable:
-            raise Exception(
-                "Model is not trainable since it was loaded from a inference model."
-            )
+            raise Exception("Model is not trainable from load_model method.")
         self.labels = copy.deepcopy(train_dataset.labels)
         self.labels.insert(0, 'background')
         # 构建训练网络

+ 1 - 3
paddlex/cv/models/yolo_v3.py

@@ -194,9 +194,7 @@ class YOLOv3(BaseAPI):
             ValueError: 模型从inference model进行加载。
         """
         if not self.trainable:
-            raise ValueError(
-                "Model is not trainable since it was loaded from a inference model."
-            )
+            raise ValueError("Model is not trainable from load_model method.")
         if metric is None:
             if isinstance(train_dataset, paddlex.datasets.CocoDetection):
                 metric = 'COCO'