Sfoglia il codice sorgente

support export_onnx

Channingss 5 anni fa
parent
commit
0702f0ea72
4 ha cambiato i file con 14 aggiunte e 129 eliminazioni
  1. 7 3
      paddlex/__init__.py
  2. 1 5
      paddlex/command.py
  3. 3 115
      paddlex/cv/models/base.py
  4. 3 6
      paddlex/cv/models/load_model.py

+ 7 - 3
paddlex/__init__.py

@@ -19,18 +19,22 @@ from . import det
 from . import seg
 from . import cls
 from . import slim
+from . import convertor
 
 try:
     import pycocotools
 except:
-    print("[WARNING] pycocotools is not installed, detection model is not available now.")
-    print("[WARNING] pycocotools install: https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/install.md")
+    print(
+        "[WARNING] pycocotools is not installed, detection model is not available now."
+    )
+    print(
+        "[WARNING] pycocotools install: https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/install.md"
+    )
 
 import paddlehub as hub
 if hub.version.hub_version < '1.6.2':
     raise Exception("[ERROR] paddlehub >= 1.6.2 is required")
 
-
 env_info = get_environ_info()
 load_model = cv.models.load_model
 datasets = cv.datasets

+ 1 - 5
paddlex/command.py

@@ -83,12 +83,8 @@ def main():
             fixed_input_shape = eval(args.fixed_input_shape)
             assert len(
                 fixed_input_shape) == 2, "len of fixed input shape must == 2"
-
         model = pdx.load_model(args.model_dir, fixed_input_shape)
-
-        model_name = os.path.basename(args.model_dir.strip('/')).split('/')[-1]
-        onnx_name = model_name + '.onnx'
-        model.export_onnx_model(args.save_dir, onnx_name=onnx_name)
+        pdx.convertor.export_onnx_model(model, args.save_dir)
 
 
 if __name__ == "__main__":

+ 3 - 115
paddlex/cv/models/base.py

@@ -223,6 +223,9 @@ class BaseAPI:
             del self.init_params['self']
         if '__class__' in self.init_params:
             del self.init_params['__class__']
+        if 'model_name' in self.init_params:
+            del self.init_params['model_name']
+
         info['_init_params'] = self.init_params
 
         info['_Attributes']['num_classes'] = self.num_classes
@@ -328,121 +331,6 @@ class BaseAPI:
         logging.info(
             "Model for inference deploy saved in {}.".format(save_dir))
 
-    def export_onnx_model(self, save_dir, onnx_name=None):
-        from fluid.utils import op_io_info, init_name_prefix
-        from onnx import helper, checker
-        import fluid_onnx.ops as ops
-        from fluid_onnx.variables import paddle_variable_to_onnx_tensor, paddle_onnx_weight
-        from debug.model_check import debug_model, Tracker
-        place = fluid.CPUPlace()
-        exe = fluid.Executor(place)
-        inference_scope = fluid.global_scope()
-        with fluid.scope_guard(inference_scope):
-            test_input_names = [
-                var.name for var in list(self.test_inputs.values())
-            ]
-            inputs_outputs_list = ["fetch", "feed"]
-            weights, weights_value_info = [], []
-            global_block = self.test_prog.global_block()
-            for var_name in global_block.vars:
-                var = global_block.var(var_name)
-                if var_name not in test_input_names\
-                    and var.persistable:
-                    weight, val_info = paddle_onnx_weight(
-                        var=var, scope=inference_scope)
-                    weights.append(weight)
-                    weights_value_info.append(val_info)
-            # Create inputs
-            inputs = [
-                paddle_variable_to_onnx_tensor(v, global_block)
-                for v in test_input_names
-            ]
-            print("load the model parameter done.")
-            onnx_nodes = []
-            op_check_list = []
-            op_trackers = []
-            nms_first_index = -1
-            nms_outputs = []
-            for block in self.test_prog.blocks:
-                for op in block.ops:
-                    if op.type in ops.node_maker:
-                        # TODO(kuke): deal with the corner case that vars in
-                        #     different blocks have the same name
-                        node_proto = ops.node_maker[str(op.type)](
-                            operator=op, block=block)
-                        op_outputs = []
-                        last_node = None
-                        if isinstance(node_proto, tuple):
-                            onnx_nodes.extend(list(node_proto))
-                            last_node = list(node_proto)
-                        else:
-                            onnx_nodes.append(node_proto)
-                            last_node = [node_proto]
-                        tracker = Tracker(str(op.type), last_node)
-                        op_trackers.append(tracker)
-                        op_check_list.append(str(op.type))
-                        if op.type == "multiclass_nms" and nms_first_index < 0:
-                            nms_first_index = 0
-                        if nms_first_index >= 0:
-                            _, _, output_op = op_io_info(op)
-                            for output in output_op:
-                                nms_outputs.extend(output_op[output])
-                    else:
-                        if op.type not in ['feed', 'fetch']:
-                            op_check_list.append(op.type)
-            print('The operator sets to run test case.')
-            print(set(op_check_list))
-            # Create outputs
-            # Get the new names for outputs if they've been renamed in nodes' making
-            renamed_outputs = op_io_info.get_all_renamed_outputs()
-            test_outputs = list(self.test_outputs.values())
-            test_outputs_names = [
-                var.name for var in self.test_outputs.values()
-            ]
-            test_outputs_names = [
-                name if name not in renamed_outputs else renamed_outputs[name]
-                for name in test_outputs_names
-            ]
-            outputs = [
-                paddle_variable_to_onnx_tensor(v, global_block)
-                for v in test_outputs_names
-            ]
-            # Make graph
-            onnx_graph = helper.make_graph(
-                nodes=onnx_nodes,
-                name=onnx_name,
-                initializer=weights,
-                inputs=inputs + weights_value_info,
-                outputs=outputs)
-
-            # Make model
-            onnx_model = helper.make_model(
-                onnx_graph, producer_name='PaddlePaddle')
-
-            # Model check
-            checker.check_model(onnx_model)
-
-            # Print model
-            #if to_print_model:
-            #    print("The converted model is:\n{}".format(onnx_model))
-            # Save converted model
-
-            if onnx_model is not None:
-                try:
-                    onnx_model_file = osp.join(save_dir, onnx_name)
-                    if not os.path.exists(save_dir):
-                        os.mkdir(save_dir)
-                    with open(onnx_model_file, 'wb') as f:
-                        f.write(onnx_model.SerializeToString())
-                    print(
-                        "Saved converted model to path: %s" % onnx_model_file)
-                except Exception as e:
-                    print(e)
-                    print(
-                        "Convert Failed! Please use the debug message to find error."
-                    )
-                    sys.exit(-1)
-
     def train_loop(self,
                    num_epochs,
                    train_dataset,

+ 3 - 6
paddlex/cv/models/load_model.py

@@ -38,12 +38,9 @@ def load_model(model_dir, fixed_input_shape=None):
     if not hasattr(paddlex.cv.models, info['Model']):
         raise Exception("There's no attribute {} in paddlex.cv.models".format(
             info['Model']))
-
-    if info['_Attributes']['model_type'] == 'classifier':
-        model = paddlex.cv.models.BaseClassifier(**info['_init_params'])
-    else:
-        model = getattr(paddlex.cv.models,
-                        info['Model'])(**info['_init_params'])
+    if 'model_name' in info['_init_params']:
+        del info['_init_params']['model_name']
+    model = getattr(paddlex.cv.models, info['Model'])(**info['_init_params'])
     model.fixed_input_shape = fixed_input_shape
     if status == "Normal" or \
             status == "Prune" or status == "fluid.save":