|
|
@@ -200,18 +200,31 @@ class BaseAPI:
|
|
|
self.exe.run(startup_prog)
|
|
|
if pretrain_weights is not None:
|
|
|
logging.info(
|
|
|
- "Load pretrain weights from {}.".format(pretrain_weights))
|
|
|
+ "Load pretrain weights from {}.".format(pretrain_weights),
|
|
|
+ use_color=True)
|
|
|
paddlex.utils.utils.load_pretrain_weights(
|
|
|
self.exe, self.train_prog, pretrain_weights, fuse_bn)
|
|
|
# 进行裁剪
|
|
|
if sensitivities_file is not None:
|
|
|
+ import paddleslim
|
|
|
from .slim.prune_config import get_sensitivities
|
|
|
sensitivities_file = get_sensitivities(sensitivities_file, self,
|
|
|
save_dir)
|
|
|
from .slim.prune import get_params_ratios, prune_program
|
|
|
+ logging.info(
|
|
|
+ "Start to prune program with eval_metric_loss = {}".format(
|
|
|
+ eval_metric_loss),
|
|
|
+ use_color=True)
|
|
|
+ origin_flops = paddleslim.analysis.flops(self.test_prog)
|
|
|
prune_params_ratios = get_params_ratios(
|
|
|
sensitivities_file, eval_metric_loss=eval_metric_loss)
|
|
|
prune_program(self, prune_params_ratios)
|
|
|
+ current_flops = paddleslim.analysis.flops(self.test_prog)
|
|
|
+ remaining_ratio = current_flops / origin_flops
|
|
|
+ logging.info(
|
|
|
+ "Finish prune program, before FLOPs:{}, after prune FLOPs:{}, remaining ratio:{}"
|
|
|
+ .format(origin_flops, current_flops, remaining_ratio),
|
|
|
+ use_color=True)
|
|
|
self.status = 'Prune'
|
|
|
|
|
|
def get_model_info(self):
|
|
|
@@ -259,7 +272,10 @@ class BaseAPI:
|
|
|
if osp.exists(save_dir):
|
|
|
os.remove(save_dir)
|
|
|
os.makedirs(save_dir)
|
|
|
- fluid.save(self.train_prog, osp.join(save_dir, 'model'))
|
|
|
+ if self.train_prog is not None:
|
|
|
+ fluid.save(self.train_prog, osp.join(save_dir, 'model'))
|
|
|
+ else:
|
|
|
+ fluid.save(self.test_prog, osp.join(save_dir, 'model'))
|
|
|
model_info = self.get_model_info()
|
|
|
model_info['status'] = self.status
|
|
|
with open(
|
|
|
@@ -408,7 +424,7 @@ class BaseAPI:
|
|
|
earlystop = EarlyStop(early_stop_patience, thresh)
|
|
|
best_accuracy_key = ""
|
|
|
best_accuracy = -1.0
|
|
|
- best_model_epoch = 1
|
|
|
+ best_model_epoch = -1
|
|
|
for i in range(num_epochs):
|
|
|
records = list()
|
|
|
step_start_time = time.time()
|
|
|
@@ -481,7 +497,7 @@ class BaseAPI:
|
|
|
current_save_dir = osp.join(save_dir, "epoch_{}".format(i + 1))
|
|
|
if not osp.isdir(current_save_dir):
|
|
|
os.makedirs(current_save_dir)
|
|
|
- if eval_dataset is not None:
|
|
|
+ if eval_dataset is not None and eval_dataset.num_samples > 0:
|
|
|
self.eval_metrics, self.eval_details = self.evaluate(
|
|
|
eval_dataset=eval_dataset,
|
|
|
batch_size=eval_batch_size,
|
|
|
@@ -513,10 +529,11 @@ class BaseAPI:
|
|
|
self.save_model(save_dir=current_save_dir)
|
|
|
time_eval_one_epoch = time.time() - eval_epoch_start_time
|
|
|
eval_epoch_start_time = time.time()
|
|
|
- logging.info(
|
|
|
- 'Current evaluated best model in eval_dataset is epoch_{}, {}={}'
|
|
|
- .format(best_model_epoch, best_accuracy_key,
|
|
|
- best_accuracy))
|
|
|
+ if best_model_epoch > 0:
|
|
|
+ logging.info(
|
|
|
+ 'Current evaluated best model in eval_dataset is epoch_{}, {}={}'
|
|
|
+ .format(best_model_epoch, best_accuracy_key,
|
|
|
+ best_accuracy))
|
|
|
if eval_dataset is not None and early_stop:
|
|
|
if earlystop(current_accuracy):
|
|
|
break
|