|
|
@@ -200,18 +200,31 @@ class BaseAPI:
|
|
|
self.exe.run(startup_prog)
|
|
|
if pretrain_weights is not None:
|
|
|
logging.info(
|
|
|
- "Load pretrain weights from {}.".format(pretrain_weights))
|
|
|
+ "Load pretrain weights from {}.".format(pretrain_weights),
|
|
|
+ use_color=True)
|
|
|
paddlex.utils.utils.load_pretrain_weights(
|
|
|
self.exe, self.train_prog, pretrain_weights, fuse_bn)
|
|
|
# 进行裁剪
|
|
|
if sensitivities_file is not None:
|
|
|
+ import paddleslim
|
|
|
from .slim.prune_config import get_sensitivities
|
|
|
sensitivities_file = get_sensitivities(sensitivities_file, self,
|
|
|
save_dir)
|
|
|
from .slim.prune import get_params_ratios, prune_program
|
|
|
+ logging.info(
|
|
|
+ "Start to prune program with eval_metric_loss = {}".format(
|
|
|
+ eval_metric_loss),
|
|
|
+ use_color=True)
|
|
|
+ origin_flops = paddleslim.analysis.flops(self.test_prog)
|
|
|
prune_params_ratios = get_params_ratios(
|
|
|
sensitivities_file, eval_metric_loss=eval_metric_loss)
|
|
|
prune_program(self, prune_params_ratios)
|
|
|
+ current_flops = paddleslim.analysis.flops(self.test_prog)
|
|
|
+ remaining_ratio = current_flops / origin_flops
|
|
|
+ logging.info(
|
|
|
+ "Finish prune program, before FLOPs:{}, after prune FLOPs:{}, remaining ratio:{}"
|
|
|
+ .format(origin_flops, current_flops, remaining_ratio),
|
|
|
+ use_color=True)
|
|
|
self.status = 'Prune'
|
|
|
|
|
|
def get_model_info(self):
|
|
|
@@ -223,6 +236,9 @@ class BaseAPI:
|
|
|
del self.init_params['self']
|
|
|
if '__class__' in self.init_params:
|
|
|
del self.init_params['__class__']
|
|
|
+ if 'model_name' in self.init_params:
|
|
|
+ del self.init_params['model_name']
|
|
|
+
|
|
|
info['_init_params'] = self.init_params
|
|
|
|
|
|
info['_Attributes']['num_classes'] = self.num_classes
|
|
|
@@ -256,7 +272,10 @@ class BaseAPI:
|
|
|
if osp.exists(save_dir):
|
|
|
os.remove(save_dir)
|
|
|
os.makedirs(save_dir)
|
|
|
- fluid.save(self.train_prog, osp.join(save_dir, 'model'))
|
|
|
+ if self.train_prog is not None:
|
|
|
+ fluid.save(self.train_prog, osp.join(save_dir, 'model'))
|
|
|
+ else:
|
|
|
+ fluid.save(self.test_prog, osp.join(save_dir, 'model'))
|
|
|
model_info = self.get_model_info()
|
|
|
model_info['status'] = self.status
|
|
|
with open(
|
|
|
@@ -328,140 +347,6 @@ class BaseAPI:
|
|
|
logging.info(
|
|
|
"Model for inference deploy saved in {}.".format(save_dir))
|
|
|
|
|
|
- def export_onnx_model(self, save_dir, onnx_name=None):
|
|
|
- support_list = ['ResNet18','ResNet34','ResNet50','ResNet101','ResNet50_vd',
|
|
|
- 'ResNet101_vd','ResNet50_vd_ssld','ResNet101_vd_ssld','DarkNet53',
|
|
|
- 'MobileNetV1','MobileNetV2','MobileNetV3_large','MobileNetV3_small',
|
|
|
- 'MobileNetV3_large_ssld','MobileNetV3_small_ssld','Xception41',
|
|
|
- 'Xception65','DenseNet121','DenseNet161','DenseNet201','ShuffleNetV2']
|
|
|
- unsupport_list = []
|
|
|
- if self.model_type in unsupport_list:
|
|
|
- raise Exception("Model: {} unsupport export to ONNX"
|
|
|
- .format(self.model_type)
|
|
|
- try:
|
|
|
- from fluid.utils import op_io_info, init_name_prefix
|
|
|
- from onnx import helper, checker
|
|
|
- import fluid_onnx.ops as ops
|
|
|
- from fluid_onnx.variables import paddle_variable_to_onnx_tensor, paddle_onnx_weight
|
|
|
- from debug.model_check import debug_model, Tracker
|
|
|
- except Exception as e:
|
|
|
- print(e)
|
|
|
- print(
|
|
|
- "Import Module Failed! Please install paddle2onnx. Related requirements
|
|
|
- see https://github.com/PaddlePaddle/paddle2onnx"
|
|
|
- )
|
|
|
- sys.exit(-1)
|
|
|
-
|
|
|
- place = fluid.CPUPlace()
|
|
|
- exe = fluid.Executor(place)
|
|
|
- inference_scope = fluid.global_scope()
|
|
|
- with fluid.scope_guard(inference_scope):
|
|
|
- test_input_names = [
|
|
|
- var.name for var in list(self.test_inputs.values())
|
|
|
- ]
|
|
|
- inputs_outputs_list = ["fetch", "feed"]
|
|
|
- weights, weights_value_info = [], []
|
|
|
- global_block = self.test_prog.global_block()
|
|
|
- for var_name in global_block.vars:
|
|
|
- var = global_block.var(var_name)
|
|
|
- if var_name not in test_input_names\
|
|
|
- and var.persistable:
|
|
|
- weight, val_info = paddle_onnx_weight(
|
|
|
- var=var, scope=inference_scope)
|
|
|
- weights.append(weight)
|
|
|
- weights_value_info.append(val_info)
|
|
|
- # Create inputs
|
|
|
- inputs = [
|
|
|
- paddle_variable_to_onnx_tensor(v, global_block)
|
|
|
- for v in test_input_names
|
|
|
- ]
|
|
|
- print("load the model parameter done.")
|
|
|
- onnx_nodes = []
|
|
|
- op_check_list = []
|
|
|
- op_trackers = []
|
|
|
- nms_first_index = -1
|
|
|
- nms_outputs = []
|
|
|
- for block in self.test_prog.blocks:
|
|
|
- for op in block.ops:
|
|
|
- if op.type in ops.node_maker:
|
|
|
- # TODO(kuke): deal with the corner case that vars in
|
|
|
- # different blocks have the same name
|
|
|
- node_proto = ops.node_maker[str(op.type)](
|
|
|
- operator=op, block=block)
|
|
|
- op_outputs = []
|
|
|
- last_node = None
|
|
|
- if isinstance(node_proto, tuple):
|
|
|
- onnx_nodes.extend(list(node_proto))
|
|
|
- last_node = list(node_proto)
|
|
|
- else:
|
|
|
- onnx_nodes.append(node_proto)
|
|
|
- last_node = [node_proto]
|
|
|
- tracker = Tracker(str(op.type), last_node)
|
|
|
- op_trackers.append(tracker)
|
|
|
- op_check_list.append(str(op.type))
|
|
|
- if op.type == "multiclass_nms" and nms_first_index < 0:
|
|
|
- nms_first_index = 0
|
|
|
- if nms_first_index >= 0:
|
|
|
- _, _, output_op = op_io_info(op)
|
|
|
- for output in output_op:
|
|
|
- nms_outputs.extend(output_op[output])
|
|
|
- else:
|
|
|
- if op.type not in ['feed', 'fetch']:
|
|
|
- op_check_list.append(op.type)
|
|
|
- print('The operator sets to run test case.')
|
|
|
- print(set(op_check_list))
|
|
|
-
|
|
|
- # Create outputs
|
|
|
- # Get the new names for outputs if they've been renamed in nodes' making
|
|
|
- renamed_outputs = op_io_info.get_all_renamed_outputs()
|
|
|
- test_outputs = list(self.test_outputs.values())
|
|
|
- test_outputs_names = [
|
|
|
- var.name for var in self.test_outputs.values()
|
|
|
- ]
|
|
|
- test_outputs_names = [
|
|
|
- name if name not in renamed_outputs else renamed_outputs[name]
|
|
|
- for name in test_outputs_names
|
|
|
- ]
|
|
|
- outputs = [
|
|
|
- paddle_variable_to_onnx_tensor(v, global_block)
|
|
|
- for v in test_outputs_names
|
|
|
- ]
|
|
|
- # Make graph
|
|
|
- onnx_graph = helper.make_graph(
|
|
|
- nodes=onnx_nodes,
|
|
|
- name=onnx_name,
|
|
|
- initializer=weights,
|
|
|
- inputs=inputs + weights_value_info,
|
|
|
- outputs=outputs)
|
|
|
-
|
|
|
- # Make model
|
|
|
- onnx_model = helper.make_model(
|
|
|
- onnx_graph, producer_name='PaddlePaddle')
|
|
|
-
|
|
|
- # Model check
|
|
|
- checker.check_model(onnx_model)
|
|
|
-
|
|
|
- # Print model
|
|
|
- #if to_print_model:
|
|
|
- # print("The converted model is:\n{}".format(onnx_model))
|
|
|
- # Save converted model
|
|
|
-
|
|
|
- if onnx_model is not None:
|
|
|
- try:
|
|
|
- onnx_model_file = osp.join(save_dir, onnx_name)
|
|
|
- if not os.path.exists(save_dir):
|
|
|
- os.mkdir(save_dir)
|
|
|
- with open(onnx_model_file, 'wb') as f:
|
|
|
- f.write(onnx_model.SerializeToString())
|
|
|
- print(
|
|
|
- "Saved converted model to path: %s" % onnx_model_file)
|
|
|
- except Exception as e:
|
|
|
- print(e)
|
|
|
- print(
|
|
|
- "Convert Failed! Please use the debug message to find error."
|
|
|
- )
|
|
|
- sys.exit(-1)
|
|
|
-
|
|
|
def train_loop(self,
|
|
|
num_epochs,
|
|
|
train_dataset,
|
|
|
@@ -539,7 +424,7 @@ class BaseAPI:
|
|
|
earlystop = EarlyStop(early_stop_patience, thresh)
|
|
|
best_accuracy_key = ""
|
|
|
best_accuracy = -1.0
|
|
|
- best_model_epoch = 1
|
|
|
+ best_model_epoch = -1
|
|
|
for i in range(num_epochs):
|
|
|
records = list()
|
|
|
step_start_time = time.time()
|
|
|
@@ -612,7 +497,7 @@ class BaseAPI:
|
|
|
current_save_dir = osp.join(save_dir, "epoch_{}".format(i + 1))
|
|
|
if not osp.isdir(current_save_dir):
|
|
|
os.makedirs(current_save_dir)
|
|
|
- if eval_dataset is not None:
|
|
|
+ if eval_dataset is not None and eval_dataset.num_samples > 0:
|
|
|
self.eval_metrics, self.eval_details = self.evaluate(
|
|
|
eval_dataset=eval_dataset,
|
|
|
batch_size=eval_batch_size,
|
|
|
@@ -644,10 +529,11 @@ class BaseAPI:
|
|
|
self.save_model(save_dir=current_save_dir)
|
|
|
time_eval_one_epoch = time.time() - eval_epoch_start_time
|
|
|
eval_epoch_start_time = time.time()
|
|
|
- logging.info(
|
|
|
- 'Current evaluated best model in eval_dataset is epoch_{}, {}={}'
|
|
|
- .format(best_model_epoch, best_accuracy_key,
|
|
|
- best_accuracy))
|
|
|
+ if best_model_epoch > 0:
|
|
|
+ logging.info(
|
|
|
+ 'Current evaluated best model in eval_dataset is epoch_{}, {}={}'
|
|
|
+ .format(best_model_epoch, best_accuracy_key,
|
|
|
+ best_accuracy))
|
|
|
if eval_dataset is not None and early_stop:
|
|
|
if earlystop(current_accuracy):
|
|
|
break
|