|
|
@@ -245,26 +245,23 @@ class BaseAPI:
|
|
|
if startup_prog is None:
|
|
|
startup_prog = fluid.default_startup_program()
|
|
|
self.exe.run(startup_prog)
|
|
|
- if resume_checkpoint:
|
|
|
- logging.info(
|
|
|
- "Resume checkpoint from {}.".format(resume_checkpoint),
|
|
|
- use_color=True)
|
|
|
- paddlex.utils.utils.load_pretrain_weights(
|
|
|
- self.exe, self.train_prog, resume_checkpoint, resume=True)
|
|
|
- if not osp.exists(osp.join(resume_checkpoint, "model.yml")):
|
|
|
- raise Exception("There's not model.yml in {}".format(
|
|
|
- resume_checkpoint))
|
|
|
- with open(osp.join(resume_checkpoint, "model.yml")) as f:
|
|
|
- info = yaml.load(f.read(), Loader=yaml.Loader)
|
|
|
- self.completed_epochs = info['completed_epochs']
|
|
|
- elif pretrain_weights is not None:
|
|
|
+
|
|
|
+ if not resume_checkpoint and pretrain_weights:
|
|
|
logging.info(
|
|
|
"Load pretrain weights from {}.".format(pretrain_weights),
|
|
|
use_color=True)
|
|
|
paddlex.utils.utils.load_pretrain_weights(
|
|
|
self.exe, self.train_prog, pretrain_weights, fuse_bn)
|
|
|
+
|
|
|
# 进行裁剪
|
|
|
if sensitivities_file is not None:
|
|
|
+ import paddle
|
|
|
+ version = paddle.__version__.strip().split('.')
|
|
|
+ if version[0] == '2' or (version[0] == '0' and
|
|
|
+ hasattr(paddle, 'enable_static')):
|
|
|
+ raise Exception(
|
|
|
+ 'Model pruning is not ready when using paddle>=2.0.0, please downgrade paddle to 1.8.5.'
|
|
|
+ )
|
|
|
import paddleslim
|
|
|
from .slim.prune_config import get_sensitivities
|
|
|
sensitivities_file = get_sensitivities(sensitivities_file, self,
|
|
|
@@ -286,6 +283,19 @@ class BaseAPI:
|
|
|
use_color=True)
|
|
|
self.status = 'Prune'
|
|
|
|
|
|
+ if resume_checkpoint:
|
|
|
+ logging.info(
|
|
|
+ "Resume checkpoint from {}.".format(resume_checkpoint),
|
|
|
+ use_color=True)
|
|
|
+ paddlex.utils.utils.load_pretrain_weights(
|
|
|
+ self.exe, self.train_prog, resume_checkpoint, resume=True)
|
|
|
+ if not osp.exists(osp.join(resume_checkpoint, "model.yml")):
|
|
|
+ raise Exception("There's not model.yml in {}".format(
|
|
|
+ resume_checkpoint))
|
|
|
+ with open(osp.join(resume_checkpoint, "model.yml")) as f:
|
|
|
+ info = yaml.load(f.read(), Loader=yaml.Loader)
|
|
|
+ self.completed_epochs = info['completed_epochs']
|
|
|
+
|
|
|
def get_model_info(self):
|
|
|
info = dict()
|
|
|
info['version'] = paddlex.__version__
|