فهرست منبع

Merge pull request #830 from will-jl944/develop_jf

raise error if pretrain weights is a directory instead of pdparams file
FlyingQianMM 4 سال پیش
والد
کامیت
4c918034ac
3فایلهای تغییر یافته به همراه15 افزوده شده و 0 حذف شده
  1. 5 0
      dygraph/paddlex/cv/models/classifier.py
  2. 5 0
      dygraph/paddlex/cv/models/detector.py
  3. 5 0
      dygraph/paddlex/cv/models/segmenter.py

+ 5 - 0
dygraph/paddlex/cv/models/classifier.py

@@ -243,6 +243,11 @@ class BaseClassifier(BaseModel):
                     "If don't want to use pretrain weights, "
                     "If don't want to use pretrain weights, "
                     "set pretrain_weights to be None.")
                     "set pretrain_weights to be None.")
                 pretrain_weights = 'IMAGENET'
                 pretrain_weights = 'IMAGENET'
+        elif osp.exists(pretrain_weights):
+            if osp.splitext(pretrain_weights)[-1] != '.pdparams':
+                logging.error(
+                    "Invalid pretrain weights. Please specify a '.pdparams' file.",
+                    exit=True)
         pretrained_dir = osp.join(save_dir, 'pretrain')
         pretrained_dir = osp.join(save_dir, 'pretrain')
         self.net_initialize(
         self.net_initialize(
             pretrain_weights=pretrain_weights, save_dir=pretrained_dir)
             pretrain_weights=pretrain_weights, save_dir=pretrained_dir)

+ 5 - 0
dygraph/paddlex/cv/models/detector.py

@@ -240,6 +240,11 @@ class BaseDetector(BaseModel):
                                 "If you don't want to use pretrain weights, "
                                 "If you don't want to use pretrain weights, "
                                 "set pretrain_weights to be None.".format(
                                 "set pretrain_weights to be None.".format(
                                     pretrain_weights))
                                     pretrain_weights))
+        elif osp.exists(pretrain_weights):
+            if osp.splitext(pretrain_weights)[-1] != '.pdparams':
+                logging.error(
+                    "Invalid pretrain weights. Please specify a '.pdparams' file.",
+                    exit=True)
         pretrained_dir = osp.join(save_dir, 'pretrain')
         pretrained_dir = osp.join(save_dir, 'pretrain')
         self.net_initialize(
         self.net_initialize(
             pretrain_weights=pretrain_weights, save_dir=pretrained_dir)
             pretrain_weights=pretrain_weights, save_dir=pretrained_dir)

+ 5 - 0
dygraph/paddlex/cv/models/segmenter.py

@@ -241,6 +241,11 @@ class BaseSegmenter(BaseModel):
                                         0]))
                                         0]))
                 pretrain_weights = seg_pretrain_weights_dict[self.model_name][
                 pretrain_weights = seg_pretrain_weights_dict[self.model_name][
                     0]
                     0]
+        elif osp.exists(pretrain_weights):
+            if osp.splitext(pretrain_weights)[-1] != '.pdparams':
+                logging.error(
+                    "Invalid pretrain weights. Please specify a '.pdparams' file.",
+                    exit=True)
         pretrained_dir = osp.join(save_dir, 'pretrain')
         pretrained_dir = osp.join(save_dir, 'pretrain')
         self.net_initialize(
         self.net_initialize(
             pretrain_weights=pretrain_weights, save_dir=pretrained_dir)
             pretrain_weights=pretrain_weights, save_dir=pretrained_dir)