|
|
@@ -243,6 +243,11 @@ class BaseClassifier(BaseModel):
|
|
|
"If don't want to use pretrain weights, "
|
|
|
"set pretrain_weights to be None.")
|
|
|
pretrain_weights = 'IMAGENET'
|
|
|
+ elif osp.exists(pretrain_weights):
|
|
|
+ if osp.splitext(pretrain_weights)[-1] != '.pdparams':
|
|
|
+ logging.error(
|
|
|
+ "Invalid pretrain weights. Please specify a '.pdparams' file.",
|
|
|
+ exit=True)
|
|
|
pretrained_dir = osp.join(save_dir, 'pretrain')
|
|
|
self.net_initialize(
|
|
|
pretrain_weights=pretrain_weights, save_dir=pretrained_dir)
|