converter.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. import paddle.fluid as fluid
  16. import os
  17. import sys
  18. import paddlex as pdx
  19. import paddlex.utils.logging as logging
  20. __all__ = ['export_onnx']
  21. def export_onnx(model_dir, save_dir, fixed_input_shape):
  22. assert len(fixed_input_shape) == 2, "len of fixed input shape must == 2"
  23. model = pdx.load_model(model_dir, fixed_input_shape)
  24. model_name = os.path.basename(model_dir.strip('/')).split('/')[-1]
  25. export_onnx_model(model, save_dir)
  26. def export_onnx_model(model, save_dir, opset_version=10):
  27. if model.__class__.__name__ == "FastSCNN" or (
  28. model.model_type == "detector" and
  29. model.__class__.__name__ != "YOLOv3"):
  30. logging.error(
  31. "Only image classifier models, detection models(YOLOv3) and semantic segmentation models(except FastSCNN) are supported to export to ONNX"
  32. )
  33. try:
  34. import paddle2onnx
  35. except:
  36. logging.error(
  37. "You need to install paddle2onnx first, pip install paddle2onnx")
  38. import paddle2onnx as p2o
  39. if opset_version == 10 and model.__class__.__name__ == "YOLOv3":
  40. logging.warning(
  41. "Export for openVINO by default, the output of multiclass_nms exported to onnx will contains background. If you need onnx completely consistent with paddle, please use paddle2onnx to export"
  42. )
  43. p2o.op_mapper.opset9.paddle_custom_layer.multiclass_nms.multiclass_nms = multiclass_nms_for_openvino
  44. mapper = p2o.PaddleOpMapper()
  45. mapper.convert(
  46. model.test_prog,
  47. save_dir,
  48. scope=model.scope,
  49. opset_version=opset_version)
  50. def multiclass_nms_for_openvino(op, block):
  51. """
  52. Convert the paddle multiclass_nms to onnx op.
  53. This op is get the select boxes from origin boxes.
  54. This op is for OpenVINO, which donn't support dynamic shape).
  55. """
  56. import math
  57. import sys
  58. import numpy as np
  59. import paddle.fluid.core as core
  60. import paddle.fluid as fluid
  61. import onnx
  62. import warnings
  63. from onnx import helper, onnx_pb
  64. inputs = dict()
  65. outputs = dict()
  66. attrs = dict()
  67. for name in op.input_names:
  68. inputs[name] = op.input(name)
  69. for name in op.output_names:
  70. outputs[name] = op.output(name)
  71. for name in op.attr_names:
  72. attrs[name] = op.attr(name)
  73. result_name = outputs['Out'][0]
  74. background = attrs['background_label']
  75. normalized = attrs['normalized']
  76. if normalized == False:
  77. warnings.warn(
  78. 'The parameter normalized of multiclass_nms OP of Paddle is False, which has diff with ONNX. \
  79. Please set normalized=True in multiclass_nms of Paddle'
  80. )
  81. #convert the paddle attribute to onnx tensor
  82. name_score_threshold = [outputs['Out'][0] + "@score_threshold"]
  83. name_iou_threshold = [outputs['Out'][0] + "@iou_threshold"]
  84. name_keep_top_k = [outputs['Out'][0] + '@keep_top_k']
  85. name_keep_top_k_2D = [outputs['Out'][0] + '@keep_top_k_1D']
  86. node_score_threshold = onnx.helper.make_node(
  87. 'Constant',
  88. inputs=[],
  89. outputs=name_score_threshold,
  90. value=onnx.helper.make_tensor(
  91. name=name_score_threshold[0] + "@const",
  92. data_type=onnx.TensorProto.FLOAT,
  93. dims=(),
  94. vals=[float(attrs['score_threshold'])]))
  95. node_iou_threshold = onnx.helper.make_node(
  96. 'Constant',
  97. inputs=[],
  98. outputs=name_iou_threshold,
  99. value=onnx.helper.make_tensor(
  100. name=name_iou_threshold[0] + "@const",
  101. data_type=onnx.TensorProto.FLOAT,
  102. dims=(),
  103. vals=[float(attrs['nms_threshold'])]))
  104. node_keep_top_k = onnx.helper.make_node(
  105. 'Constant',
  106. inputs=[],
  107. outputs=name_keep_top_k,
  108. value=onnx.helper.make_tensor(
  109. name=name_keep_top_k[0] + "@const",
  110. data_type=onnx.TensorProto.INT64,
  111. dims=(),
  112. vals=[np.int64(attrs['keep_top_k'])]))
  113. node_keep_top_k_2D = onnx.helper.make_node(
  114. 'Constant',
  115. inputs=[],
  116. outputs=name_keep_top_k_2D,
  117. value=onnx.helper.make_tensor(
  118. name=name_keep_top_k_2D[0] + "@const",
  119. data_type=onnx.TensorProto.INT64,
  120. dims=[1, 1],
  121. vals=[np.int64(attrs['keep_top_k'])]))
  122. # the paddle data format is x1,y1,x2,y2
  123. kwargs = {'center_point_box': 0}
  124. name_select_nms = [outputs['Out'][0] + "@select_index"]
  125. node_select_nms= onnx.helper.make_node(
  126. 'NonMaxSuppression',
  127. inputs=inputs['BBoxes'] + inputs['Scores'] + name_keep_top_k +\
  128. name_iou_threshold + name_score_threshold,
  129. outputs=name_select_nms)
  130. # step 1 nodes select the nms class
  131. node_list = [
  132. node_score_threshold, node_iou_threshold, node_keep_top_k,
  133. node_keep_top_k_2D, node_select_nms
  134. ]
  135. # create some const value to use
  136. name_const_value = [result_name+"@const_0",
  137. result_name+"@const_1",\
  138. result_name+"@const_2",\
  139. result_name+"@const_-1"]
  140. value_const_value = [0, 1, 2, -1]
  141. for name, value in zip(name_const_value, value_const_value):
  142. node = onnx.helper.make_node(
  143. 'Constant',
  144. inputs=[],
  145. outputs=[name],
  146. value=onnx.helper.make_tensor(
  147. name=name + "@const",
  148. data_type=onnx.TensorProto.INT64,
  149. dims=[1],
  150. vals=[value]))
  151. node_list.append(node)
  152. # In this code block, we will deocde the raw score data, reshape N * C * M to 1 * N*C*M
  153. # and the same time, decode the select indices to 1 * D, gather the select_indices
  154. outputs_gather_1_ = [result_name + "@gather_1_"]
  155. node_gather_1_ = onnx.helper.make_node(
  156. 'Gather',
  157. inputs=name_select_nms + [result_name + "@const_1"],
  158. outputs=outputs_gather_1_,
  159. axis=1)
  160. node_list.append(node_gather_1_)
  161. outputs_gather_1 = [result_name + "@gather_1"]
  162. node_gather_1 = onnx.helper.make_node(
  163. 'Unsqueeze',
  164. inputs=outputs_gather_1_,
  165. outputs=outputs_gather_1,
  166. axes=[0])
  167. node_list.append(node_gather_1)
  168. outputs_gather_2_ = [result_name + "@gather_2_"]
  169. node_gather_2_ = onnx.helper.make_node(
  170. 'Gather',
  171. inputs=name_select_nms + [result_name + "@const_2"],
  172. outputs=outputs_gather_2_,
  173. axis=1)
  174. node_list.append(node_gather_2_)
  175. outputs_gather_2 = [result_name + "@gather_2"]
  176. node_gather_2 = onnx.helper.make_node(
  177. 'Unsqueeze',
  178. inputs=outputs_gather_2_,
  179. outputs=outputs_gather_2,
  180. axes=[0])
  181. node_list.append(node_gather_2)
  182. # reshape scores N * C * M to (N*C*M) * 1
  183. outputs_reshape_scores_rank1 = [result_name + "@reshape_scores_rank1"]
  184. node_reshape_scores_rank1 = onnx.helper.make_node(
  185. "Reshape",
  186. inputs=inputs['Scores'] + [result_name + "@const_-1"],
  187. outputs=outputs_reshape_scores_rank1)
  188. node_list.append(node_reshape_scores_rank1)
  189. # get the shape of scores
  190. outputs_shape_scores = [result_name + "@shape_scores"]
  191. node_shape_scores = onnx.helper.make_node(
  192. 'Shape', inputs=inputs['Scores'], outputs=outputs_shape_scores)
  193. node_list.append(node_shape_scores)
  194. # gather the index: 2 shape of scores
  195. outputs_gather_scores_dim1 = [result_name + "@gather_scores_dim1"]
  196. node_gather_scores_dim1 = onnx.helper.make_node(
  197. 'Gather',
  198. inputs=outputs_shape_scores + [result_name + "@const_2"],
  199. outputs=outputs_gather_scores_dim1,
  200. axis=0)
  201. node_list.append(node_gather_scores_dim1)
  202. # mul class * M
  203. outputs_mul_classnum_boxnum = [result_name + "@mul_classnum_boxnum"]
  204. node_mul_classnum_boxnum = onnx.helper.make_node(
  205. 'Mul',
  206. inputs=outputs_gather_1 + outputs_gather_scores_dim1,
  207. outputs=outputs_mul_classnum_boxnum)
  208. node_list.append(node_mul_classnum_boxnum)
  209. # add class * M * index
  210. outputs_add_class_M_index = [result_name + "@add_class_M_index"]
  211. node_add_class_M_index = onnx.helper.make_node(
  212. 'Add',
  213. inputs=outputs_mul_classnum_boxnum + outputs_gather_2,
  214. outputs=outputs_add_class_M_index)
  215. node_list.append(node_add_class_M_index)
  216. # Squeeze the indices to 1 dim
  217. outputs_squeeze_select_index = [result_name + "@squeeze_select_index"]
  218. node_squeeze_select_index = onnx.helper.make_node(
  219. 'Squeeze',
  220. inputs=outputs_add_class_M_index,
  221. outputs=outputs_squeeze_select_index,
  222. axes=[0, 2])
  223. node_list.append(node_squeeze_select_index)
  224. # gather the data from flatten scores
  225. outputs_gather_select_scores = [result_name + "@gather_select_scores"]
  226. node_gather_select_scores = onnx.helper.make_node('Gather',
  227. inputs=outputs_reshape_scores_rank1 + \
  228. outputs_squeeze_select_index,
  229. outputs=outputs_gather_select_scores,
  230. axis=0)
  231. node_list.append(node_gather_select_scores)
  232. # get nums to input TopK
  233. outputs_shape_select_num = [result_name + "@shape_select_num"]
  234. node_shape_select_num = onnx.helper.make_node(
  235. 'Shape',
  236. inputs=outputs_gather_select_scores,
  237. outputs=outputs_shape_select_num)
  238. node_list.append(node_shape_select_num)
  239. outputs_gather_select_num = [result_name + "@gather_select_num"]
  240. node_gather_select_num = onnx.helper.make_node(
  241. 'Gather',
  242. inputs=outputs_shape_select_num + [result_name + "@const_0"],
  243. outputs=outputs_gather_select_num,
  244. axis=0)
  245. node_list.append(node_gather_select_num)
  246. outputs_unsqueeze_select_num = [result_name + "@unsqueeze_select_num"]
  247. node_unsqueeze_select_num = onnx.helper.make_node(
  248. 'Unsqueeze',
  249. inputs=outputs_gather_select_num,
  250. outputs=outputs_unsqueeze_select_num,
  251. axes=[0])
  252. node_list.append(node_unsqueeze_select_num)
  253. outputs_concat_topK_select_num = [result_name + "@conat_topK_select_num"]
  254. node_conat_topK_select_num = onnx.helper.make_node(
  255. 'Concat',
  256. inputs=outputs_unsqueeze_select_num + name_keep_top_k_2D,
  257. outputs=outputs_concat_topK_select_num,
  258. axis=0)
  259. node_list.append(node_conat_topK_select_num)
  260. outputs_cast_concat_topK_select_num = [
  261. result_name + "@concat_topK_select_num"
  262. ]
  263. node_outputs_cast_concat_topK_select_num = onnx.helper.make_node(
  264. 'Cast',
  265. inputs=outputs_concat_topK_select_num,
  266. outputs=outputs_cast_concat_topK_select_num,
  267. to=6)
  268. node_list.append(node_outputs_cast_concat_topK_select_num)
  269. # get min(topK, num_select)
  270. outputs_compare_topk_num_select = [
  271. result_name + "@compare_topk_num_select"
  272. ]
  273. node_compare_topk_num_select = onnx.helper.make_node(
  274. 'ReduceMin',
  275. inputs=outputs_cast_concat_topK_select_num,
  276. outputs=outputs_compare_topk_num_select,
  277. keepdims=0)
  278. node_list.append(node_compare_topk_num_select)
  279. # unsqueeze the indices to 1D tensor
  280. outputs_unsqueeze_topk_select_indices = [
  281. result_name + "@unsqueeze_topk_select_indices"
  282. ]
  283. node_unsqueeze_topk_select_indices = onnx.helper.make_node(
  284. 'Unsqueeze',
  285. inputs=outputs_compare_topk_num_select,
  286. outputs=outputs_unsqueeze_topk_select_indices,
  287. axes=[0])
  288. node_list.append(node_unsqueeze_topk_select_indices)
  289. # cast the indices to INT64
  290. outputs_cast_topk_indices = [result_name + "@cast_topk_indices"]
  291. node_cast_topk_indices = onnx.helper.make_node(
  292. 'Cast',
  293. inputs=outputs_unsqueeze_topk_select_indices,
  294. outputs=outputs_cast_topk_indices,
  295. to=7)
  296. node_list.append(node_cast_topk_indices)
  297. # select topk scores indices
  298. outputs_topk_select_topk_indices = [result_name + "@topk_select_topk_values",\
  299. result_name + "@topk_select_topk_indices"]
  300. node_topk_select_topk_indices = onnx.helper.make_node(
  301. 'TopK',
  302. inputs=outputs_gather_select_scores + outputs_cast_topk_indices,
  303. outputs=outputs_topk_select_topk_indices)
  304. node_list.append(node_topk_select_topk_indices)
  305. # gather topk label, scores, boxes
  306. outputs_gather_topk_scores = [result_name + "@gather_topk_scores"]
  307. node_gather_topk_scores = onnx.helper.make_node(
  308. 'Gather',
  309. inputs=outputs_gather_select_scores +
  310. [outputs_topk_select_topk_indices[1]],
  311. outputs=outputs_gather_topk_scores,
  312. axis=0)
  313. node_list.append(node_gather_topk_scores)
  314. outputs_gather_topk_class = [result_name + "@gather_topk_class"]
  315. node_gather_topk_class = onnx.helper.make_node(
  316. 'Gather',
  317. inputs=outputs_gather_1 + [outputs_topk_select_topk_indices[1]],
  318. outputs=outputs_gather_topk_class,
  319. axis=1)
  320. node_list.append(node_gather_topk_class)
  321. # gather the boxes need to gather the boxes id, then get boxes
  322. outputs_gather_topk_boxes_id = [result_name + "@gather_topk_boxes_id"]
  323. node_gather_topk_boxes_id = onnx.helper.make_node(
  324. 'Gather',
  325. inputs=outputs_gather_2 + [outputs_topk_select_topk_indices[1]],
  326. outputs=outputs_gather_topk_boxes_id,
  327. axis=1)
  328. node_list.append(node_gather_topk_boxes_id)
  329. # squeeze the gather_topk_boxes_id to 1 dim
  330. outputs_squeeze_topk_boxes_id = [result_name + "@squeeze_topk_boxes_id"]
  331. node_squeeze_topk_boxes_id = onnx.helper.make_node(
  332. 'Squeeze',
  333. inputs=outputs_gather_topk_boxes_id,
  334. outputs=outputs_squeeze_topk_boxes_id,
  335. axes=[0, 2])
  336. node_list.append(node_squeeze_topk_boxes_id)
  337. outputs_gather_select_boxes = [result_name + "@gather_select_boxes"]
  338. node_gather_select_boxes = onnx.helper.make_node(
  339. 'Gather',
  340. inputs=inputs['BBoxes'] + outputs_squeeze_topk_boxes_id,
  341. outputs=outputs_gather_select_boxes,
  342. axis=1)
  343. node_list.append(node_gather_select_boxes)
  344. # concat the final result
  345. # before concat need to cast the class to float
  346. outputs_cast_topk_class = [result_name + "@cast_topk_class"]
  347. node_cast_topk_class = onnx.helper.make_node(
  348. 'Cast',
  349. inputs=outputs_gather_topk_class,
  350. outputs=outputs_cast_topk_class,
  351. to=1)
  352. node_list.append(node_cast_topk_class)
  353. outputs_unsqueeze_topk_scores = [result_name + "@unsqueeze_topk_scores"]
  354. node_unsqueeze_topk_scores = onnx.helper.make_node(
  355. 'Unsqueeze',
  356. inputs=outputs_gather_topk_scores,
  357. outputs=outputs_unsqueeze_topk_scores,
  358. axes=[0, 2])
  359. node_list.append(node_unsqueeze_topk_scores)
  360. inputs_concat_final_results = outputs_cast_topk_class + outputs_unsqueeze_topk_scores +\
  361. outputs_gather_select_boxes
  362. outputs_sort_by_socre_results = [result_name + "@concat_topk_scores"]
  363. node_sort_by_socre_results = onnx.helper.make_node(
  364. 'Concat',
  365. inputs=inputs_concat_final_results,
  366. outputs=outputs_sort_by_socre_results,
  367. axis=2)
  368. node_list.append(node_sort_by_socre_results)
  369. # select topk classes indices
  370. outputs_squeeze_cast_topk_class = [
  371. result_name + "@squeeze_cast_topk_class"
  372. ]
  373. node_squeeze_cast_topk_class = onnx.helper.make_node(
  374. 'Squeeze',
  375. inputs=outputs_cast_topk_class,
  376. outputs=outputs_squeeze_cast_topk_class,
  377. axes=[0, 2])
  378. node_list.append(node_squeeze_cast_topk_class)
  379. outputs_neg_squeeze_cast_topk_class = [
  380. result_name + "@neg_squeeze_cast_topk_class"
  381. ]
  382. node_neg_squeeze_cast_topk_class = onnx.helper.make_node(
  383. 'Neg',
  384. inputs=outputs_squeeze_cast_topk_class,
  385. outputs=outputs_neg_squeeze_cast_topk_class)
  386. node_list.append(node_neg_squeeze_cast_topk_class)
  387. outputs_topk_select_classes_indices = [result_name + "@topk_select_topk_classes_scores",\
  388. result_name + "@topk_select_topk_classes_indices"]
  389. node_topk_select_topk_indices = onnx.helper.make_node(
  390. 'TopK',
  391. inputs=outputs_neg_squeeze_cast_topk_class + outputs_cast_topk_indices,
  392. outputs=outputs_topk_select_classes_indices)
  393. node_list.append(node_topk_select_topk_indices)
  394. outputs_concat_final_results = outputs['Out']
  395. node_concat_final_results = onnx.helper.make_node(
  396. 'Gather',
  397. inputs=outputs_sort_by_socre_results +
  398. [outputs_topk_select_classes_indices[1]],
  399. outputs=outputs_concat_final_results,
  400. axis=1)
  401. node_list.append(node_concat_final_results)
  402. return node_list