Selaa lähdekoodia

support resume checkpoint during model pruning

FlyingQianMM 4 vuotta sitten
vanhempi
commit
b712586c8e
1 muutettua tiedostoa jossa 23 lisäystä ja 13 poistoa
  1. 23 13
      paddlex/cv/models/base.py

+ 23 - 13
paddlex/cv/models/base.py

@@ -245,26 +245,23 @@ class BaseAPI:
         if startup_prog is None:
             startup_prog = fluid.default_startup_program()
         self.exe.run(startup_prog)
-        if resume_checkpoint:
-            logging.info(
-                "Resume checkpoint from {}.".format(resume_checkpoint),
-                use_color=True)
-            paddlex.utils.utils.load_pretrain_weights(
-                self.exe, self.train_prog, resume_checkpoint, resume=True)
-            if not osp.exists(osp.join(resume_checkpoint, "model.yml")):
-                raise Exception("There's not model.yml in {}".format(
-                    resume_checkpoint))
-            with open(osp.join(resume_checkpoint, "model.yml")) as f:
-                info = yaml.load(f.read(), Loader=yaml.Loader)
-                self.completed_epochs = info['completed_epochs']
-        elif pretrain_weights is not None:
+
+        if not resume_checkpoint and pretrain_weights:
             logging.info(
                 "Load pretrain weights from {}.".format(pretrain_weights),
                 use_color=True)
             paddlex.utils.utils.load_pretrain_weights(
                 self.exe, self.train_prog, pretrain_weights, fuse_bn)
+
         # 进行裁剪
         if sensitivities_file is not None:
+            import paddle
+            version = paddle.__version__.strip().split('.')
+            if version[0] == '2' or (version[0] == '0' and
+                                     hasattr(paddle, 'enable_static')):
+                raise Exception(
+                    'Model pruning is not ready when using paddle>=2.0.0, please downgrade paddle to 1.8.5.'
+                )
             import paddleslim
             from .slim.prune_config import get_sensitivities
             sensitivities_file = get_sensitivities(sensitivities_file, self,
@@ -286,6 +283,19 @@ class BaseAPI:
                 use_color=True)
             self.status = 'Prune'
 
+        if resume_checkpoint:
+            logging.info(
+                "Resume checkpoint from {}.".format(resume_checkpoint),
+                use_color=True)
+            paddlex.utils.utils.load_pretrain_weights(
+                self.exe, self.train_prog, resume_checkpoint, resume=True)
+            if not osp.exists(osp.join(resume_checkpoint, "model.yml")):
+                raise Exception("There's not model.yml in {}".format(
+                    resume_checkpoint))
+            with open(osp.join(resume_checkpoint, "model.yml")) as f:
+                info = yaml.load(f.read(), Loader=yaml.Loader)
+                self.completed_epochs = info['completed_epochs']
+
     def get_model_info(self):
         info = dict()
         info['version'] = paddlex.__version__