test_ops.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import print_function
  15. import os, sys
  16. # add python path of PadleDetection to sys.path
  17. parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 4)))
  18. if parent_path not in sys.path:
  19. sys.path.append(parent_path)
  20. import unittest
  21. import numpy as np
  22. import paddle
  23. import paddle.fluid as fluid
  24. from paddle.fluid.framework import Program, program_guard
  25. from paddle.fluid.dygraph import base
  26. import paddlex.ppdet.modeling.ops as ops
  27. from paddlex.ppdet.modeling.tests.test_base import LayerTest
  28. def make_rois(h, w, rois_num, output_size):
  29. rois = np.zeros((0, 4)).astype('float32')
  30. for roi_num in rois_num:
  31. roi = np.zeros((roi_num, 4)).astype('float32')
  32. roi[:, 0] = np.random.randint(0, h - output_size[0], size=roi_num)
  33. roi[:, 1] = np.random.randint(0, w - output_size[1], size=roi_num)
  34. roi[:, 2] = np.random.randint(roi[:, 0] + output_size[0], h)
  35. roi[:, 3] = np.random.randint(roi[:, 1] + output_size[1], w)
  36. rois = np.vstack((rois, roi))
  37. return rois
  38. def softmax(x):
  39. # clip to shiftx, otherwise, when calc loss with
  40. # log(exp(shiftx)), may get log(0)=INF
  41. shiftx = (x - np.max(x)).clip(-64.)
  42. exps = np.exp(shiftx)
  43. return exps / np.sum(exps)
  44. class TestCollectFpnProposals(LayerTest):
  45. def test_collect_fpn_proposals(self):
  46. multi_bboxes_np = []
  47. multi_scores_np = []
  48. rois_num_per_level_np = []
  49. for i in range(4):
  50. bboxes_np = np.random.rand(5, 4).astype('float32')
  51. scores_np = np.random.rand(5, 1).astype('float32')
  52. rois_num = np.array([2, 3]).astype('int32')
  53. multi_bboxes_np.append(bboxes_np)
  54. multi_scores_np.append(scores_np)
  55. rois_num_per_level_np.append(rois_num)
  56. with self.static_graph():
  57. multi_bboxes = []
  58. multi_scores = []
  59. rois_num_per_level = []
  60. for i in range(4):
  61. bboxes = paddle.static.data(
  62. name='rois' + str(i),
  63. shape=[5, 4],
  64. dtype='float32',
  65. lod_level=1)
  66. scores = paddle.static.data(
  67. name='scores' + str(i),
  68. shape=[5, 1],
  69. dtype='float32',
  70. lod_level=1)
  71. rois_num = paddle.static.data(
  72. name='rois_num' + str(i), shape=[None], dtype='int32')
  73. multi_bboxes.append(bboxes)
  74. multi_scores.append(scores)
  75. rois_num_per_level.append(rois_num)
  76. fpn_rois, rois_num = ops.collect_fpn_proposals(
  77. multi_bboxes,
  78. multi_scores,
  79. 2,
  80. 5,
  81. 10,
  82. rois_num_per_level=rois_num_per_level)
  83. feed = {}
  84. for i in range(4):
  85. feed['rois' + str(i)] = multi_bboxes_np[i]
  86. feed['scores' + str(i)] = multi_scores_np[i]
  87. feed['rois_num' + str(i)] = rois_num_per_level_np[i]
  88. fpn_rois_stat, rois_num_stat = self.get_static_graph_result(
  89. feed=feed, fetch_list=[fpn_rois, rois_num], with_lod=True)
  90. fpn_rois_stat = np.array(fpn_rois_stat)
  91. rois_num_stat = np.array(rois_num_stat)
  92. with self.dynamic_graph():
  93. multi_bboxes_dy = []
  94. multi_scores_dy = []
  95. rois_num_per_level_dy = []
  96. for i in range(4):
  97. bboxes_dy = base.to_variable(multi_bboxes_np[i])
  98. scores_dy = base.to_variable(multi_scores_np[i])
  99. rois_num_dy = base.to_variable(rois_num_per_level_np[i])
  100. multi_bboxes_dy.append(bboxes_dy)
  101. multi_scores_dy.append(scores_dy)
  102. rois_num_per_level_dy.append(rois_num_dy)
  103. fpn_rois_dy, rois_num_dy = ops.collect_fpn_proposals(
  104. multi_bboxes_dy,
  105. multi_scores_dy,
  106. 2,
  107. 5,
  108. 10,
  109. rois_num_per_level=rois_num_per_level_dy)
  110. fpn_rois_dy = fpn_rois_dy.numpy()
  111. rois_num_dy = rois_num_dy.numpy()
  112. self.assertTrue(np.array_equal(fpn_rois_stat, fpn_rois_dy))
  113. self.assertTrue(np.array_equal(rois_num_stat, rois_num_dy))
  114. def test_collect_fpn_proposals_error(self):
  115. def generate_input(bbox_type, score_type, name):
  116. multi_bboxes = []
  117. multi_scores = []
  118. for i in range(4):
  119. bboxes = paddle.static.data(
  120. name='rois' + name + str(i),
  121. shape=[10, 4],
  122. dtype=bbox_type,
  123. lod_level=1)
  124. scores = paddle.static.data(
  125. name='scores' + name + str(i),
  126. shape=[10, 1],
  127. dtype=score_type,
  128. lod_level=1)
  129. multi_bboxes.append(bboxes)
  130. multi_scores.append(scores)
  131. return multi_bboxes, multi_scores
  132. with self.static_graph():
  133. bbox1 = paddle.static.data(
  134. name='rois', shape=[5, 10, 4], dtype='float32', lod_level=1)
  135. score1 = paddle.static.data(
  136. name='scores', shape=[5, 10, 1], dtype='float32', lod_level=1)
  137. bbox2, score2 = generate_input('int32', 'float32', '2')
  138. self.assertRaises(
  139. TypeError,
  140. ops.collect_fpn_proposals,
  141. multi_rois=bbox1,
  142. multi_scores=score1,
  143. min_level=2,
  144. max_level=5,
  145. post_nms_top_n=2000)
  146. self.assertRaises(
  147. TypeError,
  148. ops.collect_fpn_proposals,
  149. multi_rois=bbox2,
  150. multi_scores=score2,
  151. min_level=2,
  152. max_level=5,
  153. post_nms_top_n=2000)
  154. paddle.disable_static()
  155. class TestDistributeFpnProposals(LayerTest):
  156. def test_distribute_fpn_proposals(self):
  157. rois_np = np.random.rand(10, 4).astype('float32')
  158. rois_num_np = np.array([4, 6]).astype('int32')
  159. with self.static_graph():
  160. rois = paddle.static.data(
  161. name='rois', shape=[10, 4], dtype='float32')
  162. rois_num = paddle.static.data(
  163. name='rois_num', shape=[None], dtype='int32')
  164. multi_rois, restore_ind, rois_num_per_level = ops.distribute_fpn_proposals(
  165. fpn_rois=rois,
  166. min_level=2,
  167. max_level=5,
  168. refer_level=4,
  169. refer_scale=224,
  170. rois_num=rois_num)
  171. fetch_list = multi_rois + [restore_ind] + rois_num_per_level
  172. output_stat = self.get_static_graph_result(
  173. feed={'rois': rois_np,
  174. 'rois_num': rois_num_np},
  175. fetch_list=fetch_list,
  176. with_lod=True)
  177. output_stat_np = []
  178. for output in output_stat:
  179. output_np = np.array(output)
  180. if len(output_np) > 0:
  181. output_stat_np.append(output_np)
  182. with self.dynamic_graph():
  183. rois_dy = base.to_variable(rois_np)
  184. rois_num_dy = base.to_variable(rois_num_np)
  185. multi_rois_dy, restore_ind_dy, rois_num_per_level_dy = ops.distribute_fpn_proposals(
  186. fpn_rois=rois_dy,
  187. min_level=2,
  188. max_level=5,
  189. refer_level=4,
  190. refer_scale=224,
  191. rois_num=rois_num_dy)
  192. output_dy = multi_rois_dy + [restore_ind_dy] + rois_num_per_level_dy
  193. output_dy_np = []
  194. for output in output_dy:
  195. output_np = output.numpy()
  196. if len(output_np) > 0:
  197. output_dy_np.append(output_np)
  198. for res_stat, res_dy in zip(output_stat_np, output_dy_np):
  199. self.assertTrue(np.array_equal(res_stat, res_dy))
  200. def test_distribute_fpn_proposals_error(self):
  201. with self.static_graph():
  202. fpn_rois = paddle.static.data(
  203. name='data_error', shape=[10, 4], dtype='int32', lod_level=1)
  204. self.assertRaises(
  205. TypeError,
  206. ops.distribute_fpn_proposals,
  207. fpn_rois=fpn_rois,
  208. min_level=2,
  209. max_level=5,
  210. refer_level=4,
  211. refer_scale=224)
  212. paddle.disable_static()
  213. class TestROIAlign(LayerTest):
  214. def test_roi_align(self):
  215. b, c, h, w = 2, 12, 20, 20
  216. inputs_np = np.random.rand(b, c, h, w).astype('float32')
  217. rois_num = [4, 6]
  218. output_size = (7, 7)
  219. rois_np = make_rois(h, w, rois_num, output_size)
  220. rois_num_np = np.array(rois_num).astype('int32')
  221. with self.static_graph():
  222. inputs = paddle.static.data(
  223. name='inputs', shape=[b, c, h, w], dtype='float32')
  224. rois = paddle.static.data(
  225. name='rois', shape=[10, 4], dtype='float32')
  226. rois_num = paddle.static.data(
  227. name='rois_num', shape=[None], dtype='int32')
  228. output = ops.roi_align(
  229. input=inputs,
  230. rois=rois,
  231. output_size=output_size,
  232. rois_num=rois_num)
  233. output_np, = self.get_static_graph_result(
  234. feed={
  235. 'inputs': inputs_np,
  236. 'rois': rois_np,
  237. 'rois_num': rois_num_np
  238. },
  239. fetch_list=output,
  240. with_lod=False)
  241. with self.dynamic_graph():
  242. inputs_dy = base.to_variable(inputs_np)
  243. rois_dy = base.to_variable(rois_np)
  244. rois_num_dy = base.to_variable(rois_num_np)
  245. output_dy = ops.roi_align(
  246. input=inputs_dy,
  247. rois=rois_dy,
  248. output_size=output_size,
  249. rois_num=rois_num_dy)
  250. output_dy_np = output_dy.numpy()
  251. self.assertTrue(np.array_equal(output_np, output_dy_np))
  252. def test_roi_align_error(self):
  253. with self.static_graph():
  254. inputs = paddle.static.data(
  255. name='inputs', shape=[2, 12, 20, 20], dtype='float32')
  256. rois = paddle.static.data(
  257. name='data_error', shape=[10, 4], dtype='int32', lod_level=1)
  258. self.assertRaises(
  259. TypeError,
  260. ops.roi_align,
  261. input=inputs,
  262. rois=rois,
  263. output_size=(7, 7))
  264. paddle.disable_static()
  265. class TestROIPool(LayerTest):
  266. def test_roi_pool(self):
  267. b, c, h, w = 2, 12, 20, 20
  268. inputs_np = np.random.rand(b, c, h, w).astype('float32')
  269. rois_num = [4, 6]
  270. output_size = (7, 7)
  271. rois_np = make_rois(h, w, rois_num, output_size)
  272. rois_num_np = np.array(rois_num).astype('int32')
  273. with self.static_graph():
  274. inputs = paddle.static.data(
  275. name='inputs', shape=[b, c, h, w], dtype='float32')
  276. rois = paddle.static.data(
  277. name='rois', shape=[10, 4], dtype='float32')
  278. rois_num = paddle.static.data(
  279. name='rois_num', shape=[None], dtype='int32')
  280. output, _ = ops.roi_pool(
  281. input=inputs,
  282. rois=rois,
  283. output_size=output_size,
  284. rois_num=rois_num)
  285. output_np, = self.get_static_graph_result(
  286. feed={
  287. 'inputs': inputs_np,
  288. 'rois': rois_np,
  289. 'rois_num': rois_num_np
  290. },
  291. fetch_list=[output],
  292. with_lod=False)
  293. with self.dynamic_graph():
  294. inputs_dy = base.to_variable(inputs_np)
  295. rois_dy = base.to_variable(rois_np)
  296. rois_num_dy = base.to_variable(rois_num_np)
  297. output_dy, _ = ops.roi_pool(
  298. input=inputs_dy,
  299. rois=rois_dy,
  300. output_size=output_size,
  301. rois_num=rois_num_dy)
  302. output_dy_np = output_dy.numpy()
  303. self.assertTrue(np.array_equal(output_np, output_dy_np))
  304. def test_roi_pool_error(self):
  305. with self.static_graph():
  306. inputs = paddle.static.data(
  307. name='inputs', shape=[2, 12, 20, 20], dtype='float32')
  308. rois = paddle.static.data(
  309. name='data_error', shape=[10, 4], dtype='int32', lod_level=1)
  310. self.assertRaises(
  311. TypeError,
  312. ops.roi_pool,
  313. input=inputs,
  314. rois=rois,
  315. output_size=(7, 7))
  316. paddle.disable_static()
  317. class TestIoUSimilarity(LayerTest):
  318. def test_iou_similarity(self):
  319. b, c, h, w = 2, 12, 20, 20
  320. inputs_np = np.random.rand(b, c, h, w).astype('float32')
  321. output_size = (7, 7)
  322. x_np = make_rois(h, w, [20], output_size)
  323. y_np = make_rois(h, w, [10], output_size)
  324. with self.static_graph():
  325. x = paddle.static.data(name='x', shape=[20, 4], dtype='float32')
  326. y = paddle.static.data(name='y', shape=[10, 4], dtype='float32')
  327. iou = ops.iou_similarity(x=x, y=y)
  328. iou_np, = self.get_static_graph_result(
  329. feed={
  330. 'x': x_np,
  331. 'y': y_np,
  332. }, fetch_list=[iou], with_lod=False)
  333. with self.dynamic_graph():
  334. x_dy = base.to_variable(x_np)
  335. y_dy = base.to_variable(y_np)
  336. iou_dy = ops.iou_similarity(x=x_dy, y=y_dy)
  337. iou_dy_np = iou_dy.numpy()
  338. self.assertTrue(np.array_equal(iou_np, iou_dy_np))
  339. class TestBipartiteMatch(LayerTest):
  340. def test_bipartite_match(self):
  341. distance = np.random.random((20, 10)).astype('float32')
  342. with self.static_graph():
  343. x = paddle.static.data(name='x', shape=[20, 10], dtype='float32')
  344. match_indices, match_dist = ops.bipartite_match(
  345. x, match_type='per_prediction', dist_threshold=0.5)
  346. match_indices_np, match_dist_np = self.get_static_graph_result(
  347. feed={'x': distance, },
  348. fetch_list=[match_indices, match_dist],
  349. with_lod=False)
  350. with self.dynamic_graph():
  351. x_dy = base.to_variable(distance)
  352. match_indices_dy, match_dist_dy = ops.bipartite_match(
  353. x_dy, match_type='per_prediction', dist_threshold=0.5)
  354. match_indices_dy_np = match_indices_dy.numpy()
  355. match_dist_dy_np = match_dist_dy.numpy()
  356. self.assertTrue(np.array_equal(match_indices_np, match_indices_dy_np))
  357. self.assertTrue(np.array_equal(match_dist_np, match_dist_dy_np))
  358. class TestYoloBox(LayerTest):
  359. def test_yolo_box(self):
  360. # x shape [N C H W], C=K * (5 + class_num), class_num=10, K=2
  361. np_x = np.random.random([1, 30, 7, 7]).astype('float32')
  362. np_origin_shape = np.array([[608, 608]], dtype='int32')
  363. class_num = 10
  364. conf_thresh = 0.01
  365. downsample_ratio = 32
  366. scale_x_y = 1.2
  367. # static
  368. with self.static_graph():
  369. # x shape [N C H W], C=K * (5 + class_num), class_num=10, K=2
  370. x = paddle.static.data(
  371. name='x', shape=[1, 30, 7, 7], dtype='float32')
  372. origin_shape = paddle.static.data(
  373. name='origin_shape', shape=[1, 2], dtype='int32')
  374. boxes, scores = ops.yolo_box(
  375. x,
  376. origin_shape, [10, 13, 30, 13],
  377. class_num,
  378. conf_thresh,
  379. downsample_ratio,
  380. scale_x_y=scale_x_y)
  381. boxes_np, scores_np = self.get_static_graph_result(
  382. feed={
  383. 'x': np_x,
  384. 'origin_shape': np_origin_shape,
  385. },
  386. fetch_list=[boxes, scores],
  387. with_lod=False)
  388. # dygraph
  389. with self.dynamic_graph():
  390. x_dy = fluid.layers.assign(np_x)
  391. origin_shape_dy = fluid.layers.assign(np_origin_shape)
  392. boxes_dy, scores_dy = ops.yolo_box(
  393. x_dy,
  394. origin_shape_dy, [10, 13, 30, 13],
  395. 10,
  396. 0.01,
  397. 32,
  398. scale_x_y=scale_x_y)
  399. boxes_dy_np = boxes_dy.numpy()
  400. scores_dy_np = scores_dy.numpy()
  401. self.assertTrue(np.array_equal(boxes_np, boxes_dy_np))
  402. self.assertTrue(np.array_equal(scores_np, scores_dy_np))
  403. def test_yolo_box_error(self):
  404. with self.static_graph():
  405. # x shape [N C H W], C=K * (5 + class_num), class_num=10, K=2
  406. x = paddle.static.data(
  407. name='x', shape=[1, 30, 7, 7], dtype='float32')
  408. origin_shape = paddle.static.data(
  409. name='origin_shape', shape=[1, 2], dtype='int32')
  410. self.assertRaises(
  411. TypeError,
  412. ops.yolo_box,
  413. x,
  414. origin_shape, [10, 13, 30, 13],
  415. 10.123,
  416. 0.01,
  417. 32,
  418. scale_x_y=1.2)
  419. paddle.disable_static()
  420. class TestPriorBox(LayerTest):
  421. def test_prior_box(self):
  422. input_np = np.random.rand(2, 10, 32, 32).astype('float32')
  423. image_np = np.random.rand(2, 10, 40, 40).astype('float32')
  424. min_sizes = [2, 4]
  425. with self.static_graph():
  426. input = paddle.static.data(
  427. name='input', shape=[2, 10, 32, 32], dtype='float32')
  428. image = paddle.static.data(
  429. name='image', shape=[2, 10, 40, 40], dtype='float32')
  430. box, var = ops.prior_box(
  431. input=input,
  432. image=image,
  433. min_sizes=min_sizes,
  434. clip=True,
  435. flip=True)
  436. box_np, var_np = self.get_static_graph_result(
  437. feed={
  438. 'input': input_np,
  439. 'image': image_np,
  440. },
  441. fetch_list=[box, var],
  442. with_lod=False)
  443. with self.dynamic_graph():
  444. inputs_dy = base.to_variable(input_np)
  445. image_dy = base.to_variable(image_np)
  446. box_dy, var_dy = ops.prior_box(
  447. input=inputs_dy,
  448. image=image_dy,
  449. min_sizes=min_sizes,
  450. clip=True,
  451. flip=True)
  452. box_dy_np = box_dy.numpy()
  453. var_dy_np = var_dy.numpy()
  454. self.assertTrue(np.array_equal(box_np, box_dy_np))
  455. self.assertTrue(np.array_equal(var_np, var_dy_np))
  456. def test_prior_box_error(self):
  457. with self.static_graph():
  458. input = paddle.static.data(
  459. name='input', shape=[2, 10, 32, 32], dtype='int32')
  460. image = paddle.static.data(
  461. name='image', shape=[2, 10, 40, 40], dtype='int32')
  462. self.assertRaises(
  463. TypeError,
  464. ops.prior_box,
  465. input=input,
  466. image=image,
  467. min_sizes=[2, 4],
  468. clip=True,
  469. flip=True)
  470. paddle.disable_static()
  471. class TestMulticlassNms(LayerTest):
  472. def test_multiclass_nms(self):
  473. boxes_np = np.random.rand(10, 81, 4).astype('float32')
  474. scores_np = np.random.rand(10, 81).astype('float32')
  475. rois_num_np = np.array([2, 8]).astype('int32')
  476. with self.static_graph():
  477. boxes = paddle.static.data(
  478. name='bboxes',
  479. shape=[None, 81, 4],
  480. dtype='float32',
  481. lod_level=1)
  482. scores = paddle.static.data(
  483. name='scores', shape=[None, 81], dtype='float32', lod_level=1)
  484. rois_num = paddle.static.data(
  485. name='rois_num', shape=[None], dtype='int32')
  486. output = ops.multiclass_nms(
  487. bboxes=boxes,
  488. scores=scores,
  489. background_label=0,
  490. score_threshold=0.5,
  491. nms_top_k=400,
  492. nms_threshold=0.3,
  493. keep_top_k=200,
  494. normalized=False,
  495. return_index=True,
  496. rois_num=rois_num)
  497. out_np, index_np, nms_rois_num_np = self.get_static_graph_result(
  498. feed={
  499. 'bboxes': boxes_np,
  500. 'scores': scores_np,
  501. 'rois_num': rois_num_np
  502. },
  503. fetch_list=output,
  504. with_lod=True)
  505. out_np = np.array(out_np)
  506. index_np = np.array(index_np)
  507. nms_rois_num_np = np.array(nms_rois_num_np)
  508. with self.dynamic_graph():
  509. boxes_dy = base.to_variable(boxes_np)
  510. scores_dy = base.to_variable(scores_np)
  511. rois_num_dy = base.to_variable(rois_num_np)
  512. out_dy, index_dy, nms_rois_num_dy = ops.multiclass_nms(
  513. bboxes=boxes_dy,
  514. scores=scores_dy,
  515. background_label=0,
  516. score_threshold=0.5,
  517. nms_top_k=400,
  518. nms_threshold=0.3,
  519. keep_top_k=200,
  520. normalized=False,
  521. return_index=True,
  522. rois_num=rois_num_dy)
  523. out_dy_np = out_dy.numpy()
  524. index_dy_np = index_dy.numpy()
  525. nms_rois_num_dy_np = nms_rois_num_dy.numpy()
  526. self.assertTrue(np.array_equal(out_np, out_dy_np))
  527. self.assertTrue(np.array_equal(index_np, index_dy_np))
  528. self.assertTrue(np.array_equal(nms_rois_num_np, nms_rois_num_dy_np))
  529. def test_multiclass_nms_error(self):
  530. with self.static_graph():
  531. boxes = paddle.static.data(
  532. name='bboxes', shape=[81, 4], dtype='float32', lod_level=1)
  533. scores = paddle.static.data(
  534. name='scores', shape=[81], dtype='float32', lod_level=1)
  535. rois_num = paddle.static.data(
  536. name='rois_num', shape=[40, 41], dtype='int32')
  537. self.assertRaises(
  538. TypeError,
  539. ops.multiclass_nms,
  540. boxes=boxes,
  541. scores=scores,
  542. background_label=0,
  543. score_threshold=0.5,
  544. nms_top_k=400,
  545. nms_threshold=0.3,
  546. keep_top_k=200,
  547. normalized=False,
  548. return_index=True,
  549. rois_num=rois_num)
  550. class TestMatrixNMS(LayerTest):
  551. def test_matrix_nms(self):
  552. N, M, C = 7, 1200, 21
  553. BOX_SIZE = 4
  554. nms_top_k = 400
  555. keep_top_k = 200
  556. score_threshold = 0.01
  557. post_threshold = 0.
  558. scores_np = np.random.random((N * M, C)).astype('float32')
  559. scores_np = np.apply_along_axis(softmax, 1, scores_np)
  560. scores_np = np.reshape(scores_np, (N, M, C))
  561. scores_np = np.transpose(scores_np, (0, 2, 1))
  562. boxes_np = np.random.random((N, M, BOX_SIZE)).astype('float32')
  563. boxes_np[:, :, 0:2] = boxes_np[:, :, 0:2] * 0.5
  564. boxes_np[:, :, 2:4] = boxes_np[:, :, 2:4] * 0.5 + 0.5
  565. with self.static_graph():
  566. boxes = paddle.static.data(
  567. name='boxes', shape=[N, M, BOX_SIZE], dtype='float32')
  568. scores = paddle.static.data(
  569. name='scores', shape=[N, C, M], dtype='float32')
  570. out, index, _ = ops.matrix_nms(
  571. bboxes=boxes,
  572. scores=scores,
  573. score_threshold=score_threshold,
  574. post_threshold=post_threshold,
  575. nms_top_k=nms_top_k,
  576. keep_top_k=keep_top_k,
  577. return_index=True)
  578. out_np, index_np = self.get_static_graph_result(
  579. feed={'boxes': boxes_np,
  580. 'scores': scores_np},
  581. fetch_list=[out, index],
  582. with_lod=True)
  583. with self.dynamic_graph():
  584. boxes_dy = base.to_variable(boxes_np)
  585. scores_dy = base.to_variable(scores_np)
  586. out_dy, index_dy, _ = ops.matrix_nms(
  587. bboxes=boxes_dy,
  588. scores=scores_dy,
  589. score_threshold=score_threshold,
  590. post_threshold=post_threshold,
  591. nms_top_k=nms_top_k,
  592. keep_top_k=keep_top_k,
  593. return_index=True)
  594. out_dy_np = out_dy.numpy()
  595. index_dy_np = index_dy.numpy()
  596. self.assertTrue(np.array_equal(out_np, out_dy_np))
  597. self.assertTrue(np.array_equal(index_np, index_dy_np))
  598. def test_matrix_nms_error(self):
  599. with self.static_graph():
  600. bboxes = paddle.static.data(
  601. name='bboxes', shape=[7, 1200, 4], dtype='float32')
  602. scores = paddle.static.data(
  603. name='data_error', shape=[7, 21, 1200], dtype='int32')
  604. self.assertRaises(
  605. TypeError,
  606. ops.matrix_nms,
  607. bboxes=bboxes,
  608. scores=scores,
  609. score_threshold=0.01,
  610. post_threshold=0.,
  611. nms_top_k=400,
  612. keep_top_k=200,
  613. return_index=True)
  614. paddle.disable_static()
  615. class TestBoxCoder(LayerTest):
  616. def test_box_coder(self):
  617. prior_box_np = np.random.random((81, 4)).astype('float32')
  618. prior_box_var_np = np.random.random((81, 4)).astype('float32')
  619. target_box_np = np.random.random((20, 81, 4)).astype('float32')
  620. # static
  621. with self.static_graph():
  622. prior_box = paddle.static.data(
  623. name='prior_box', shape=[81, 4], dtype='float32')
  624. prior_box_var = paddle.static.data(
  625. name='prior_box_var', shape=[81, 4], dtype='float32')
  626. target_box = paddle.static.data(
  627. name='target_box', shape=[20, 81, 4], dtype='float32')
  628. boxes = ops.box_coder(
  629. prior_box=prior_box,
  630. prior_box_var=prior_box_var,
  631. target_box=target_box,
  632. code_type="decode_center_size",
  633. box_normalized=False)
  634. boxes_np, = self.get_static_graph_result(
  635. feed={
  636. 'prior_box': prior_box_np,
  637. 'prior_box_var': prior_box_var_np,
  638. 'target_box': target_box_np,
  639. },
  640. fetch_list=[boxes],
  641. with_lod=False)
  642. # dygraph
  643. with self.dynamic_graph():
  644. prior_box_dy = base.to_variable(prior_box_np)
  645. prior_box_var_dy = base.to_variable(prior_box_var_np)
  646. target_box_dy = base.to_variable(target_box_np)
  647. boxes_dy = ops.box_coder(
  648. prior_box=prior_box_dy,
  649. prior_box_var=prior_box_var_dy,
  650. target_box=target_box_dy,
  651. code_type="decode_center_size",
  652. box_normalized=False)
  653. boxes_dy_np = boxes_dy.numpy()
  654. self.assertTrue(np.array_equal(boxes_np, boxes_dy_np))
  655. def test_box_coder_error(self):
  656. with self.static_graph():
  657. prior_box = paddle.static.data(
  658. name='prior_box', shape=[81, 4], dtype='int32')
  659. prior_box_var = paddle.static.data(
  660. name='prior_box_var', shape=[81, 4], dtype='float32')
  661. target_box = paddle.static.data(
  662. name='target_box', shape=[20, 81, 4], dtype='float32')
  663. self.assertRaises(TypeError, ops.box_coder, prior_box,
  664. prior_box_var, target_box)
  665. paddle.disable_static()
  666. class TestGenerateProposals(LayerTest):
  667. def test_generate_proposals(self):
  668. scores_np = np.random.rand(2, 3, 4, 4).astype('float32')
  669. bbox_deltas_np = np.random.rand(2, 12, 4, 4).astype('float32')
  670. im_shape_np = np.array([[8, 8], [6, 6]]).astype('float32')
  671. anchors_np = np.reshape(np.arange(4 * 4 * 3 * 4),
  672. [4, 4, 3, 4]).astype('float32')
  673. variances_np = np.ones((4, 4, 3, 4)).astype('float32')
  674. with self.static_graph():
  675. scores = paddle.static.data(
  676. name='scores', shape=[2, 3, 4, 4], dtype='float32')
  677. bbox_deltas = paddle.static.data(
  678. name='bbox_deltas', shape=[2, 12, 4, 4], dtype='float32')
  679. im_shape = paddle.static.data(
  680. name='im_shape', shape=[2, 2], dtype='float32')
  681. anchors = paddle.static.data(
  682. name='anchors', shape=[4, 4, 3, 4], dtype='float32')
  683. variances = paddle.static.data(
  684. name='var', shape=[4, 4, 3, 4], dtype='float32')
  685. rois, roi_probs, rois_num = ops.generate_proposals(
  686. scores,
  687. bbox_deltas,
  688. im_shape,
  689. anchors,
  690. variances,
  691. pre_nms_top_n=10,
  692. post_nms_top_n=5,
  693. return_rois_num=True)
  694. rois_stat, roi_probs_stat, rois_num_stat = self.get_static_graph_result(
  695. feed={
  696. 'scores': scores_np,
  697. 'bbox_deltas': bbox_deltas_np,
  698. 'im_shape': im_shape_np,
  699. 'anchors': anchors_np,
  700. 'var': variances_np
  701. },
  702. fetch_list=[rois, roi_probs, rois_num],
  703. with_lod=True)
  704. with self.dynamic_graph():
  705. scores_dy = base.to_variable(scores_np)
  706. bbox_deltas_dy = base.to_variable(bbox_deltas_np)
  707. im_shape_dy = base.to_variable(im_shape_np)
  708. anchors_dy = base.to_variable(anchors_np)
  709. variances_dy = base.to_variable(variances_np)
  710. rois, roi_probs, rois_num = ops.generate_proposals(
  711. scores_dy,
  712. bbox_deltas_dy,
  713. im_shape_dy,
  714. anchors_dy,
  715. variances_dy,
  716. pre_nms_top_n=10,
  717. post_nms_top_n=5,
  718. return_rois_num=True)
  719. rois_dy = rois.numpy()
  720. roi_probs_dy = roi_probs.numpy()
  721. rois_num_dy = rois_num.numpy()
  722. self.assertTrue(np.array_equal(np.array(rois_stat), rois_dy))
  723. self.assertTrue(np.array_equal(np.array(roi_probs_stat), roi_probs_dy))
  724. self.assertTrue(np.array_equal(np.array(rois_num_stat), rois_num_dy))
  725. if __name__ == '__main__':
  726. unittest.main()