convertor.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. import paddle.fluid as fluid
  16. import os
  17. import sys
  18. import paddlex as pdx
  19. import paddlex.utils.logging as logging
  20. __all__ = ['export_onnx']
  21. def export_onnx(model_dir, save_dir, fixed_input_shape):
  22. assert len(fixed_input_shape) == 2, "len of fixed input shape must == 2"
  23. model = pdx.load_model(model_dir, fixed_input_shape)
  24. model_name = os.path.basename(model_dir.strip('/')).split('/')[-1]
  25. export_onnx_model(model, save_dir)
  26. def export_onnx_model(model, save_dir, opset_version=10):
  27. if model.__class__.__name__ == "FastSCNN" or (
  28. model.model_type != "detector" and
  29. model.__class__.__name__ != "YOLOv3"):
  30. logging.error(
  31. "Only image classifier models, detection models(YOLOv3) and semantic segmentation models(except FastSCNN) are supported to export to ONNX"
  32. )
  33. try:
  34. import x2paddle
  35. if x2paddle.__version__ < '0.7.4':
  36. logging.error("You need to upgrade x2paddle >= 0.7.4")
  37. except:
  38. logging.error(
  39. "You need to install x2paddle first, pip install x2paddle>=0.7.4")
  40. if opset_version == 10 and model.__class__.__name__ == "YOLOv3":
  41. logging.warning(
  42. "Export for openVINO by default, the output of multiclass_nms exported to onnx will contains background. If you need onnx completely consistent with paddle, please use X2Paddle to export"
  43. )
  44. x2paddle.op_mapper.paddle2onnx.opset10.paddle_custom_layer.multiclass_nms.multiclass_nms = multiclass_nms_for_openvino
  45. from x2paddle.op_mapper.paddle2onnx.paddle_op_mapper import PaddleOpMapper
  46. mapper = PaddleOpMapper()
  47. mapper.convert(
  48. model.test_prog,
  49. save_dir,
  50. scope=model.scope,
  51. opset_version=opset_version)
  52. def multiclass_nms_for_openvino(op, block):
  53. """
  54. Convert the paddle multiclass_nms to onnx op.
  55. This op is get the select boxes from origin boxes.
  56. This op is for OpenVINO, which donn't support dynamic shape).
  57. """
  58. import math
  59. import sys
  60. import numpy as np
  61. import paddle.fluid.core as core
  62. import paddle.fluid as fluid
  63. import onnx
  64. import warnings
  65. from onnx import helper, onnx_pb
  66. inputs = dict()
  67. outputs = dict()
  68. attrs = dict()
  69. for name in op.input_names:
  70. inputs[name] = op.input(name)
  71. for name in op.output_names:
  72. outputs[name] = op.output(name)
  73. for name in op.attr_names:
  74. attrs[name] = op.attr(name)
  75. result_name = outputs['Out'][0]
  76. background = attrs['background_label']
  77. normalized = attrs['normalized']
  78. if normalized == False:
  79. warnings.warn(
  80. 'The parameter normalized of multiclass_nms OP of Paddle is False, which has diff with ONNX. \
  81. Please set normalized=True in multiclass_nms of Paddle'
  82. )
  83. #convert the paddle attribute to onnx tensor
  84. name_score_threshold = [outputs['Out'][0] + "@score_threshold"]
  85. name_iou_threshold = [outputs['Out'][0] + "@iou_threshold"]
  86. name_keep_top_k = [outputs['Out'][0] + '@keep_top_k']
  87. name_keep_top_k_2D = [outputs['Out'][0] + '@keep_top_k_1D']
  88. node_score_threshold = onnx.helper.make_node(
  89. 'Constant',
  90. inputs=[],
  91. outputs=name_score_threshold,
  92. value=onnx.helper.make_tensor(
  93. name=name_score_threshold[0] + "@const",
  94. data_type=onnx.TensorProto.FLOAT,
  95. dims=(),
  96. vals=[float(attrs['score_threshold'])]))
  97. node_iou_threshold = onnx.helper.make_node(
  98. 'Constant',
  99. inputs=[],
  100. outputs=name_iou_threshold,
  101. value=onnx.helper.make_tensor(
  102. name=name_iou_threshold[0] + "@const",
  103. data_type=onnx.TensorProto.FLOAT,
  104. dims=(),
  105. vals=[float(attrs['nms_threshold'])]))
  106. node_keep_top_k = onnx.helper.make_node(
  107. 'Constant',
  108. inputs=[],
  109. outputs=name_keep_top_k,
  110. value=onnx.helper.make_tensor(
  111. name=name_keep_top_k[0] + "@const",
  112. data_type=onnx.TensorProto.INT64,
  113. dims=(),
  114. vals=[np.int64(attrs['keep_top_k'])]))
  115. node_keep_top_k_2D = onnx.helper.make_node(
  116. 'Constant',
  117. inputs=[],
  118. outputs=name_keep_top_k_2D,
  119. value=onnx.helper.make_tensor(
  120. name=name_keep_top_k_2D[0] + "@const",
  121. data_type=onnx.TensorProto.INT64,
  122. dims=[1, 1],
  123. vals=[np.int64(attrs['keep_top_k'])]))
  124. # the paddle data format is x1,y1,x2,y2
  125. kwargs = {'center_point_box': 0}
  126. name_select_nms = [outputs['Out'][0] + "@select_index"]
  127. node_select_nms= onnx.helper.make_node(
  128. 'NonMaxSuppression',
  129. inputs=inputs['BBoxes'] + inputs['Scores'] + name_keep_top_k +\
  130. name_iou_threshold + name_score_threshold,
  131. outputs=name_select_nms)
  132. # step 1 nodes select the nms class
  133. node_list = [
  134. node_score_threshold, node_iou_threshold, node_keep_top_k,
  135. node_keep_top_k_2D, node_select_nms
  136. ]
  137. # create some const value to use
  138. name_const_value = [result_name+"@const_0",
  139. result_name+"@const_1",\
  140. result_name+"@const_2",\
  141. result_name+"@const_-1"]
  142. value_const_value = [0, 1, 2, -1]
  143. for name, value in zip(name_const_value, value_const_value):
  144. node = onnx.helper.make_node(
  145. 'Constant',
  146. inputs=[],
  147. outputs=[name],
  148. value=onnx.helper.make_tensor(
  149. name=name + "@const",
  150. data_type=onnx.TensorProto.INT64,
  151. dims=[1],
  152. vals=[value]))
  153. node_list.append(node)
  154. # In this code block, we will deocde the raw score data, reshape N * C * M to 1 * N*C*M
  155. # and the same time, decode the select indices to 1 * D, gather the select_indices
  156. outputs_gather_1_ = [result_name + "@gather_1_"]
  157. node_gather_1_ = onnx.helper.make_node(
  158. 'Gather',
  159. inputs=name_select_nms + [result_name + "@const_1"],
  160. outputs=outputs_gather_1_,
  161. axis=1)
  162. node_list.append(node_gather_1_)
  163. outputs_gather_1 = [result_name + "@gather_1"]
  164. node_gather_1 = onnx.helper.make_node(
  165. 'Unsqueeze',
  166. inputs=outputs_gather_1_,
  167. outputs=outputs_gather_1,
  168. axes=[0])
  169. node_list.append(node_gather_1)
  170. outputs_gather_2_ = [result_name + "@gather_2_"]
  171. node_gather_2_ = onnx.helper.make_node(
  172. 'Gather',
  173. inputs=name_select_nms + [result_name + "@const_2"],
  174. outputs=outputs_gather_2_,
  175. axis=1)
  176. node_list.append(node_gather_2_)
  177. outputs_gather_2 = [result_name + "@gather_2"]
  178. node_gather_2 = onnx.helper.make_node(
  179. 'Unsqueeze',
  180. inputs=outputs_gather_2_,
  181. outputs=outputs_gather_2,
  182. axes=[0])
  183. node_list.append(node_gather_2)
  184. # reshape scores N * C * M to (N*C*M) * 1
  185. outputs_reshape_scores_rank1 = [result_name + "@reshape_scores_rank1"]
  186. node_reshape_scores_rank1 = onnx.helper.make_node(
  187. "Reshape",
  188. inputs=inputs['Scores'] + [result_name + "@const_-1"],
  189. outputs=outputs_reshape_scores_rank1)
  190. node_list.append(node_reshape_scores_rank1)
  191. # get the shape of scores
  192. outputs_shape_scores = [result_name + "@shape_scores"]
  193. node_shape_scores = onnx.helper.make_node(
  194. 'Shape', inputs=inputs['Scores'], outputs=outputs_shape_scores)
  195. node_list.append(node_shape_scores)
  196. # gather the index: 2 shape of scores
  197. outputs_gather_scores_dim1 = [result_name + "@gather_scores_dim1"]
  198. node_gather_scores_dim1 = onnx.helper.make_node(
  199. 'Gather',
  200. inputs=outputs_shape_scores + [result_name + "@const_2"],
  201. outputs=outputs_gather_scores_dim1,
  202. axis=0)
  203. node_list.append(node_gather_scores_dim1)
  204. # mul class * M
  205. outputs_mul_classnum_boxnum = [result_name + "@mul_classnum_boxnum"]
  206. node_mul_classnum_boxnum = onnx.helper.make_node(
  207. 'Mul',
  208. inputs=outputs_gather_1 + outputs_gather_scores_dim1,
  209. outputs=outputs_mul_classnum_boxnum)
  210. node_list.append(node_mul_classnum_boxnum)
  211. # add class * M * index
  212. outputs_add_class_M_index = [result_name + "@add_class_M_index"]
  213. node_add_class_M_index = onnx.helper.make_node(
  214. 'Add',
  215. inputs=outputs_mul_classnum_boxnum + outputs_gather_2,
  216. outputs=outputs_add_class_M_index)
  217. node_list.append(node_add_class_M_index)
  218. # Squeeze the indices to 1 dim
  219. outputs_squeeze_select_index = [result_name + "@squeeze_select_index"]
  220. node_squeeze_select_index = onnx.helper.make_node(
  221. 'Squeeze',
  222. inputs=outputs_add_class_M_index,
  223. outputs=outputs_squeeze_select_index,
  224. axes=[0, 2])
  225. node_list.append(node_squeeze_select_index)
  226. # gather the data from flatten scores
  227. outputs_gather_select_scores = [result_name + "@gather_select_scores"]
  228. node_gather_select_scores = onnx.helper.make_node('Gather',
  229. inputs=outputs_reshape_scores_rank1 + \
  230. outputs_squeeze_select_index,
  231. outputs=outputs_gather_select_scores,
  232. axis=0)
  233. node_list.append(node_gather_select_scores)
  234. # get nums to input TopK
  235. outputs_shape_select_num = [result_name + "@shape_select_num"]
  236. node_shape_select_num = onnx.helper.make_node(
  237. 'Shape',
  238. inputs=outputs_gather_select_scores,
  239. outputs=outputs_shape_select_num)
  240. node_list.append(node_shape_select_num)
  241. outputs_gather_select_num = [result_name + "@gather_select_num"]
  242. node_gather_select_num = onnx.helper.make_node(
  243. 'Gather',
  244. inputs=outputs_shape_select_num + [result_name + "@const_0"],
  245. outputs=outputs_gather_select_num,
  246. axis=0)
  247. node_list.append(node_gather_select_num)
  248. outputs_unsqueeze_select_num = [result_name + "@unsqueeze_select_num"]
  249. node_unsqueeze_select_num = onnx.helper.make_node(
  250. 'Unsqueeze',
  251. inputs=outputs_gather_select_num,
  252. outputs=outputs_unsqueeze_select_num,
  253. axes=[0])
  254. node_list.append(node_unsqueeze_select_num)
  255. outputs_concat_topK_select_num = [result_name + "@conat_topK_select_num"]
  256. node_conat_topK_select_num = onnx.helper.make_node(
  257. 'Concat',
  258. inputs=outputs_unsqueeze_select_num + name_keep_top_k_2D,
  259. outputs=outputs_concat_topK_select_num,
  260. axis=0)
  261. node_list.append(node_conat_topK_select_num)
  262. outputs_cast_concat_topK_select_num = [
  263. result_name + "@concat_topK_select_num"
  264. ]
  265. node_outputs_cast_concat_topK_select_num = onnx.helper.make_node(
  266. 'Cast',
  267. inputs=outputs_concat_topK_select_num,
  268. outputs=outputs_cast_concat_topK_select_num,
  269. to=6)
  270. node_list.append(node_outputs_cast_concat_topK_select_num)
  271. # get min(topK, num_select)
  272. outputs_compare_topk_num_select = [
  273. result_name + "@compare_topk_num_select"
  274. ]
  275. node_compare_topk_num_select = onnx.helper.make_node(
  276. 'ReduceMin',
  277. inputs=outputs_cast_concat_topK_select_num,
  278. outputs=outputs_compare_topk_num_select,
  279. keepdims=0)
  280. node_list.append(node_compare_topk_num_select)
  281. # unsqueeze the indices to 1D tensor
  282. outputs_unsqueeze_topk_select_indices = [
  283. result_name + "@unsqueeze_topk_select_indices"
  284. ]
  285. node_unsqueeze_topk_select_indices = onnx.helper.make_node(
  286. 'Unsqueeze',
  287. inputs=outputs_compare_topk_num_select,
  288. outputs=outputs_unsqueeze_topk_select_indices,
  289. axes=[0])
  290. node_list.append(node_unsqueeze_topk_select_indices)
  291. # cast the indices to INT64
  292. outputs_cast_topk_indices = [result_name + "@cast_topk_indices"]
  293. node_cast_topk_indices = onnx.helper.make_node(
  294. 'Cast',
  295. inputs=outputs_unsqueeze_topk_select_indices,
  296. outputs=outputs_cast_topk_indices,
  297. to=7)
  298. node_list.append(node_cast_topk_indices)
  299. # select topk scores indices
  300. outputs_topk_select_topk_indices = [result_name + "@topk_select_topk_values",\
  301. result_name + "@topk_select_topk_indices"]
  302. node_topk_select_topk_indices = onnx.helper.make_node(
  303. 'TopK',
  304. inputs=outputs_gather_select_scores + outputs_cast_topk_indices,
  305. outputs=outputs_topk_select_topk_indices)
  306. node_list.append(node_topk_select_topk_indices)
  307. # gather topk label, scores, boxes
  308. outputs_gather_topk_scores = [result_name + "@gather_topk_scores"]
  309. node_gather_topk_scores = onnx.helper.make_node(
  310. 'Gather',
  311. inputs=outputs_gather_select_scores +
  312. [outputs_topk_select_topk_indices[1]],
  313. outputs=outputs_gather_topk_scores,
  314. axis=0)
  315. node_list.append(node_gather_topk_scores)
  316. outputs_gather_topk_class = [result_name + "@gather_topk_class"]
  317. node_gather_topk_class = onnx.helper.make_node(
  318. 'Gather',
  319. inputs=outputs_gather_1 + [outputs_topk_select_topk_indices[1]],
  320. outputs=outputs_gather_topk_class,
  321. axis=1)
  322. node_list.append(node_gather_topk_class)
  323. # gather the boxes need to gather the boxes id, then get boxes
  324. outputs_gather_topk_boxes_id = [result_name + "@gather_topk_boxes_id"]
  325. node_gather_topk_boxes_id = onnx.helper.make_node(
  326. 'Gather',
  327. inputs=outputs_gather_2 + [outputs_topk_select_topk_indices[1]],
  328. outputs=outputs_gather_topk_boxes_id,
  329. axis=1)
  330. node_list.append(node_gather_topk_boxes_id)
  331. # squeeze the gather_topk_boxes_id to 1 dim
  332. outputs_squeeze_topk_boxes_id = [result_name + "@squeeze_topk_boxes_id"]
  333. node_squeeze_topk_boxes_id = onnx.helper.make_node(
  334. 'Squeeze',
  335. inputs=outputs_gather_topk_boxes_id,
  336. outputs=outputs_squeeze_topk_boxes_id,
  337. axes=[0, 2])
  338. node_list.append(node_squeeze_topk_boxes_id)
  339. outputs_gather_select_boxes = [result_name + "@gather_select_boxes"]
  340. node_gather_select_boxes = onnx.helper.make_node(
  341. 'Gather',
  342. inputs=inputs['BBoxes'] + outputs_squeeze_topk_boxes_id,
  343. outputs=outputs_gather_select_boxes,
  344. axis=1)
  345. node_list.append(node_gather_select_boxes)
  346. # concat the final result
  347. # before concat need to cast the class to float
  348. outputs_cast_topk_class = [result_name + "@cast_topk_class"]
  349. node_cast_topk_class = onnx.helper.make_node(
  350. 'Cast',
  351. inputs=outputs_gather_topk_class,
  352. outputs=outputs_cast_topk_class,
  353. to=1)
  354. node_list.append(node_cast_topk_class)
  355. outputs_unsqueeze_topk_scores = [result_name + "@unsqueeze_topk_scores"]
  356. node_unsqueeze_topk_scores = onnx.helper.make_node(
  357. 'Unsqueeze',
  358. inputs=outputs_gather_topk_scores,
  359. outputs=outputs_unsqueeze_topk_scores,
  360. axes=[0, 2])
  361. node_list.append(node_unsqueeze_topk_scores)
  362. inputs_concat_final_results = outputs_cast_topk_class + outputs_unsqueeze_topk_scores +\
  363. outputs_gather_select_boxes
  364. outputs_sort_by_socre_results = [result_name + "@concat_topk_scores"]
  365. node_sort_by_socre_results = onnx.helper.make_node(
  366. 'Concat',
  367. inputs=inputs_concat_final_results,
  368. outputs=outputs_sort_by_socre_results,
  369. axis=2)
  370. node_list.append(node_sort_by_socre_results)
  371. # select topk classes indices
  372. outputs_squeeze_cast_topk_class = [
  373. result_name + "@squeeze_cast_topk_class"
  374. ]
  375. node_squeeze_cast_topk_class = onnx.helper.make_node(
  376. 'Squeeze',
  377. inputs=outputs_cast_topk_class,
  378. outputs=outputs_squeeze_cast_topk_class,
  379. axes=[0, 2])
  380. node_list.append(node_squeeze_cast_topk_class)
  381. outputs_neg_squeeze_cast_topk_class = [
  382. result_name + "@neg_squeeze_cast_topk_class"
  383. ]
  384. node_neg_squeeze_cast_topk_class = onnx.helper.make_node(
  385. 'Neg',
  386. inputs=outputs_squeeze_cast_topk_class,
  387. outputs=outputs_neg_squeeze_cast_topk_class)
  388. node_list.append(node_neg_squeeze_cast_topk_class)
  389. outputs_topk_select_classes_indices = [result_name + "@topk_select_topk_classes_scores",\
  390. result_name + "@topk_select_topk_classes_indices"]
  391. node_topk_select_topk_indices = onnx.helper.make_node(
  392. 'TopK',
  393. inputs=outputs_neg_squeeze_cast_topk_class + outputs_cast_topk_indices,
  394. outputs=outputs_topk_select_classes_indices)
  395. node_list.append(node_topk_select_topk_indices)
  396. outputs_concat_final_results = outputs['Out']
  397. node_concat_final_results = onnx.helper.make_node(
  398. 'Gather',
  399. inputs=outputs_sort_by_socre_results +
  400. [outputs_topk_select_classes_indices[1]],
  401. outputs=outputs_concat_final_results,
  402. axis=1)
  403. node_list.append(node_concat_final_results)
  404. return node_list