Ver código fonte

modify save model

jiangjiajun 5 anos atrás
pai
commit
4c1373c8a3
1 arquivos alterados com 4 adições e 1 exclusões
  1. 4 1
      paddlex/cv/models/base.py

+ 4 - 1
paddlex/cv/models/base.py

@@ -255,7 +255,10 @@ class BaseAPI:
             if osp.exists(save_dir):
                 os.remove(save_dir)
             os.makedirs(save_dir)
-        fluid.save(self.train_prog, osp.join(save_dir, 'model'))
+        if self.train_prog is not None:
+            fluid.save(self.train_prog, osp.join(save_dir, 'model'))
+        else:
+            fluid.save(self.test_prog, osp.join(save_dir, 'model'))
         model_info = self.get_model_info()
         model_info['status'] = self.status
         with open(