solov2_head.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import paddle
  18. from paddle import ParamAttr
  19. import paddle.nn as nn
  20. import paddle.nn.functional as F
  21. from paddle.nn.initializer import Normal, Constant
  22. from paddlex.ppdet.modeling.layers import ConvNormLayer, MaskMatrixNMS
  23. from paddlex.ppdet.core.workspace import register
  24. from six.moves import zip
  25. import numpy as np
  26. __all__ = ['SOLOv2Head']
  27. @register
  28. class SOLOv2MaskHead(nn.Layer):
  29. """
  30. MaskHead of SOLOv2
  31. Args:
  32. in_channels (int): The channel number of input Tensor.
  33. out_channels (int): The channel number of output Tensor.
  34. start_level (int): The position where the input starts.
  35. end_level (int): The position where the input ends.
  36. use_dcn_in_tower (bool): Whether to use dcn in tower or not.
  37. """
  38. def __init__(self,
  39. in_channels=256,
  40. mid_channels=128,
  41. out_channels=256,
  42. start_level=0,
  43. end_level=3,
  44. use_dcn_in_tower=False):
  45. super(SOLOv2MaskHead, self).__init__()
  46. assert start_level >= 0 and end_level >= start_level
  47. self.in_channels = in_channels
  48. self.out_channels = out_channels
  49. self.mid_channels = mid_channels
  50. self.use_dcn_in_tower = use_dcn_in_tower
  51. self.range_level = end_level - start_level + 1
  52. # TODO: add DeformConvNorm
  53. conv_type = [ConvNormLayer]
  54. self.conv_func = conv_type[0]
  55. if self.use_dcn_in_tower:
  56. self.conv_func = conv_type[1]
  57. self.convs_all_levels = []
  58. for i in range(start_level, end_level + 1):
  59. conv_feat_name = 'mask_feat_head.convs_all_levels.{}'.format(i)
  60. conv_pre_feat = nn.Sequential()
  61. if i == start_level:
  62. conv_pre_feat.add_sublayer(
  63. conv_feat_name + '.conv' + str(i),
  64. self.conv_func(
  65. ch_in=self.in_channels,
  66. ch_out=self.mid_channels,
  67. filter_size=3,
  68. stride=1,
  69. norm_type='gn'))
  70. self.add_sublayer('conv_pre_feat' + str(i), conv_pre_feat)
  71. self.convs_all_levels.append(conv_pre_feat)
  72. else:
  73. for j in range(i):
  74. ch_in = 0
  75. if j == 0:
  76. ch_in = self.in_channels + 2 if i == end_level else self.in_channels
  77. else:
  78. ch_in = self.mid_channels
  79. conv_pre_feat.add_sublayer(
  80. conv_feat_name + '.conv' + str(j),
  81. self.conv_func(
  82. ch_in=ch_in,
  83. ch_out=self.mid_channels,
  84. filter_size=3,
  85. stride=1,
  86. norm_type='gn'))
  87. conv_pre_feat.add_sublayer(
  88. conv_feat_name + '.conv' + str(j) + 'act', nn.ReLU())
  89. conv_pre_feat.add_sublayer(
  90. 'upsample' + str(i) + str(j),
  91. nn.Upsample(
  92. scale_factor=2, mode='bilinear'))
  93. self.add_sublayer('conv_pre_feat' + str(i), conv_pre_feat)
  94. self.convs_all_levels.append(conv_pre_feat)
  95. conv_pred_name = 'mask_feat_head.conv_pred.0'
  96. self.conv_pred = self.add_sublayer(
  97. conv_pred_name,
  98. self.conv_func(
  99. ch_in=self.mid_channels,
  100. ch_out=self.out_channels,
  101. filter_size=1,
  102. stride=1,
  103. norm_type='gn'))
  104. def forward(self, inputs):
  105. """
  106. Get SOLOv2MaskHead output.
  107. Args:
  108. inputs(list[Tensor]): feature map from each necks with shape of [N, C, H, W]
  109. Returns:
  110. ins_pred(Tensor): Output of SOLOv2MaskHead head
  111. """
  112. feat_all_level = F.relu(self.convs_all_levels[0](inputs[0]))
  113. for i in range(1, self.range_level):
  114. input_p = inputs[i]
  115. if i == (self.range_level - 1):
  116. input_feat = input_p
  117. x_range = paddle.linspace(
  118. -1, 1, paddle.shape(input_feat)[-1], dtype='float32')
  119. y_range = paddle.linspace(
  120. -1, 1, paddle.shape(input_feat)[-2], dtype='float32')
  121. y, x = paddle.meshgrid([y_range, x_range])
  122. x = paddle.unsqueeze(x, [0, 1])
  123. y = paddle.unsqueeze(y, [0, 1])
  124. y = paddle.expand(
  125. y, shape=[paddle.shape(input_feat)[0], 1, -1, -1])
  126. x = paddle.expand(
  127. x, shape=[paddle.shape(input_feat)[0], 1, -1, -1])
  128. coord_feat = paddle.concat([x, y], axis=1)
  129. input_p = paddle.concat([input_p, coord_feat], axis=1)
  130. feat_all_level = paddle.add(feat_all_level,
  131. self.convs_all_levels[i](input_p))
  132. ins_pred = F.relu(self.conv_pred(feat_all_level))
  133. return ins_pred
  134. @register
  135. class SOLOv2Head(nn.Layer):
  136. """
  137. Head block for SOLOv2 network
  138. Args:
  139. num_classes (int): Number of output classes.
  140. in_channels (int): Number of input channels.
  141. seg_feat_channels (int): Num_filters of kernel & categroy branch convolution operation.
  142. stacked_convs (int): Times of convolution operation.
  143. num_grids (list[int]): List of feature map grids size.
  144. kernel_out_channels (int): Number of output channels in kernel branch.
  145. dcn_v2_stages (list): Which stage use dcn v2 in tower. It is between [0, stacked_convs).
  146. segm_strides (list[int]): List of segmentation area stride.
  147. solov2_loss (object): SOLOv2Loss instance.
  148. score_threshold (float): Threshold of categroy score.
  149. mask_nms (object): MaskMatrixNMS instance.
  150. """
  151. __inject__ = ['solov2_loss', 'mask_nms']
  152. __shared__ = ['num_classes']
  153. def __init__(self,
  154. num_classes=80,
  155. in_channels=256,
  156. seg_feat_channels=256,
  157. stacked_convs=4,
  158. num_grids=[40, 36, 24, 16, 12],
  159. kernel_out_channels=256,
  160. dcn_v2_stages=[],
  161. segm_strides=[8, 8, 16, 32, 32],
  162. solov2_loss=None,
  163. score_threshold=0.1,
  164. mask_threshold=0.5,
  165. mask_nms=None):
  166. super(SOLOv2Head, self).__init__()
  167. self.num_classes = num_classes
  168. self.in_channels = in_channels
  169. self.seg_num_grids = num_grids
  170. self.cate_out_channels = self.num_classes
  171. self.seg_feat_channels = seg_feat_channels
  172. self.stacked_convs = stacked_convs
  173. self.kernel_out_channels = kernel_out_channels
  174. self.dcn_v2_stages = dcn_v2_stages
  175. self.segm_strides = segm_strides
  176. self.solov2_loss = solov2_loss
  177. self.mask_nms = mask_nms
  178. self.score_threshold = score_threshold
  179. self.mask_threshold = mask_threshold
  180. conv_type = [ConvNormLayer]
  181. self.conv_func = conv_type[0]
  182. self.kernel_pred_convs = []
  183. self.cate_pred_convs = []
  184. for i in range(self.stacked_convs):
  185. if i in self.dcn_v2_stages:
  186. self.conv_func = conv_type[1]
  187. ch_in = self.in_channels + 2 if i == 0 else self.seg_feat_channels
  188. kernel_conv = self.add_sublayer(
  189. 'bbox_head.kernel_convs.' + str(i),
  190. self.conv_func(
  191. ch_in=ch_in,
  192. ch_out=self.seg_feat_channels,
  193. filter_size=3,
  194. stride=1,
  195. norm_type='gn'))
  196. self.kernel_pred_convs.append(kernel_conv)
  197. ch_in = self.in_channels if i == 0 else self.seg_feat_channels
  198. cate_conv = self.add_sublayer(
  199. 'bbox_head.cate_convs.' + str(i),
  200. self.conv_func(
  201. ch_in=ch_in,
  202. ch_out=self.seg_feat_channels,
  203. filter_size=3,
  204. stride=1,
  205. norm_type='gn'))
  206. self.cate_pred_convs.append(cate_conv)
  207. self.solo_kernel = self.add_sublayer(
  208. 'bbox_head.solo_kernel',
  209. nn.Conv2D(
  210. self.seg_feat_channels,
  211. self.kernel_out_channels,
  212. kernel_size=3,
  213. stride=1,
  214. padding=1,
  215. weight_attr=ParamAttr(initializer=Normal(
  216. mean=0., std=0.01)),
  217. bias_attr=True))
  218. self.solo_cate = self.add_sublayer(
  219. 'bbox_head.solo_cate',
  220. nn.Conv2D(
  221. self.seg_feat_channels,
  222. self.cate_out_channels,
  223. kernel_size=3,
  224. stride=1,
  225. padding=1,
  226. weight_attr=ParamAttr(initializer=Normal(
  227. mean=0., std=0.01)),
  228. bias_attr=ParamAttr(initializer=Constant(
  229. value=float(-np.log((1 - 0.01) / 0.01))))))
  230. def _points_nms(self, heat, kernel_size=2):
  231. hmax = F.max_pool2d(heat, kernel_size=kernel_size, stride=1, padding=1)
  232. keep = paddle.cast((hmax[:, :, :-1, :-1] == heat), 'float32')
  233. return heat * keep
  234. def _split_feats(self, feats):
  235. return (F.interpolate(
  236. feats[0],
  237. scale_factor=0.5,
  238. align_corners=False,
  239. align_mode=0,
  240. mode='bilinear'), feats[1], feats[2], feats[3], F.interpolate(
  241. feats[4],
  242. size=paddle.shape(feats[3])[-2:],
  243. mode='bilinear',
  244. align_corners=False,
  245. align_mode=0))
  246. def forward(self, input):
  247. """
  248. Get SOLOv2 head output
  249. Args:
  250. input (list): List of Tensors, output of backbone or neck stages
  251. Returns:
  252. cate_pred_list (list): Tensors of each category branch layer
  253. kernel_pred_list (list): Tensors of each kernel branch layer
  254. """
  255. feats = self._split_feats(input)
  256. cate_pred_list = []
  257. kernel_pred_list = []
  258. for idx in range(len(self.seg_num_grids)):
  259. cate_pred, kernel_pred = self._get_output_single(feats[idx], idx)
  260. cate_pred_list.append(cate_pred)
  261. kernel_pred_list.append(kernel_pred)
  262. return cate_pred_list, kernel_pred_list
  263. def _get_output_single(self, input, idx):
  264. ins_kernel_feat = input
  265. # CoordConv
  266. x_range = paddle.linspace(
  267. -1, 1, paddle.shape(ins_kernel_feat)[-1], dtype='float32')
  268. y_range = paddle.linspace(
  269. -1, 1, paddle.shape(ins_kernel_feat)[-2], dtype='float32')
  270. y, x = paddle.meshgrid([y_range, x_range])
  271. x = paddle.unsqueeze(x, [0, 1])
  272. y = paddle.unsqueeze(y, [0, 1])
  273. y = paddle.expand(
  274. y, shape=[paddle.shape(ins_kernel_feat)[0], 1, -1, -1])
  275. x = paddle.expand(
  276. x, shape=[paddle.shape(ins_kernel_feat)[0], 1, -1, -1])
  277. coord_feat = paddle.concat([x, y], axis=1)
  278. ins_kernel_feat = paddle.concat([ins_kernel_feat, coord_feat], axis=1)
  279. # kernel branch
  280. kernel_feat = ins_kernel_feat
  281. seg_num_grid = self.seg_num_grids[idx]
  282. kernel_feat = F.interpolate(
  283. kernel_feat,
  284. size=[seg_num_grid, seg_num_grid],
  285. mode='bilinear',
  286. align_corners=False,
  287. align_mode=0)
  288. cate_feat = kernel_feat[:, :-2, :, :]
  289. for kernel_layer in self.kernel_pred_convs:
  290. kernel_feat = F.relu(kernel_layer(kernel_feat))
  291. kernel_pred = self.solo_kernel(kernel_feat)
  292. # cate branch
  293. for cate_layer in self.cate_pred_convs:
  294. cate_feat = F.relu(cate_layer(cate_feat))
  295. cate_pred = self.solo_cate(cate_feat)
  296. if not self.training:
  297. cate_pred = self._points_nms(F.sigmoid(cate_pred), kernel_size=2)
  298. cate_pred = paddle.transpose(cate_pred, [0, 2, 3, 1])
  299. return cate_pred, kernel_pred
  300. def get_loss(self, cate_preds, kernel_preds, ins_pred, ins_labels,
  301. cate_labels, grid_order_list, fg_num):
  302. """
  303. Get loss of network of SOLOv2.
  304. Args:
  305. cate_preds (list): Tensor list of categroy branch output.
  306. kernel_preds (list): Tensor list of kernel branch output.
  307. ins_pred (list): Tensor list of instance branch output.
  308. ins_labels (list): List of instance labels pre batch.
  309. cate_labels (list): List of categroy labels pre batch.
  310. grid_order_list (list): List of index in pre grid.
  311. fg_num (int): Number of positive samples in a mini-batch.
  312. Returns:
  313. loss_ins (Tensor): The instance loss Tensor of SOLOv2 network.
  314. loss_cate (Tensor): The category loss Tensor of SOLOv2 network.
  315. """
  316. batch_size = paddle.shape(grid_order_list[0])[0]
  317. ins_pred_list = []
  318. for kernel_preds_level, grid_orders_level in zip(kernel_preds,
  319. grid_order_list):
  320. if grid_orders_level.shape[1] == 0:
  321. ins_pred_list.append(None)
  322. continue
  323. grid_orders_level = paddle.reshape(grid_orders_level, [-1])
  324. reshape_pred = paddle.reshape(
  325. kernel_preds_level,
  326. shape=(paddle.shape(kernel_preds_level)[0],
  327. paddle.shape(kernel_preds_level)[1], -1))
  328. reshape_pred = paddle.transpose(reshape_pred, [0, 2, 1])
  329. reshape_pred = paddle.reshape(
  330. reshape_pred, shape=(-1, paddle.shape(reshape_pred)[2]))
  331. gathered_pred = paddle.gather(reshape_pred, index=grid_orders_level)
  332. gathered_pred = paddle.reshape(
  333. gathered_pred,
  334. shape=[batch_size, -1, paddle.shape(gathered_pred)[1]])
  335. cur_ins_pred = ins_pred
  336. cur_ins_pred = paddle.reshape(
  337. cur_ins_pred,
  338. shape=(paddle.shape(cur_ins_pred)[0],
  339. paddle.shape(cur_ins_pred)[1], -1))
  340. ins_pred_conv = paddle.matmul(gathered_pred, cur_ins_pred)
  341. cur_ins_pred = paddle.reshape(
  342. ins_pred_conv,
  343. shape=(-1, paddle.shape(ins_pred)[-2],
  344. paddle.shape(ins_pred)[-1]))
  345. ins_pred_list.append(cur_ins_pred)
  346. num_ins = paddle.sum(fg_num)
  347. cate_preds = [
  348. paddle.reshape(
  349. paddle.transpose(cate_pred, [0, 2, 3, 1]),
  350. shape=(-1, self.cate_out_channels)) for cate_pred in cate_preds
  351. ]
  352. flatten_cate_preds = paddle.concat(cate_preds)
  353. new_cate_labels = []
  354. for cate_label in cate_labels:
  355. new_cate_labels.append(paddle.reshape(cate_label, shape=[-1]))
  356. cate_labels = paddle.concat(new_cate_labels)
  357. loss_ins, loss_cate = self.solov2_loss(
  358. ins_pred_list, ins_labels, flatten_cate_preds, cate_labels, num_ins)
  359. return {'loss_ins': loss_ins, 'loss_cate': loss_cate}
  360. def get_prediction(self, cate_preds, kernel_preds, seg_pred, im_shape,
  361. scale_factor):
  362. """
  363. Get prediction result of SOLOv2 network
  364. Args:
  365. cate_preds (list): List of Variables, output of categroy branch.
  366. kernel_preds (list): List of Variables, output of kernel branch.
  367. seg_pred (list): List of Variables, output of mask head stages.
  368. im_shape (Variables): [h, w] for input images.
  369. scale_factor (Variables): [scale, scale] for input images.
  370. Returns:
  371. seg_masks (Tensor): The prediction segmentation.
  372. cate_labels (Tensor): The prediction categroy label of each segmentation.
  373. seg_masks (Tensor): The prediction score of each segmentation.
  374. """
  375. num_levels = len(cate_preds)
  376. featmap_size = paddle.shape(seg_pred)[-2:]
  377. seg_masks_list = []
  378. cate_labels_list = []
  379. cate_scores_list = []
  380. cate_preds = [cate_pred * 1.0 for cate_pred in cate_preds]
  381. kernel_preds = [kernel_pred * 1.0 for kernel_pred in kernel_preds]
  382. # Currently only supports batch size == 1
  383. for idx in range(1):
  384. cate_pred_list = [
  385. paddle.reshape(
  386. cate_preds[i][idx], shape=(-1, self.cate_out_channels))
  387. for i in range(num_levels)
  388. ]
  389. seg_pred_list = seg_pred
  390. kernel_pred_list = [
  391. paddle.reshape(
  392. paddle.transpose(kernel_preds[i][idx], [1, 2, 0]),
  393. shape=(-1, self.kernel_out_channels))
  394. for i in range(num_levels)
  395. ]
  396. cate_pred_list = paddle.concat(cate_pred_list, axis=0)
  397. kernel_pred_list = paddle.concat(kernel_pred_list, axis=0)
  398. seg_masks, cate_labels, cate_scores = self.get_seg_single(
  399. cate_pred_list, seg_pred_list, kernel_pred_list, featmap_size,
  400. im_shape[idx], scale_factor[idx][0])
  401. bbox_num = paddle.shape(cate_labels)[0]
  402. return seg_masks, cate_labels, cate_scores, bbox_num
  403. def get_seg_single(self, cate_preds, seg_preds, kernel_preds, featmap_size,
  404. im_shape, scale_factor):
  405. h = paddle.cast(im_shape[0], 'int32')[0]
  406. w = paddle.cast(im_shape[1], 'int32')[0]
  407. upsampled_size_out = [featmap_size[0] * 4, featmap_size[1] * 4]
  408. y = paddle.zeros(shape=paddle.shape(cate_preds), dtype='float32')
  409. inds = paddle.where(cate_preds > self.score_threshold, cate_preds, y)
  410. inds = paddle.nonzero(inds)
  411. cate_preds = paddle.reshape(cate_preds, shape=[-1])
  412. # Prevent empty and increase fake data
  413. ind_a = paddle.cast(paddle.shape(kernel_preds)[0], 'int64')
  414. ind_b = paddle.zeros(shape=[1], dtype='int64')
  415. inds_end = paddle.unsqueeze(paddle.concat([ind_a, ind_b]), 0)
  416. inds = paddle.concat([inds, inds_end])
  417. kernel_preds_end = paddle.ones(
  418. shape=[1, self.kernel_out_channels], dtype='float32')
  419. kernel_preds = paddle.concat([kernel_preds, kernel_preds_end])
  420. cate_preds = paddle.concat(
  421. [cate_preds, paddle.zeros(
  422. shape=[1], dtype='float32')])
  423. # cate_labels & kernel_preds
  424. cate_labels = inds[:, 1]
  425. kernel_preds = paddle.gather(kernel_preds, index=inds[:, 0])
  426. cate_score_idx = paddle.add(inds[:, 0] * 80, cate_labels)
  427. cate_scores = paddle.gather(cate_preds, index=cate_score_idx)
  428. size_trans = np.power(self.seg_num_grids, 2)
  429. strides = []
  430. for _ind in range(len(self.segm_strides)):
  431. strides.append(
  432. paddle.full(
  433. shape=[int(size_trans[_ind])],
  434. fill_value=self.segm_strides[_ind],
  435. dtype="int32"))
  436. strides = paddle.concat(strides)
  437. strides = paddle.gather(strides, index=inds[:, 0])
  438. # mask encoding.
  439. kernel_preds = paddle.unsqueeze(kernel_preds, [2, 3])
  440. seg_preds = F.conv2d(seg_preds, kernel_preds)
  441. seg_preds = F.sigmoid(paddle.squeeze(seg_preds, [0]))
  442. seg_masks = seg_preds > self.mask_threshold
  443. seg_masks = paddle.cast(seg_masks, 'float32')
  444. sum_masks = paddle.sum(seg_masks, axis=[1, 2])
  445. y = paddle.zeros(shape=paddle.shape(sum_masks), dtype='float32')
  446. keep = paddle.where(sum_masks > strides, sum_masks, y)
  447. keep = paddle.nonzero(keep)
  448. keep = paddle.squeeze(keep, axis=[1])
  449. # Prevent empty and increase fake data
  450. keep_other = paddle.concat(
  451. [keep, paddle.cast(paddle.shape(sum_masks)[0] - 1, 'int64')])
  452. keep_scores = paddle.concat(
  453. [keep, paddle.cast(paddle.shape(sum_masks)[0], 'int64')])
  454. cate_scores_end = paddle.zeros(shape=[1], dtype='float32')
  455. cate_scores = paddle.concat([cate_scores, cate_scores_end])
  456. seg_masks = paddle.gather(seg_masks, index=keep_other)
  457. seg_preds = paddle.gather(seg_preds, index=keep_other)
  458. sum_masks = paddle.gather(sum_masks, index=keep_other)
  459. cate_labels = paddle.gather(cate_labels, index=keep_other)
  460. cate_scores = paddle.gather(cate_scores, index=keep_scores)
  461. # mask scoring.
  462. seg_mul = paddle.cast(seg_preds * seg_masks, 'float32')
  463. seg_scores = paddle.sum(seg_mul, axis=[1, 2]) / sum_masks
  464. cate_scores *= seg_scores
  465. # Matrix NMS
  466. seg_preds, cate_scores, cate_labels = self.mask_nms(
  467. seg_preds, seg_masks, cate_labels, cate_scores, sum_masks=sum_masks)
  468. ori_shape = im_shape[:2] / scale_factor + 0.5
  469. ori_shape = paddle.cast(ori_shape, 'int32')
  470. seg_preds = F.interpolate(
  471. paddle.unsqueeze(seg_preds, 0),
  472. size=upsampled_size_out,
  473. mode='bilinear',
  474. align_corners=False,
  475. align_mode=0)
  476. seg_preds = paddle.slice(
  477. seg_preds, axes=[2, 3], starts=[0, 0], ends=[h, w])
  478. seg_masks = paddle.squeeze(
  479. F.interpolate(
  480. seg_preds,
  481. size=ori_shape[:2],
  482. mode='bilinear',
  483. align_corners=False,
  484. align_mode=0),
  485. axis=[0])
  486. seg_masks = paddle.cast(seg_masks > self.mask_threshold, 'uint8')
  487. return seg_masks, cate_labels, cate_scores