Selaa lähdekoodia

fix bug when pretrain_weights is None

will-jl944 4 vuotta sitten
vanhempi
commit
60f9604eb4

+ 1 - 1
dygraph/paddlex/cv/models/classifier.py

@@ -245,7 +245,7 @@ class BaseClassifier(BaseModel):
                     "If don't want to use pretrain weights, "
                     "set pretrain_weights to be None.")
                 pretrain_weights = 'IMAGENET'
-        elif osp.exists(pretrain_weights):
+        elif pretrain_weights is not None and osp.exists(pretrain_weights):
             if osp.splitext(pretrain_weights)[-1] != '.pdparams':
                 logging.error(
                     "Invalid pretrain weights. Please specify a '.pdparams' file.",

+ 1 - 1
dygraph/paddlex/cv/models/detector.py

@@ -240,7 +240,7 @@ class BaseDetector(BaseModel):
                                 "If you don't want to use pretrain weights, "
                                 "set pretrain_weights to be None.".format(
                                     pretrain_weights))
-        elif osp.exists(pretrain_weights):
+        elif pretrain_weights is not None and osp.exists(pretrain_weights):
             if osp.splitext(pretrain_weights)[-1] != '.pdparams':
                 logging.error(
                     "Invalid pretrain weights. Please specify a '.pdparams' file.",

+ 1 - 1
dygraph/paddlex/cv/models/segmenter.py

@@ -241,7 +241,7 @@ class BaseSegmenter(BaseModel):
                                         0]))
                 pretrain_weights = seg_pretrain_weights_dict[self.model_name][
                     0]
-        elif osp.exists(pretrain_weights):
+        elif pretrain_weights is not None and osp.exists(pretrain_weights):
             if osp.splitext(pretrain_weights)[-1] != '.pdparams':
                 logging.error(
                     "Invalid pretrain weights. Please specify a '.pdparams' file.",