|
|
@@ -199,7 +199,7 @@ class BaseAPI:
|
|
|
self.exe.run(startup_prog)
|
|
|
if pretrain_weights is not None:
|
|
|
logging.info(
|
|
|
- "Load pretrain weights from {}.".format(pretrain_weights))
|
|
|
+ "Load pretrain weights from {}.".format(pretrain_weights), use_color=True)
|
|
|
paddlex.utils.utils.load_pretrain_weights(
|
|
|
self.exe, self.train_prog, pretrain_weights, fuse_bn)
|
|
|
# 进行裁剪
|
|
|
@@ -211,7 +211,7 @@ class BaseAPI:
|
|
|
from .slim.prune import get_params_ratios, prune_program
|
|
|
logging.info(
|
|
|
"Start to prune program with eval_metric_loss = {}".format(
|
|
|
- eval_metric_loss))
|
|
|
+ eval_metric_loss), use_color=True)
|
|
|
origin_flops = paddleslim.analysis.flops(self.test_prog)
|
|
|
prune_params_ratios = get_params_ratios(
|
|
|
sensitivities_file, eval_metric_loss=eval_metric_loss)
|
|
|
@@ -220,7 +220,7 @@ class BaseAPI:
|
|
|
remaining_ratio = current_flops / origin_flops
|
|
|
logging.info(
|
|
|
"Finish prune program, before FLOPs:{}, after prune FLOPs:{}, remaining ratio:{}"
|
|
|
- .format(origin_flops, current_flops, remaining_ratio))
|
|
|
+ .format(origin_flops, current_flops, remaining_ratio), use_color=True)
|
|
|
self.status = 'Prune'
|
|
|
|
|
|
def get_model_info(self):
|