浏览代码

raise error if pretrain weights is a directory instead of pdparams file

will-jl944 4 年之前
父节点
当前提交
2d8af5b3c3

+ 5 - 0
dygraph/paddlex/cv/models/classifier.py

@@ -243,6 +243,11 @@ class BaseClassifier(BaseModel):
                     "If don't want to use pretrain weights, "
                     "set pretrain_weights to be None.")
                 pretrain_weights = 'IMAGENET'
+        elif osp.exists(pretrain_weights):
+            if osp.splitext(pretrain_weights)[-1] != '.pdparams':
+                logging.error(
+                    "Invalid pretrain weights. Please specify a '.pdparams' file.",
+                    exit=True)
         pretrained_dir = osp.join(save_dir, 'pretrain')
         self.net_initialize(
             pretrain_weights=pretrain_weights, save_dir=pretrained_dir)

+ 5 - 0
dygraph/paddlex/cv/models/detector.py

@@ -240,6 +240,11 @@ class BaseDetector(BaseModel):
                                 "If you don't want to use pretrain weights, "
                                 "set pretrain_weights to be None.".format(
                                     pretrain_weights))
+        elif osp.exists(pretrain_weights):
+            if osp.splitext(pretrain_weights)[-1] != '.pdparams':
+                logging.error(
+                    "Invalid pretrain weights. Please specify a '.pdparams' file.",
+                    exit=True)
         pretrained_dir = osp.join(save_dir, 'pretrain')
         self.net_initialize(
             pretrain_weights=pretrain_weights, save_dir=pretrained_dir)

+ 5 - 0
dygraph/paddlex/cv/models/segmenter.py

@@ -241,6 +241,11 @@ class BaseSegmenter(BaseModel):
                                         0]))
                 pretrain_weights = seg_pretrain_weights_dict[self.model_name][
                     0]
+        elif osp.exists(pretrain_weights):
+            if osp.splitext(pretrain_weights)[-1] != '.pdparams':
+                logging.error(
+                    "Invalid pretrain weights. Please specify a '.pdparams' file.",
+                    exit=True)
         pretrained_dir = osp.join(save_dir, 'pretrain')
         self.net_initialize(
             pretrain_weights=pretrain_weights, save_dir=pretrained_dir)