|
@@ -19,15 +19,6 @@ import sys
|
|
|
import paddlex as pdx
|
|
import paddlex as pdx
|
|
|
import paddlex.utils.logging as logging
|
|
import paddlex.utils.logging as logging
|
|
|
|
|
|
|
|
-__all__ = ['export_onnx']
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-def export_onnx(model_dir, save_dir, fixed_input_shape):
|
|
|
|
|
- assert len(fixed_input_shape) == 2, "len of fixed input shape must == 2"
|
|
|
|
|
- model = pdx.load_model(model_dir, fixed_input_shape)
|
|
|
|
|
- model_name = os.path.basename(model_dir.strip('/')).split('/')[-1]
|
|
|
|
|
- export_onnx_model(model, save_dir)
|
|
|
|
|
-
|
|
|
|
|
|
|
|
|
|
class MultiClassNMS4OpenVINO():
|
|
class MultiClassNMS4OpenVINO():
|
|
|
"""
|
|
"""
|
|
@@ -265,7 +256,7 @@ class MultiClassNMS4OpenVINO():
|
|
|
axes=[0])
|
|
axes=[0])
|
|
|
|
|
|
|
|
|
|
|
|
|
-def export_onnx_model(model, save_dir, opset_version=10):
|
|
|
|
|
|
|
+def export_onnx_model(model, save_file, opset_version=10):
|
|
|
if model.__class__.__name__ == "FastSCNN" or (
|
|
if model.__class__.__name__ == "FastSCNN" or (
|
|
|
model.model_type == "detector" and
|
|
model.model_type == "detector" and
|
|
|
model.__class__.__name__ != "YOLOv3"):
|
|
model.__class__.__name__ != "YOLOv3"):
|
|
@@ -284,9 +275,9 @@ def export_onnx_model(model, save_dir, opset_version=10):
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
p2o.register_op_mapper('multiclass_nms', MultiClassNMS4OpenVINO)
|
|
p2o.register_op_mapper('multiclass_nms', MultiClassNMS4OpenVINO)
|
|
|
-
|
|
|
|
|
|
|
+
|
|
|
p2o.program2onnx(
|
|
p2o.program2onnx(
|
|
|
model.test_prog,
|
|
model.test_prog,
|
|
|
scope=model.scope,
|
|
scope=model.scope,
|
|
|
- save_file=save_dir,
|
|
|
|
|
|
|
+ save_file=save_file,
|
|
|
opset_version=opset_version)
|
|
opset_version=opset_version)
|