Browse Source

adjust paddle2onnx usage

Channingss 5 years ago
parent
commit
b312229c62
2 changed files with 6 additions and 13 deletions
  1. 3 1
      paddlex/command.py
  2. 3 12
      paddlex/converter.py

+ 3 - 1
paddlex/command.py

@@ -161,6 +161,7 @@ def main():
         assert args.save_dir is not None, "--save_dir should be defined to create onnx model"
 
         model = pdx.load_model(args.model_dir)
+
         if model.status == "Normal" or model.status == "Prune":
             logging.error(
                 "Only support inference model, try to export model first as below,",
@@ -168,7 +169,8 @@ def main():
             logging.error(
                 "paddlex --export_inference --model_dir model_path --save_dir infer_model"
             )
-        pdx.converter.export_onnx_model(model, args.save_dir, args.onnx_opset)
+        save_file = os.path.join(args.save_dir, 'paddle2onnx_model.onnx') 
+        pdx.converter.export_onnx_model(model, save_file, args.onnx_opset)
 
     if args.data_conversion:
         assert args.source is not None, "--source should be defined while converting dataset"

+ 3 - 12
paddlex/converter.py

@@ -19,15 +19,6 @@ import sys
 import paddlex as pdx
 import paddlex.utils.logging as logging
 
-__all__ = ['export_onnx']
-
-
-def export_onnx(model_dir, save_dir, fixed_input_shape):
-    assert len(fixed_input_shape) == 2, "len of fixed input shape must == 2"
-    model = pdx.load_model(model_dir, fixed_input_shape)
-    model_name = os.path.basename(model_dir.strip('/')).split('/')[-1]
-    export_onnx_model(model, save_dir)
-
 
 class MultiClassNMS4OpenVINO():
     """
@@ -265,7 +256,7 @@ class MultiClassNMS4OpenVINO():
                 axes=[0])
     
 
-def export_onnx_model(model, save_dir, opset_version=10):
+def export_onnx_model(model, save_file, opset_version=10):
     if model.__class__.__name__ == "FastSCNN" or (
             model.model_type == "detector" and
             model.__class__.__name__ != "YOLOv3"):
@@ -284,9 +275,9 @@ def export_onnx_model(model, save_dir, opset_version=10):
         )
     
     p2o.register_op_mapper('multiclass_nms', MultiClassNMS4OpenVINO)
-
+    
     p2o.program2onnx(
         model.test_prog,
         scope=model.scope,
-        save_file=save_dir,
+        save_file=save_file,
         opset_version=opset_version)