convertor.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. import paddle.fluid as fluid
  16. import os
  17. import sys
  18. import paddlex as pdx
  19. import paddlex.utils.logging as logging
  20. __all__ = ['export_onnx']
  21. def export_onnx(model_dir, save_dir, fixed_input_shape):
  22. assert len(fixed_input_shape) == 2, "len of fixed input shape must == 2"
  23. model = pdx.load_model(model_dir, fixed_input_shape)
  24. model_name = os.path.basename(model_dir.strip('/')).split('/')[-1]
  25. export_onnx_model(model, save_dir)
  26. def export_onnx_model(model, save_dir):
  27. support_list = [
  28. 'ResNet18', 'ResNet34', 'ResNet50', 'ResNet101', 'ResNet50_vd',
  29. 'ResNet101_vd', 'ResNet50_vd_ssld', 'ResNet101_vd_ssld', 'DarkNet53',
  30. 'MobileNetV1', 'MobileNetV2', 'DenseNet121', 'DenseNet161',
  31. 'DenseNet201'
  32. ]
  33. if model.__class__.__name__ not in support_list:
  34. raise Exception("Model: {} unsupport export to ONNX".format(
  35. model.__class__.__name__))
  36. try:
  37. from fluid.utils import op_io_info, init_name_prefix
  38. from onnx import helper, checker
  39. import fluid_onnx.ops as ops
  40. from fluid_onnx.variables import paddle_variable_to_onnx_tensor, paddle_onnx_weight
  41. from debug.model_check import debug_model, Tracker
  42. except Exception as e:
  43. logging.error(
  44. "Import Module Failed! Please install paddle2onnx. Related requirements see https://github.com/PaddlePaddle/paddle2onnx."
  45. )
  46. raise e
  47. place = fluid.CPUPlace()
  48. exe = fluid.Executor(place)
  49. inference_scope = fluid.global_scope()
  50. with fluid.scope_guard(inference_scope):
  51. test_input_names = [
  52. var.name for var in list(model.test_inputs.values())
  53. ]
  54. inputs_outputs_list = ["fetch", "feed"]
  55. weights, weights_value_info = [], []
  56. global_block = model.test_prog.global_block()
  57. for var_name in global_block.vars:
  58. var = global_block.var(var_name)
  59. if var_name not in test_input_names\
  60. and var.persistable:
  61. weight, val_info = paddle_onnx_weight(
  62. var=var, scope=inference_scope)
  63. weights.append(weight)
  64. weights_value_info.append(val_info)
  65. # Create inputs
  66. inputs = [
  67. paddle_variable_to_onnx_tensor(v, global_block)
  68. for v in test_input_names
  69. ]
  70. logging.INFO("load the model parameter done.")
  71. onnx_nodes = []
  72. op_check_list = []
  73. op_trackers = []
  74. nms_first_index = -1
  75. nms_outputs = []
  76. for block in model.test_prog.blocks:
  77. for op in block.ops:
  78. if op.type in ops.node_maker:
  79. # TODO: deal with the corner case that vars in
  80. # different blocks have the same name
  81. node_proto = ops.node_maker[str(op.type)](
  82. operator=op, block=block)
  83. op_outputs = []
  84. last_node = None
  85. if isinstance(node_proto, tuple):
  86. onnx_nodes.extend(list(node_proto))
  87. last_node = list(node_proto)
  88. else:
  89. onnx_nodes.append(node_proto)
  90. last_node = [node_proto]
  91. tracker = Tracker(str(op.type), last_node)
  92. op_trackers.append(tracker)
  93. op_check_list.append(str(op.type))
  94. if op.type == "multiclass_nms" and nms_first_index < 0:
  95. nms_first_index = 0
  96. if nms_first_index >= 0:
  97. _, _, output_op = op_io_info(op)
  98. for output in output_op:
  99. nms_outputs.extend(output_op[output])
  100. else:
  101. if op.type not in ['feed', 'fetch']:
  102. op_check_list.append(op.type)
  103. logging.info('The operator sets to run test case.')
  104. logging.info(set(op_check_list))
  105. # Create outputs
  106. # Get the new names for outputs if they've been renamed in nodes' making
  107. renamed_outputs = op_io_info.get_all_renamed_outputs()
  108. test_outputs = list(model.test_outputs.values())
  109. test_outputs_names = [var.name for var in model.test_outputs.values()]
  110. test_outputs_names = [
  111. name if name not in renamed_outputs else renamed_outputs[name]
  112. for name in test_outputs_names
  113. ]
  114. outputs = [
  115. paddle_variable_to_onnx_tensor(v, global_block)
  116. for v in test_outputs_names
  117. ]
  118. # Make graph
  119. onnx_name = 'paddlex.onnx'
  120. onnx_graph = helper.make_graph(
  121. nodes=onnx_nodes,
  122. name=onnx_name,
  123. initializer=weights,
  124. inputs=inputs + weights_value_info,
  125. outputs=outputs)
  126. # Make model
  127. onnx_model = helper.make_model(
  128. onnx_graph, producer_name='PaddlePaddle')
  129. # Model check
  130. checker.check_model(onnx_model)
  131. if onnx_model is not None:
  132. onnx_model_file = os.path.join(save_dir, onnx_name)
  133. if not os.path.exists(save_dir):
  134. os.mkdir(save_dir)
  135. with open(onnx_model_file, 'wb') as f:
  136. f.write(onnx_model.SerializeToString())
  137. logging.info("Saved converted model to path: %s" % onnx_model_file)