converter.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. import paddle.fluid as fluid
  16. import os
  17. import sys
  18. import paddlex as pdx
  19. import paddlex.utils.logging as logging
  20. class MultiClassNMS4OpenVINO():
  21. """
  22. Convert the paddle multiclass_nms to onnx op.
  23. This op is get the select boxes from origin boxes.
  24. """
  25. @classmethod
  26. def opset_10(cls, graph, node, **kw):
  27. from paddle2onnx.constant import dtypes
  28. import numpy as np
  29. result_name = node.output('Out', 0)
  30. background = node.attr('background_label')
  31. normalized = node.attr('normalized')
  32. if normalized == False:
  33. logging.warning(
  34. "The parameter normalized of multiclass_nms OP of Paddle is False, which has diff with ONNX." \
  35. " Please set normalized=True in multiclass_nms of Paddle, see doc Q1 in" \
  36. " https://github.com/PaddlePaddle/paddle2onnx/blob/develop/FAQ.md")
  37. #convert the paddle attribute to onnx tensor
  38. node_score_threshold = graph.make_node(
  39. 'Constant',
  40. inputs=[],
  41. dtype=dtypes.ONNX.FLOAT,
  42. value=[float(node.attr('score_threshold'))])
  43. node_iou_threshold = graph.make_node(
  44. 'Constant',
  45. inputs=[],
  46. dtype=dtypes.ONNX.FLOAT,
  47. value=[float(node.attr('nms_threshold'))])
  48. node_keep_top_k = graph.make_node(
  49. 'Constant',
  50. inputs=[],
  51. dtype=dtypes.ONNX.INT64,
  52. value=[np.int64(node.attr('keep_top_k'))])
  53. node_keep_top_k_2D = graph.make_node(
  54. 'Constant',
  55. inputs=[],
  56. dtype=dtypes.ONNX.INT64,
  57. dims=[1, 1],
  58. value=[node.attr('keep_top_k')])
  59. # the paddle data format is x1,y1,x2,y2
  60. kwargs = {'center_point_box': 0}
  61. node_select_nms= graph.make_node(
  62. 'NonMaxSuppression',
  63. inputs=[node.input('BBoxes', 0), node.input('Scores', 0), node_keep_top_k,\
  64. node_iou_threshold, node_score_threshold])
  65. # step 1 nodes select the nms class
  66. # create some const value to use
  67. node_const_value = [result_name+"@const_0",
  68. result_name+"@const_1",\
  69. result_name+"@const_2",\
  70. result_name+"@const_-1"]
  71. value_const_value = [0, 1, 2, -1]
  72. for name, value in zip(node_const_value, value_const_value):
  73. graph.make_node(
  74. 'Constant',
  75. layer_name=name,
  76. inputs=[],
  77. outputs=[name],
  78. dtype=dtypes.ONNX.INT64,
  79. value=[value])
  80. # In this code block, we will deocde the raw score data, reshape N * C * M to 1 * N*C*M
  81. # and the same time, decode the select indices to 1 * D, gather the select_indices
  82. node_gather_1 = graph.make_node(
  83. 'Gather',
  84. inputs=[node_select_nms, result_name + "@const_1"],
  85. axis=1)
  86. node_gather_1 = graph.make_node(
  87. 'Unsqueeze', inputs=[node_gather_1], axes=[0])
  88. node_gather_2 = graph.make_node(
  89. 'Gather',
  90. inputs=[node_select_nms, result_name + "@const_2"],
  91. axis=1)
  92. node_gather_2 = graph.make_node(
  93. 'Unsqueeze', inputs=[node_gather_2], axes=[0])
  94. # reshape scores N * C * M to (N*C*M) * 1
  95. node_reshape_scores_rank1 = graph.make_node(
  96. "Reshape",
  97. inputs=[node.input('Scores', 0), result_name + "@const_-1"])
  98. # get the shape of scores
  99. node_shape_scores = graph.make_node(
  100. 'Shape', inputs=node.input('Scores'))
  101. # gather the index: 2 shape of scores
  102. node_gather_scores_dim1 = graph.make_node(
  103. 'Gather',
  104. inputs=[node_shape_scores, result_name + "@const_2"],
  105. axis=0)
  106. # mul class * M
  107. node_mul_classnum_boxnum = graph.make_node(
  108. 'Mul', inputs=[node_gather_1, node_gather_scores_dim1])
  109. # add class * M * index
  110. node_add_class_M_index = graph.make_node(
  111. 'Add', inputs=[node_mul_classnum_boxnum, node_gather_2])
  112. # Squeeze the indices to 1 dim
  113. node_squeeze_select_index = graph.make_node(
  114. 'Squeeze', inputs=[node_add_class_M_index], axes=[0, 2])
  115. # gather the data from flatten scores
  116. node_gather_select_scores = graph.make_node(
  117. 'Gather',
  118. inputs=[node_reshape_scores_rank1, node_squeeze_select_index],
  119. axis=0)
  120. # get nums to input TopK
  121. node_shape_select_num = graph.make_node(
  122. 'Shape', inputs=[node_gather_select_scores])
  123. node_gather_select_num = graph.make_node(
  124. 'Gather',
  125. inputs=[node_shape_select_num, result_name + "@const_0"],
  126. axis=0)
  127. node_unsqueeze_select_num = graph.make_node(
  128. 'Unsqueeze', inputs=[node_gather_select_num], axes=[0])
  129. node_concat_topK_select_num = graph.make_node(
  130. 'Concat',
  131. inputs=[node_unsqueeze_select_num, node_keep_top_k_2D],
  132. axis=0)
  133. node_cast_concat_topK_select_num = graph.make_node(
  134. 'Cast', inputs=[node_concat_topK_select_num], to=6)
  135. # get min(topK, num_select)
  136. node_compare_topk_num_select = graph.make_node(
  137. 'ReduceMin', inputs=[node_cast_concat_topK_select_num], keepdims=0)
  138. # unsqueeze the indices to 1D tensor
  139. node_unsqueeze_topk_select_indices = graph.make_node(
  140. 'Unsqueeze', inputs=[node_compare_topk_num_select], axes=[0])
  141. # cast the indices to INT64
  142. node_cast_topk_indices = graph.make_node(
  143. 'Cast', inputs=[node_unsqueeze_topk_select_indices], to=7)
  144. # select topk scores indices
  145. outputs_topk_select_topk_indices = [result_name + "@topk_select_topk_values",\
  146. result_name + "@topk_select_topk_indices"]
  147. node_topk_select_topk_indices = graph.make_node(
  148. 'TopK',
  149. inputs=[node_gather_select_scores, node_cast_topk_indices],
  150. outputs=outputs_topk_select_topk_indices)
  151. # gather topk label, scores, boxes
  152. node_gather_topk_scores = graph.make_node(
  153. 'Gather',
  154. inputs=[
  155. node_gather_select_scores, outputs_topk_select_topk_indices[1]
  156. ],
  157. axis=0)
  158. node_gather_topk_class = graph.make_node(
  159. 'Gather',
  160. inputs=[node_gather_1, outputs_topk_select_topk_indices[1]],
  161. axis=1)
  162. # gather the boxes need to gather the boxes id, then get boxes
  163. node_gather_topk_boxes_id = graph.make_node(
  164. 'Gather',
  165. inputs=[node_gather_2, outputs_topk_select_topk_indices[1]],
  166. axis=1)
  167. # squeeze the gather_topk_boxes_id to 1 dim
  168. node_squeeze_topk_boxes_id = graph.make_node(
  169. 'Squeeze', inputs=[node_gather_topk_boxes_id], axes=[0, 2])
  170. node_gather_select_boxes = graph.make_node(
  171. 'Gather',
  172. inputs=[node.input('BBoxes', 0), node_squeeze_topk_boxes_id],
  173. axis=1)
  174. # concat the final result
  175. # before concat need to cast the class to float
  176. node_cast_topk_class = graph.make_node(
  177. 'Cast', inputs=[node_gather_topk_class], to=1)
  178. node_unsqueeze_topk_scores = graph.make_node(
  179. 'Unsqueeze', inputs=[node_gather_topk_scores], axes=[0, 2])
  180. inputs_concat_final_results = [node_cast_topk_class, node_unsqueeze_topk_scores, \
  181. node_gather_select_boxes]
  182. node_sort_by_socre_results = graph.make_node(
  183. 'Concat', inputs=inputs_concat_final_results, axis=2)
  184. # select topk classes indices
  185. node_squeeze_cast_topk_class = graph.make_node(
  186. 'Squeeze', inputs=[node_cast_topk_class], axes=[0, 2])
  187. node_neg_squeeze_cast_topk_class = graph.make_node(
  188. 'Neg', inputs=[node_squeeze_cast_topk_class])
  189. outputs_topk_select_classes_indices = [result_name + "@topk_select_topk_classes_scores",\
  190. result_name + "@topk_select_topk_classes_indices"]
  191. node_topk_select_topk_indices = graph.make_node(
  192. 'TopK',
  193. inputs=[node_neg_squeeze_cast_topk_class, node_cast_topk_indices],
  194. outputs=outputs_topk_select_classes_indices)
  195. node_concat_final_results = graph.make_node(
  196. 'Gather',
  197. inputs=[
  198. node_sort_by_socre_results,
  199. outputs_topk_select_classes_indices[1]
  200. ],
  201. axis=1)
  202. node_concat_final_results = graph.make_node(
  203. 'Squeeze',
  204. inputs=[node_concat_final_results],
  205. outputs=[node.output('Out', 0)],
  206. axes=[0])
  207. if node.type == 'multiclass_nms2':
  208. graph.make_node(
  209. 'Squeeze',
  210. inputs=[node_gather_2],
  211. outputs=node.output('Index'),
  212. axes=[0])
  213. def export_onnx_model(model, save_file, opset_version=10):
  214. if model.__class__.__name__ == "FastSCNN" or (
  215. model.model_type == "detector" and
  216. model.__class__.__name__ != "YOLOv3"):
  217. logging.error(
  218. "Only image classifier models, detection models(YOLOv3) and semantic segmentation models(except FastSCNN) are supported to export to ONNX"
  219. )
  220. try:
  221. import paddle2onnx
  222. except:
  223. logging.error(
  224. "You need to install paddle2onnx first, pip install paddle2onnx==0.4"
  225. )
  226. import paddle2onnx as p2o
  227. if p2o.__version__ != '0.4':
  228. logging.error(
  229. "You need install paddle2onnx==0.4, but the version of paddle2onnx is {}".
  230. format(p2o.__version__))
  231. if opset_version == 10 and model.__class__.__name__ == "YOLOv3":
  232. logging.warning(
  233. "Export for openVINO by default, the output of multiclass_nms exported to onnx will contains background. If you need onnx completely consistent with paddle, please use paddle2onnx to export"
  234. )
  235. p2o.register_op_mapper('multiclass_nms', MultiClassNMS4OpenVINO)
  236. p2o.program2onnx(
  237. model.test_prog,
  238. scope=model.scope,
  239. save_file=save_file,
  240. opset_version=opset_version)