소스 검색

support no pretrained for cls

zhangyubo0722 1 년 전
부모
커밋
f78692923c
2개의 변경된 파일14개의 추가작업 그리고 8개의 파일을 삭제
  1. 1 2
      paddlex/modules/image_classification/trainer.py
  2. 13 6
      paddlex/repo_apis/PaddleClas_api/cls/config.py

+ 1 - 2
paddlex/modules/image_classification/trainer.py

@@ -12,7 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
 import json
 import shutil
 import paddle
@@ -64,7 +63,7 @@ class ClsTrainer(BaseTrainer):
                                        "ClsDataset")
         if self.train_config.num_classes is not None:
             self.pdx_config.update_num_classes(self.train_config.num_classes)
-        if self.train_config.pretrain_weight_path and self.train_config.pretrain_weight_path != "":
+        if self.train_config.pretrain_weight_path != "":
             self.pdx_config.update_pretrained_weights(
                 self.train_config.pretrain_weight_path)
 

+ 13 - 6
paddlex/repo_apis/PaddleClas_api/cls/config.py

@@ -147,15 +147,22 @@ class ClsConfig(BaseConfig):
             pretrained_model (str): the local path or url of pretrained weight file to set.
         """
         assert isinstance(
-            pretrained_model, (str, None)
+            pretrained_model, (str, type(None))
         ), "The 'pretrained_model' should be a string, indicating the path to the '*.pdparams' file, or 'None', \
 indicating that no pretrained model to be used."
 
-        if pretrained_model and not pretrained_model.startswith(
-            ('http://', 'https://')):
-            pretrained_model = abspath(
-                pretrained_model.replace(".pdparams", ""))
-        self.update([f'Global.pretrained_model={pretrained_model}'])
+        if pretrained_model is None:
+            self.update(['Global.pretrained_model=None'])
+            self.update(['Arch.pretrained=False'])
+        else:
+            if pretrained_model.lower() == "default":
+                self.update(['Global.pretrained_model=None'])
+                self.update(['Arch.pretrained=True'])
+            else:
+                if not pretrained_model.startswith(('http://', 'https://')):
+                    pretrained_model = abspath(
+                        pretrained_model.replace(".pdparams", ""))
+                self.update([f'Global.pretrained_model={pretrained_model}'])
 
     def update_num_classes(self, num_classes: int):
         """update classes number