Răsfoiți Sursa

append the classname with backbone_name to pretrain_dir for gui because the pretrain weight of a backbone for cls/det/seg is not same

FlyingQianMM 4 ani în urmă
părinte
comite
fe12272aa7
1 a modificat fișierele cu 6 adăugiri și 1 ștergeri
  1. 6 1
      dygraph/paddlex/utils/checkpoint.py

+ 6 - 1
dygraph/paddlex/utils/checkpoint.py

@@ -353,7 +353,12 @@ def get_pretrain_weights(flag, class_name, save_dir, backbone_name=None):
         return flag
 
     # TODO: check flag
-    new_save_dir = getattr(paddlex, 'pretrain_dir', save_dir)
+    new_save_dir = save_dir
+    if hasattr(paddlex, 'pretrain_dir'):
+        new_save_dir = paddlex.pretrain_dir
+        new_save_dir = osp.join(new_save_dir, class_name)
+        if backbone_name is not None:
+            new_save_dir = "{}_{}".format(new_save_dir, backbone_name)
     if backbone_name is not None:
         weights_key = "{}_{}_{}".format(class_name, backbone_name, flag)
     else: