Parcourir la source

modify save model

jiangjiajun il y a 5 ans
Parent
commit
4c1373c8a3
1 fichiers modifiés avec 4 ajouts et 1 suppressions
  1. 4 1
      paddlex/cv/models/base.py

+ 4 - 1
paddlex/cv/models/base.py

@@ -255,7 +255,10 @@ class BaseAPI:
             if osp.exists(save_dir):
                 os.remove(save_dir)
             os.makedirs(save_dir)
-        fluid.save(self.train_prog, osp.join(save_dir, 'model'))
+        if self.train_prog is not None:
+            fluid.save(self.train_prog, osp.join(save_dir, 'model'))
+        else:
+            fluid.save(self.test_prog, osp.join(save_dir, 'model'))
         model_info = self.get_model_info()
         model_info['status'] = self.status
         with open(