|
|
@@ -353,7 +353,12 @@ def get_pretrain_weights(flag, class_name, save_dir, backbone_name=None):
|
|
|
return flag
|
|
|
|
|
|
# TODO: check flag
|
|
|
- new_save_dir = getattr(paddlex, 'pretrain_dir', save_dir)
|
|
|
+ new_save_dir = save_dir
|
|
|
+ if hasattr(paddlex, 'pretrain_dir'):
|
|
|
+ new_save_dir = paddlex.pretrain_dir
|
|
|
+ new_save_dir = osp.join(new_save_dir, class_name)
|
|
|
+ if backbone_name is not None:
|
|
|
+ new_save_dir = "{}_{}".format(new_save_dir, backbone_name)
|
|
|
if backbone_name is not None:
|
|
|
weights_key = "{}_{}_{}".format(class_name, backbone_name, flag)
|
|
|
else:
|