solov2_head.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import paddle
  18. from paddle import ParamAttr
  19. import paddle.nn as nn
  20. import paddle.nn.functional as F
  21. from paddle.nn.initializer import Normal, Constant
  22. from paddlex.ppdet.modeling.layers import ConvNormLayer, MaskMatrixNMS, DropBlock
  23. from paddlex.ppdet.core.workspace import register
  24. from six.moves import zip
  25. import numpy as np
  26. __all__ = ['SOLOv2Head']
  27. @register
  28. class SOLOv2MaskHead(nn.Layer):
  29. """
  30. MaskHead of SOLOv2
  31. Args:
  32. in_channels (int): The channel number of input Tensor.
  33. out_channels (int): The channel number of output Tensor.
  34. start_level (int): The position where the input starts.
  35. end_level (int): The position where the input ends.
  36. use_dcn_in_tower (bool): Whether to use dcn in tower or not.
  37. """
  38. __shared__ = ['norm_type']
  39. def __init__(self,
  40. in_channels=256,
  41. mid_channels=128,
  42. out_channels=256,
  43. start_level=0,
  44. end_level=3,
  45. use_dcn_in_tower=False,
  46. norm_type='gn'):
  47. super(SOLOv2MaskHead, self).__init__()
  48. assert start_level >= 0 and end_level >= start_level
  49. self.in_channels = in_channels
  50. self.out_channels = out_channels
  51. self.mid_channels = mid_channels
  52. self.use_dcn_in_tower = use_dcn_in_tower
  53. self.range_level = end_level - start_level + 1
  54. self.use_dcn = True if self.use_dcn_in_tower else False
  55. self.convs_all_levels = []
  56. self.norm_type = norm_type
  57. for i in range(start_level, end_level + 1):
  58. conv_feat_name = 'mask_feat_head.convs_all_levels.{}'.format(i)
  59. conv_pre_feat = nn.Sequential()
  60. if i == start_level:
  61. conv_pre_feat.add_sublayer(
  62. conv_feat_name + '.conv' + str(i),
  63. ConvNormLayer(
  64. ch_in=self.in_channels,
  65. ch_out=self.mid_channels,
  66. filter_size=3,
  67. stride=1,
  68. use_dcn=self.use_dcn,
  69. norm_type=self.norm_type))
  70. self.add_sublayer('conv_pre_feat' + str(i), conv_pre_feat)
  71. self.convs_all_levels.append(conv_pre_feat)
  72. else:
  73. for j in range(i):
  74. ch_in = 0
  75. if j == 0:
  76. ch_in = self.in_channels + 2 if i == end_level else self.in_channels
  77. else:
  78. ch_in = self.mid_channels
  79. conv_pre_feat.add_sublayer(
  80. conv_feat_name + '.conv' + str(j),
  81. ConvNormLayer(
  82. ch_in=ch_in,
  83. ch_out=self.mid_channels,
  84. filter_size=3,
  85. stride=1,
  86. use_dcn=self.use_dcn,
  87. norm_type=self.norm_type))
  88. conv_pre_feat.add_sublayer(
  89. conv_feat_name + '.conv' + str(j) + 'act', nn.ReLU())
  90. conv_pre_feat.add_sublayer(
  91. 'upsample' + str(i) + str(j),
  92. nn.Upsample(
  93. scale_factor=2, mode='bilinear'))
  94. self.add_sublayer('conv_pre_feat' + str(i), conv_pre_feat)
  95. self.convs_all_levels.append(conv_pre_feat)
  96. conv_pred_name = 'mask_feat_head.conv_pred.0'
  97. self.conv_pred = self.add_sublayer(
  98. conv_pred_name,
  99. ConvNormLayer(
  100. ch_in=self.mid_channels,
  101. ch_out=self.out_channels,
  102. filter_size=1,
  103. stride=1,
  104. use_dcn=self.use_dcn,
  105. norm_type=self.norm_type))
  106. def forward(self, inputs):
  107. """
  108. Get SOLOv2MaskHead output.
  109. Args:
  110. inputs(list[Tensor]): feature map from each necks with shape of [N, C, H, W]
  111. Returns:
  112. ins_pred(Tensor): Output of SOLOv2MaskHead head
  113. """
  114. feat_all_level = F.relu(self.convs_all_levels[0](inputs[0]))
  115. for i in range(1, self.range_level):
  116. input_p = inputs[i]
  117. if i == (self.range_level - 1):
  118. input_feat = input_p
  119. x_range = paddle.linspace(
  120. -1, 1, paddle.shape(input_feat)[-1], dtype='float32')
  121. y_range = paddle.linspace(
  122. -1, 1, paddle.shape(input_feat)[-2], dtype='float32')
  123. y, x = paddle.meshgrid([y_range, x_range])
  124. x = paddle.unsqueeze(x, [0, 1])
  125. y = paddle.unsqueeze(y, [0, 1])
  126. y = paddle.expand(
  127. y, shape=[paddle.shape(input_feat)[0], 1, -1, -1])
  128. x = paddle.expand(
  129. x, shape=[paddle.shape(input_feat)[0], 1, -1, -1])
  130. coord_feat = paddle.concat([x, y], axis=1)
  131. input_p = paddle.concat([input_p, coord_feat], axis=1)
  132. feat_all_level = paddle.add(feat_all_level,
  133. self.convs_all_levels[i](input_p))
  134. ins_pred = F.relu(self.conv_pred(feat_all_level))
  135. return ins_pred
  136. @register
  137. class SOLOv2Head(nn.Layer):
  138. """
  139. Head block for SOLOv2 network
  140. Args:
  141. num_classes (int): Number of output classes.
  142. in_channels (int): Number of input channels.
  143. seg_feat_channels (int): Num_filters of kernel & categroy branch convolution operation.
  144. stacked_convs (int): Times of convolution operation.
  145. num_grids (list[int]): List of feature map grids size.
  146. kernel_out_channels (int): Number of output channels in kernel branch.
  147. dcn_v2_stages (list): Which stage use dcn v2 in tower. It is between [0, stacked_convs).
  148. segm_strides (list[int]): List of segmentation area stride.
  149. solov2_loss (object): SOLOv2Loss instance.
  150. score_threshold (float): Threshold of categroy score.
  151. mask_nms (object): MaskMatrixNMS instance.
  152. """
  153. __inject__ = ['solov2_loss', 'mask_nms']
  154. __shared__ = ['norm_type', 'num_classes']
  155. def __init__(self,
  156. num_classes=80,
  157. in_channels=256,
  158. seg_feat_channels=256,
  159. stacked_convs=4,
  160. num_grids=[40, 36, 24, 16, 12],
  161. kernel_out_channels=256,
  162. dcn_v2_stages=[],
  163. segm_strides=[8, 8, 16, 32, 32],
  164. solov2_loss=None,
  165. score_threshold=0.1,
  166. mask_threshold=0.5,
  167. mask_nms=None,
  168. norm_type='gn',
  169. drop_block=False):
  170. super(SOLOv2Head, self).__init__()
  171. self.num_classes = num_classes
  172. self.in_channels = in_channels
  173. self.seg_num_grids = num_grids
  174. self.cate_out_channels = self.num_classes
  175. self.seg_feat_channels = seg_feat_channels
  176. self.stacked_convs = stacked_convs
  177. self.kernel_out_channels = kernel_out_channels
  178. self.dcn_v2_stages = dcn_v2_stages
  179. self.segm_strides = segm_strides
  180. self.solov2_loss = solov2_loss
  181. self.mask_nms = mask_nms
  182. self.score_threshold = score_threshold
  183. self.mask_threshold = mask_threshold
  184. self.norm_type = norm_type
  185. self.drop_block = drop_block
  186. self.kernel_pred_convs = []
  187. self.cate_pred_convs = []
  188. for i in range(self.stacked_convs):
  189. use_dcn = True if i in self.dcn_v2_stages else False
  190. ch_in = self.in_channels + 2 if i == 0 else self.seg_feat_channels
  191. kernel_conv = self.add_sublayer(
  192. 'bbox_head.kernel_convs.' + str(i),
  193. ConvNormLayer(
  194. ch_in=ch_in,
  195. ch_out=self.seg_feat_channels,
  196. filter_size=3,
  197. stride=1,
  198. use_dcn=use_dcn,
  199. norm_type=self.norm_type))
  200. self.kernel_pred_convs.append(kernel_conv)
  201. ch_in = self.in_channels if i == 0 else self.seg_feat_channels
  202. cate_conv = self.add_sublayer(
  203. 'bbox_head.cate_convs.' + str(i),
  204. ConvNormLayer(
  205. ch_in=ch_in,
  206. ch_out=self.seg_feat_channels,
  207. filter_size=3,
  208. stride=1,
  209. use_dcn=use_dcn,
  210. norm_type=self.norm_type))
  211. self.cate_pred_convs.append(cate_conv)
  212. self.solo_kernel = self.add_sublayer(
  213. 'bbox_head.solo_kernel',
  214. nn.Conv2D(
  215. self.seg_feat_channels,
  216. self.kernel_out_channels,
  217. kernel_size=3,
  218. stride=1,
  219. padding=1,
  220. weight_attr=ParamAttr(initializer=Normal(
  221. mean=0., std=0.01)),
  222. bias_attr=True))
  223. self.solo_cate = self.add_sublayer(
  224. 'bbox_head.solo_cate',
  225. nn.Conv2D(
  226. self.seg_feat_channels,
  227. self.cate_out_channels,
  228. kernel_size=3,
  229. stride=1,
  230. padding=1,
  231. weight_attr=ParamAttr(initializer=Normal(
  232. mean=0., std=0.01)),
  233. bias_attr=ParamAttr(initializer=Constant(
  234. value=float(-np.log((1 - 0.01) / 0.01))))))
  235. if self.drop_block and self.training:
  236. self.drop_block_fun = DropBlock(
  237. block_size=3, keep_prob=0.9, name='solo_cate.dropblock')
  238. def _points_nms(self, heat, kernel_size=2):
  239. hmax = F.max_pool2d(heat, kernel_size=kernel_size, stride=1, padding=1)
  240. keep = paddle.cast((hmax[:, :, :-1, :-1] == heat), 'float32')
  241. return heat * keep
  242. def _split_feats(self, feats):
  243. return (F.interpolate(
  244. feats[0],
  245. scale_factor=0.5,
  246. align_corners=False,
  247. align_mode=0,
  248. mode='bilinear'), feats[1], feats[2], feats[3], F.interpolate(
  249. feats[4],
  250. size=paddle.shape(feats[3])[-2:],
  251. mode='bilinear',
  252. align_corners=False,
  253. align_mode=0))
  254. def forward(self, input):
  255. """
  256. Get SOLOv2 head output
  257. Args:
  258. input (list): List of Tensors, output of backbone or neck stages
  259. Returns:
  260. cate_pred_list (list): Tensors of each category branch layer
  261. kernel_pred_list (list): Tensors of each kernel branch layer
  262. """
  263. feats = self._split_feats(input)
  264. cate_pred_list = []
  265. kernel_pred_list = []
  266. for idx in range(len(self.seg_num_grids)):
  267. cate_pred, kernel_pred = self._get_output_single(feats[idx], idx)
  268. cate_pred_list.append(cate_pred)
  269. kernel_pred_list.append(kernel_pred)
  270. return cate_pred_list, kernel_pred_list
  271. def _get_output_single(self, input, idx):
  272. ins_kernel_feat = input
  273. # CoordConv
  274. x_range = paddle.linspace(
  275. -1, 1, paddle.shape(ins_kernel_feat)[-1], dtype='float32')
  276. y_range = paddle.linspace(
  277. -1, 1, paddle.shape(ins_kernel_feat)[-2], dtype='float32')
  278. y, x = paddle.meshgrid([y_range, x_range])
  279. x = paddle.unsqueeze(x, [0, 1])
  280. y = paddle.unsqueeze(y, [0, 1])
  281. y = paddle.expand(
  282. y, shape=[paddle.shape(ins_kernel_feat)[0], 1, -1, -1])
  283. x = paddle.expand(
  284. x, shape=[paddle.shape(ins_kernel_feat)[0], 1, -1, -1])
  285. coord_feat = paddle.concat([x, y], axis=1)
  286. ins_kernel_feat = paddle.concat([ins_kernel_feat, coord_feat], axis=1)
  287. # kernel branch
  288. kernel_feat = ins_kernel_feat
  289. seg_num_grid = self.seg_num_grids[idx]
  290. kernel_feat = F.interpolate(
  291. kernel_feat,
  292. size=[seg_num_grid, seg_num_grid],
  293. mode='bilinear',
  294. align_corners=False,
  295. align_mode=0)
  296. cate_feat = kernel_feat[:, :-2, :, :]
  297. for kernel_layer in self.kernel_pred_convs:
  298. kernel_feat = F.relu(kernel_layer(kernel_feat))
  299. if self.drop_block and self.training:
  300. kernel_feat = self.drop_block_fun(kernel_feat)
  301. kernel_pred = self.solo_kernel(kernel_feat)
  302. # cate branch
  303. for cate_layer in self.cate_pred_convs:
  304. cate_feat = F.relu(cate_layer(cate_feat))
  305. if self.drop_block and self.training:
  306. cate_feat = self.drop_block_fun(cate_feat)
  307. cate_pred = self.solo_cate(cate_feat)
  308. if not self.training:
  309. cate_pred = self._points_nms(F.sigmoid(cate_pred), kernel_size=2)
  310. cate_pred = paddle.transpose(cate_pred, [0, 2, 3, 1])
  311. return cate_pred, kernel_pred
  312. def get_loss(self, cate_preds, kernel_preds, ins_pred, ins_labels,
  313. cate_labels, grid_order_list, fg_num):
  314. """
  315. Get loss of network of SOLOv2.
  316. Args:
  317. cate_preds (list): Tensor list of categroy branch output.
  318. kernel_preds (list): Tensor list of kernel branch output.
  319. ins_pred (list): Tensor list of instance branch output.
  320. ins_labels (list): List of instance labels pre batch.
  321. cate_labels (list): List of categroy labels pre batch.
  322. grid_order_list (list): List of index in pre grid.
  323. fg_num (int): Number of positive samples in a mini-batch.
  324. Returns:
  325. loss_ins (Tensor): The instance loss Tensor of SOLOv2 network.
  326. loss_cate (Tensor): The category loss Tensor of SOLOv2 network.
  327. """
  328. batch_size = paddle.shape(grid_order_list[0])[0]
  329. ins_pred_list = []
  330. for kernel_preds_level, grid_orders_level in zip(kernel_preds,
  331. grid_order_list):
  332. if grid_orders_level.shape[1] == 0:
  333. ins_pred_list.append(None)
  334. continue
  335. grid_orders_level = paddle.reshape(grid_orders_level, [-1])
  336. reshape_pred = paddle.reshape(
  337. kernel_preds_level,
  338. shape=(paddle.shape(kernel_preds_level)[0],
  339. paddle.shape(kernel_preds_level)[1], -1))
  340. reshape_pred = paddle.transpose(reshape_pred, [0, 2, 1])
  341. reshape_pred = paddle.reshape(
  342. reshape_pred, shape=(-1, paddle.shape(reshape_pred)[2]))
  343. gathered_pred = paddle.gather(
  344. reshape_pred, index=grid_orders_level)
  345. gathered_pred = paddle.reshape(
  346. gathered_pred,
  347. shape=[batch_size, -1, paddle.shape(gathered_pred)[1]])
  348. cur_ins_pred = ins_pred
  349. cur_ins_pred = paddle.reshape(
  350. cur_ins_pred,
  351. shape=(paddle.shape(cur_ins_pred)[0],
  352. paddle.shape(cur_ins_pred)[1], -1))
  353. ins_pred_conv = paddle.matmul(gathered_pred, cur_ins_pred)
  354. cur_ins_pred = paddle.reshape(
  355. ins_pred_conv,
  356. shape=(-1, paddle.shape(ins_pred)[-2],
  357. paddle.shape(ins_pred)[-1]))
  358. ins_pred_list.append(cur_ins_pred)
  359. num_ins = paddle.sum(fg_num)
  360. cate_preds = [
  361. paddle.reshape(
  362. paddle.transpose(cate_pred, [0, 2, 3, 1]),
  363. shape=(-1, self.cate_out_channels)) for cate_pred in cate_preds
  364. ]
  365. flatten_cate_preds = paddle.concat(cate_preds)
  366. new_cate_labels = []
  367. for cate_label in cate_labels:
  368. new_cate_labels.append(paddle.reshape(cate_label, shape=[-1]))
  369. cate_labels = paddle.concat(new_cate_labels)
  370. loss_ins, loss_cate = self.solov2_loss(ins_pred_list, ins_labels,
  371. flatten_cate_preds, cate_labels,
  372. num_ins)
  373. return {'loss_ins': loss_ins, 'loss_cate': loss_cate}
  374. def get_prediction(self, cate_preds, kernel_preds, seg_pred, im_shape,
  375. scale_factor):
  376. """
  377. Get prediction result of SOLOv2 network
  378. Args:
  379. cate_preds (list): List of Variables, output of categroy branch.
  380. kernel_preds (list): List of Variables, output of kernel branch.
  381. seg_pred (list): List of Variables, output of mask head stages.
  382. im_shape (Variables): [h, w] for input images.
  383. scale_factor (Variables): [scale, scale] for input images.
  384. Returns:
  385. seg_masks (Tensor): The prediction segmentation.
  386. cate_labels (Tensor): The prediction categroy label of each segmentation.
  387. seg_masks (Tensor): The prediction score of each segmentation.
  388. """
  389. num_levels = len(cate_preds)
  390. featmap_size = paddle.shape(seg_pred)[-2:]
  391. seg_masks_list = []
  392. cate_labels_list = []
  393. cate_scores_list = []
  394. cate_preds = [cate_pred * 1.0 for cate_pred in cate_preds]
  395. kernel_preds = [kernel_pred * 1.0 for kernel_pred in kernel_preds]
  396. # Currently only supports batch size == 1
  397. for idx in range(1):
  398. cate_pred_list = [
  399. paddle.reshape(
  400. cate_preds[i][idx], shape=(-1, self.cate_out_channels))
  401. for i in range(num_levels)
  402. ]
  403. seg_pred_list = seg_pred
  404. kernel_pred_list = [
  405. paddle.reshape(
  406. paddle.transpose(kernel_preds[i][idx], [1, 2, 0]),
  407. shape=(-1, self.kernel_out_channels))
  408. for i in range(num_levels)
  409. ]
  410. cate_pred_list = paddle.concat(cate_pred_list, axis=0)
  411. kernel_pred_list = paddle.concat(kernel_pred_list, axis=0)
  412. seg_masks, cate_labels, cate_scores = self.get_seg_single(
  413. cate_pred_list, seg_pred_list, kernel_pred_list, featmap_size,
  414. im_shape[idx], scale_factor[idx][0])
  415. bbox_num = paddle.shape(cate_labels)[0]
  416. return seg_masks, cate_labels, cate_scores, bbox_num
  417. def get_seg_single(self, cate_preds, seg_preds, kernel_preds, featmap_size,
  418. im_shape, scale_factor):
  419. h = paddle.cast(im_shape[0], 'int32')[0]
  420. w = paddle.cast(im_shape[1], 'int32')[0]
  421. upsampled_size_out = [featmap_size[0] * 4, featmap_size[1] * 4]
  422. y = paddle.zeros(shape=paddle.shape(cate_preds), dtype='float32')
  423. inds = paddle.where(cate_preds > self.score_threshold, cate_preds, y)
  424. inds = paddle.nonzero(inds)
  425. cate_preds = paddle.reshape(cate_preds, shape=[-1])
  426. # Prevent empty and increase fake data
  427. ind_a = paddle.cast(paddle.shape(kernel_preds)[0], 'int64')
  428. ind_b = paddle.zeros(shape=[1], dtype='int64')
  429. inds_end = paddle.unsqueeze(paddle.concat([ind_a, ind_b]), 0)
  430. inds = paddle.concat([inds, inds_end])
  431. kernel_preds_end = paddle.ones(
  432. shape=[1, self.kernel_out_channels], dtype='float32')
  433. kernel_preds = paddle.concat([kernel_preds, kernel_preds_end])
  434. cate_preds = paddle.concat(
  435. [cate_preds, paddle.zeros(
  436. shape=[1], dtype='float32')])
  437. # cate_labels & kernel_preds
  438. cate_labels = inds[:, 1]
  439. kernel_preds = paddle.gather(kernel_preds, index=inds[:, 0])
  440. cate_score_idx = paddle.add(inds[:, 0] * self.cate_out_channels,
  441. cate_labels)
  442. cate_scores = paddle.gather(cate_preds, index=cate_score_idx)
  443. size_trans = np.power(self.seg_num_grids, 2)
  444. strides = []
  445. for _ind in range(len(self.segm_strides)):
  446. strides.append(
  447. paddle.full(
  448. shape=[int(size_trans[_ind])],
  449. fill_value=self.segm_strides[_ind],
  450. dtype="int32"))
  451. strides = paddle.concat(strides)
  452. strides = paddle.concat(
  453. [strides, paddle.zeros(
  454. shape=[1], dtype='int32')])
  455. strides = paddle.gather(strides, index=inds[:, 0])
  456. # mask encoding.
  457. kernel_preds = paddle.unsqueeze(kernel_preds, [2, 3])
  458. seg_preds = F.conv2d(seg_preds, kernel_preds)
  459. seg_preds = F.sigmoid(paddle.squeeze(seg_preds, [0]))
  460. seg_masks = seg_preds > self.mask_threshold
  461. seg_masks = paddle.cast(seg_masks, 'float32')
  462. sum_masks = paddle.sum(seg_masks, axis=[1, 2])
  463. y = paddle.zeros(shape=paddle.shape(sum_masks), dtype='float32')
  464. keep = paddle.where(sum_masks > strides, sum_masks, y)
  465. keep = paddle.nonzero(keep)
  466. keep = paddle.squeeze(keep, axis=[1])
  467. # Prevent empty and increase fake data
  468. keep_other = paddle.concat(
  469. [keep, paddle.cast(paddle.shape(sum_masks)[0] - 1, 'int64')])
  470. keep_scores = paddle.concat(
  471. [keep, paddle.cast(paddle.shape(sum_masks)[0], 'int64')])
  472. cate_scores_end = paddle.zeros(shape=[1], dtype='float32')
  473. cate_scores = paddle.concat([cate_scores, cate_scores_end])
  474. seg_masks = paddle.gather(seg_masks, index=keep_other)
  475. seg_preds = paddle.gather(seg_preds, index=keep_other)
  476. sum_masks = paddle.gather(sum_masks, index=keep_other)
  477. cate_labels = paddle.gather(cate_labels, index=keep_other)
  478. cate_scores = paddle.gather(cate_scores, index=keep_scores)
  479. # mask scoring.
  480. seg_mul = paddle.cast(seg_preds * seg_masks, 'float32')
  481. seg_scores = paddle.sum(seg_mul, axis=[1, 2]) / sum_masks
  482. cate_scores *= seg_scores
  483. # Matrix NMS
  484. seg_preds, cate_scores, cate_labels = self.mask_nms(
  485. seg_preds,
  486. seg_masks,
  487. cate_labels,
  488. cate_scores,
  489. sum_masks=sum_masks)
  490. ori_shape = im_shape[:2] / scale_factor + 0.5
  491. ori_shape = paddle.cast(ori_shape, 'int32')
  492. seg_preds = F.interpolate(
  493. paddle.unsqueeze(seg_preds, 0),
  494. size=upsampled_size_out,
  495. mode='bilinear',
  496. align_corners=False,
  497. align_mode=0)
  498. seg_preds = paddle.slice(
  499. seg_preds, axes=[2, 3], starts=[0, 0], ends=[h, w])
  500. seg_masks = paddle.squeeze(
  501. F.interpolate(
  502. seg_preds,
  503. size=ori_shape[:2],
  504. mode='bilinear',
  505. align_corners=False,
  506. align_mode=0),
  507. axis=[0])
  508. seg_masks = paddle.cast(seg_masks > self.mask_threshold, 'uint8')
  509. return seg_masks, cate_labels, cate_scores