Browse Source

support export to onnx

Channingss 5 years ago
parent
commit
4f51c4d086
2 changed files with 135 additions and 134 deletions
  1. 21 11
      paddlex/command.py
  2. 114 123
      paddlex/cv/models/base.py

+ 21 - 11
paddlex/command.py

@@ -60,25 +60,35 @@ def main():
         print("Repo: https://github.com/PaddlePaddle/PaddleX.git")
         print("Email: paddlex@baidu.com")
         return
+
     if args.export_inference:
         assert args.model_dir is not None, "--model_dir should be defined while exporting inference model"
         assert args.save_dir is not None, "--save_dir should be defined to save inference model"
-        fixed_input_shape = eval(args.fixed_input_shape)
-        assert len(
-            fixed_input_shape) == 2, "len of fixed input shape must == 2"
+
+        fixed_input_shape = None
+        if args.fixed_input_shape is not None:
+            fixed_input_shape = eval(args.fixed_input_shape)
+            assert len(
+                fixed_input_shape) == 2, "len of fixed input shape must == 2"
 
         model = pdx.load_model(args.model_dir, fixed_input_shape)
         model.export_inference_model(args.save_dir)
 
-   # if args.export_onnx:
-   #     assert args.model_dir is not None, "--model_dir should be defined while exporting onnx model"
-   #     assert args.save_dir is not None, "--save_dir should be defined to save onnx model"
-   #     fixed_input_shape = eval(args.fixed_input_shape)
-   #     assert len(
-   #         fixed_input_shape) == 2, "len of fixed input shape must == 2"
+    if args.export_onnx:
+        assert args.model_dir is not None, "--model_dir should be defined while exporting onnx model"
+        assert args.save_dir is not None, "--save_dir should be defined to save onnx model"
+
+        fixed_input_shape = None
+        if args.fixed_input_shape is not None:
+            fixed_input_shape = eval(args.fixed_input_shape)
+            assert len(
+                fixed_input_shape) == 2, "len of fixed input shape must == 2"
+
+        model = pdx.load_model(args.model_dir, fixed_input_shape)
 
-   #     model = pdx.load_model(args.model_dir, fixed_input_shape)
-   #     model.export_onnx_model(args.save_dir)
+        model_name = os.path.basename(args.model_dir.strip('/')).split('/')[-1]
+        onnx_name = model_name + '.onnx'
+        model.export_onnx_model(args.save_dir, onnx_name=onnx_name)
 
 
 if __name__ == "__main__":

+ 114 - 123
paddlex/cv/models/base.py

@@ -15,6 +15,7 @@
 from __future__ import absolute_import
 import paddle.fluid as fluid
 import os
+import sys
 import numpy as np
 import time
 import math
@@ -327,129 +328,119 @@ class BaseAPI:
         logging.info(
             "Model for inference deploy saved in {}.".format(save_dir))
 
-   # def export_onnx_model(self, save_dir, onnx_model=None):
-   #     from fluid.utils import op_io_info, init_name_prefix
-   #     from onnx import helper, checker
-   #     import fluid_onnx.ops as ops
-   #     from fluid_onnx.variables import paddle_variable_to_onnx_tensor, paddle_onnx_weight
-   #     from debug.model_check import debug_model, Tracke
-   #     place = fluid.CPUPlace()
-   #     exe = fluid.Executor(place)
-   #     inference_scope = fluid.core.Scope() 
-   #     with fluid.scope_guard(inference_scope):
-   #         test_input_names = [
-   #         var.name for var in list(self.test_inputs.values())
-   #         ]
-   #         inputs_outputs_list = ["fetch", "feed"] 
-   #         weights, weights_value_info = [], []
-   #         global_block = self.test_program.global_block()
-   #         for var_name in global_block.vars:
-   #             var = global_block.var(var_name)
-   #             if var_name not in feed_fetch_list\
-   #                 and var.persistable:
-   #                 weight, val_info = paddle_onnx_weight(
-   #                     var=var, scope=inference_scope)
-   #                 weights.append(weight)
-   #                 weights_value_info.append(val_info)
-   #         # Create inputs
-   #         inputs = [
-   #         paddle_variable_to_onnx_tensor(v, global_block)
-   #         for v in test_input_names
-   #         ]
-   #         print("load the model parameter done.")
-   #         onnx_nodes = []
-   #         op_check_list = []
-   #         op_trackers = []
-   #         nms_first_index = -1
-   #         nms_outputs = []
-   #         for block in inference_program.blocks:
-   #             for op in block.ops:
-   #                 if op.type in ops.node_maker:
-   #                     # TODO(kuke): deal with the corner case that vars in 
-   #                     #     different blocks have the same name
-   #                     node_proto = ops.node_maker[str(op.type)](operator=op,
-   #                                                               block=block)
-   #                     op_outputs = []
-   #                     last_node = None
-   #                     if isinstance(node_proto, tuple):
-   #                         onnx_nodes.extend(list(node_proto))
-   #                         last_node = list(node_proto)
-   #                     else:
-   #                         onnx_nodes.append(node_proto)
-   #                         last_node = [node_proto]
-   #                     tracker = Tracker(str(op.type), last_node)
-   #                     op_trackers.append(tracker)
-   #                     op_check_list.append(str(op.type))
-   #                     if op.type == "multiclass_nms" and nms_first_index < 0:
-   #                         nms_first_index = 0
-   #                     if nms_first_index >= 0:
-   #                         _, _, output_op = op_io_info(op)
-   #                         for output in output_op:
-   #                             nms_outputs.extend(output_op[output])
-   #                 else:
-   #                     if op.type not in ['feed', 'fetch']:
-   #                         op_check_list.append(op.type)
-   #         print('The operator sets to run test case.')
-   #         print(set(op_check_list))
-   #         # Create outputs
-   #         # Get the new names for outputs if they've been renamed in nodes' making
-   #         renamed_outputs = op_io_info.get_all_renamed_outputs()
-   #         test_outputs = list(self.test_outputs.values())
-   #         test_outputs_names = [var.name for var in self.test_outpus.values]
-   #         test_outputs_names = [
-   #             name if name not in renamed_outputs else renamed_outputs[name]
-   #             for name in test_outputs_names
-   #         ]
-   #         outputs = [
-   #             paddle_variable_to_onnx_tensor(v, global_block)
-   #             for v in test_outputs_names
-   #         ]
-   #         # Make graph
-   #         #model_name = os.path.basename(args.fluid_model.strip('/')).split('.')[0]
-   #         model_name = 'test' 
-   #         onnx_graph = helper.make_graph(
-   #             nodes=onnx_nodes,
-   #             name=model_name,
-   #             initializer=weights,
-   #             inputs=inputs + weights_value_info,
-   #             outputs=outputs)
-
-   #         # Make model
-   #         onnx_model = helper.make_model(onnx_graph, producer_name='PaddlePaddle')
-
-   #         # Model check
-   #         checker.check_model(onnx_model)
-
-   #         # Print model
-   #         #if to_print_model:
-   #         #    print("The converted model is:\n{}".format(onnx_model))
-   #         # Save converted model
-   #         
-   #         if onnx_model is not None:
-   #             try:
-   #                 onnx_model_file = osp.join(save_dir, onnx_model)
-   #                 with open(onnx_model_file, 'wb') as f:
-   #                     f.write(onnx_model.SerializeToString())
-   #                 print("Saved converted model to path: %s" % onnx_model_file)
-   #                 # If in debug mode, need to save op list, add we will check op 
-   #                 #if args.debug:
-   #                 #    op_check_list = list(set(op_check_list))
-   #                 #    check_outputs = []
-
-   #                 #    for node_proto in onnx_nodes:
-   #                 #        check_outputs.extend(node_proto.output)
-
-   #                 #    print("The num of %d operators need to check, and %d op outputs need to check."\
-   #                 #          %(len(op_check_list), len(check_outputs)))
-
-   #                 #    debug_model(op_check_list, op_trackers, nms_outputs, args)
-
-   #             except Exception as e:
-   #                 print(e)
-   #                 print(
-   #                     "Convert Failed! Please use the debug message to find error."
-   #                 )
-   #                 sys.exit(-1)
+    def export_onnx_model(self, save_dir, onnx_name=None):
+        from fluid.utils import op_io_info, init_name_prefix
+        from onnx import helper, checker
+        import fluid_onnx.ops as ops
+        from fluid_onnx.variables import paddle_variable_to_onnx_tensor, paddle_onnx_weight
+        from debug.model_check import debug_model, Tracker
+        place = fluid.CPUPlace()
+        exe = fluid.Executor(place)
+        inference_scope = fluid.global_scope()
+        with fluid.scope_guard(inference_scope):
+            test_input_names = [
+                var.name for var in list(self.test_inputs.values())
+            ]
+            inputs_outputs_list = ["fetch", "feed"]
+            weights, weights_value_info = [], []
+            global_block = self.test_prog.global_block()
+            for var_name in global_block.vars:
+                var = global_block.var(var_name)
+                if var_name not in test_input_names\
+                    and var.persistable:
+                    weight, val_info = paddle_onnx_weight(
+                        var=var, scope=inference_scope)
+                    weights.append(weight)
+                    weights_value_info.append(val_info)
+            # Create inputs
+            inputs = [
+                paddle_variable_to_onnx_tensor(v, global_block)
+                for v in test_input_names
+            ]
+            print("load the model parameter done.")
+            onnx_nodes = []
+            op_check_list = []
+            op_trackers = []
+            nms_first_index = -1
+            nms_outputs = []
+            for block in self.test_prog.blocks:
+                for op in block.ops:
+                    if op.type in ops.node_maker:
+                        # TODO(kuke): deal with the corner case that vars in
+                        #     different blocks have the same name
+                        node_proto = ops.node_maker[str(op.type)](
+                            operator=op, block=block)
+                        op_outputs = []
+                        last_node = None
+                        if isinstance(node_proto, tuple):
+                            onnx_nodes.extend(list(node_proto))
+                            last_node = list(node_proto)
+                        else:
+                            onnx_nodes.append(node_proto)
+                            last_node = [node_proto]
+                        tracker = Tracker(str(op.type), last_node)
+                        op_trackers.append(tracker)
+                        op_check_list.append(str(op.type))
+                        if op.type == "multiclass_nms" and nms_first_index < 0:
+                            nms_first_index = 0
+                        if nms_first_index >= 0:
+                            _, _, output_op = op_io_info(op)
+                            for output in output_op:
+                                nms_outputs.extend(output_op[output])
+                    else:
+                        if op.type not in ['feed', 'fetch']:
+                            op_check_list.append(op.type)
+            print('The operator sets to run test case.')
+            print(set(op_check_list))
+            # Create outputs
+            # Get the new names for outputs if they've been renamed in nodes' making
+            renamed_outputs = op_io_info.get_all_renamed_outputs()
+            test_outputs = list(self.test_outputs.values())
+            test_outputs_names = [
+                var.name for var in self.test_outputs.values()
+            ]
+            test_outputs_names = [
+                name if name not in renamed_outputs else renamed_outputs[name]
+                for name in test_outputs_names
+            ]
+            outputs = [
+                paddle_variable_to_onnx_tensor(v, global_block)
+                for v in test_outputs_names
+            ]
+            # Make graph
+            onnx_name = 'test'
+            onnx_graph = helper.make_graph(
+                nodes=onnx_nodes,
+                name=onnx_name,
+                initializer=weights,
+                inputs=inputs + weights_value_info,
+                outputs=outputs)
+
+            # Make model
+            onnx_model = helper.make_model(
+                onnx_graph, producer_name='PaddlePaddle')
+
+            # Model check
+            checker.check_model(onnx_model)
+
+            # Print model
+            #if to_print_model:
+            #    print("The converted model is:\n{}".format(onnx_model))
+            # Save converted model
+
+            if onnx_model is not None:
+                try:
+                    onnx_model_file = osp.join(save_dir, onnx_name)
+                    with open(onnx_model_file, 'wb') as f:
+                        f.write(onnx_model.SerializeToString())
+                    print(
+                        "Saved converted model to path: %s" % onnx_model_file)
+                except Exception as e:
+                    print(e)
+                    print(
+                        "Convert Failed! Please use the debug message to find error."
+                    )
+                    sys.exit(-1)
 
     def train_loop(self,
                    num_epochs,