layers.py 53 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import six
  16. import numpy as np
  17. from numbers import Integral
  18. import paddle
  19. import paddle.nn as nn
  20. from paddle import ParamAttr
  21. from paddle import to_tensor
  22. import paddle.nn.functional as F
  23. from paddle.nn.initializer import Normal, Constant, XavierUniform
  24. from paddle.regularizer import L2Decay
  25. from paddlex.ppdet.core.workspace import register, serializable
  26. from paddlex.ppdet.modeling.bbox_utils import delta2bbox
  27. from . import ops
  28. from .initializer import xavier_uniform_, constant_
  29. from paddle.vision.ops import DeformConv2D
  30. def _to_list(l):
  31. if isinstance(l, (list, tuple)):
  32. return list(l)
  33. return [l]
  34. class DeformableConvV2(nn.Layer):
  35. def __init__(self,
  36. in_channels,
  37. out_channels,
  38. kernel_size,
  39. stride=1,
  40. padding=0,
  41. dilation=1,
  42. groups=1,
  43. weight_attr=None,
  44. bias_attr=None,
  45. lr_scale=1,
  46. regularizer=None,
  47. skip_quant=False,
  48. dcn_bias_regularizer=L2Decay(0.),
  49. dcn_bias_lr_scale=2.):
  50. super(DeformableConvV2, self).__init__()
  51. self.offset_channel = 2 * kernel_size**2
  52. self.mask_channel = kernel_size**2
  53. if lr_scale == 1 and regularizer is None:
  54. offset_bias_attr = ParamAttr(initializer=Constant(0.))
  55. else:
  56. offset_bias_attr = ParamAttr(
  57. initializer=Constant(0.),
  58. learning_rate=lr_scale,
  59. regularizer=regularizer)
  60. self.conv_offset = nn.Conv2D(
  61. in_channels,
  62. 3 * kernel_size**2,
  63. kernel_size,
  64. stride=stride,
  65. padding=(kernel_size - 1) // 2,
  66. weight_attr=ParamAttr(initializer=Constant(0.0)),
  67. bias_attr=offset_bias_attr)
  68. if skip_quant:
  69. self.conv_offset.skip_quant = True
  70. if bias_attr:
  71. # in FCOS-DCN head, specifically need learning_rate and regularizer
  72. dcn_bias_attr = ParamAttr(
  73. initializer=Constant(value=0),
  74. regularizer=dcn_bias_regularizer,
  75. learning_rate=dcn_bias_lr_scale)
  76. else:
  77. # in ResNet backbone, do not need bias
  78. dcn_bias_attr = False
  79. self.conv_dcn = DeformConv2D(
  80. in_channels,
  81. out_channels,
  82. kernel_size,
  83. stride=stride,
  84. padding=(kernel_size - 1) // 2 * dilation,
  85. dilation=dilation,
  86. groups=groups,
  87. weight_attr=weight_attr,
  88. bias_attr=dcn_bias_attr)
  89. def forward(self, x):
  90. offset_mask = self.conv_offset(x)
  91. offset, mask = paddle.split(
  92. offset_mask,
  93. num_or_sections=[self.offset_channel, self.mask_channel],
  94. axis=1)
  95. mask = F.sigmoid(mask)
  96. y = self.conv_dcn(x, offset, mask=mask)
  97. return y
  98. class ConvNormLayer(nn.Layer):
  99. def __init__(self,
  100. ch_in,
  101. ch_out,
  102. filter_size,
  103. stride,
  104. groups=1,
  105. norm_type='bn',
  106. norm_decay=0.,
  107. norm_groups=32,
  108. use_dcn=False,
  109. bias_on=False,
  110. lr_scale=1.,
  111. freeze_norm=False,
  112. initializer=Normal(
  113. mean=0., std=0.01),
  114. skip_quant=False,
  115. dcn_lr_scale=2.,
  116. dcn_regularizer=L2Decay(0.)):
  117. super(ConvNormLayer, self).__init__()
  118. assert norm_type in ['bn', 'sync_bn', 'gn']
  119. if bias_on:
  120. bias_attr = ParamAttr(
  121. initializer=Constant(value=0.), learning_rate=lr_scale)
  122. else:
  123. bias_attr = False
  124. if not use_dcn:
  125. self.conv = nn.Conv2D(
  126. in_channels=ch_in,
  127. out_channels=ch_out,
  128. kernel_size=filter_size,
  129. stride=stride,
  130. padding=(filter_size - 1) // 2,
  131. groups=groups,
  132. weight_attr=ParamAttr(
  133. initializer=initializer, learning_rate=1.),
  134. bias_attr=bias_attr)
  135. if skip_quant:
  136. self.conv.skip_quant = True
  137. else:
  138. # in FCOS-DCN head, specifically need learning_rate and regularizer
  139. self.conv = DeformableConvV2(
  140. in_channels=ch_in,
  141. out_channels=ch_out,
  142. kernel_size=filter_size,
  143. stride=stride,
  144. padding=(filter_size - 1) // 2,
  145. groups=groups,
  146. weight_attr=ParamAttr(
  147. initializer=initializer, learning_rate=1.),
  148. bias_attr=True,
  149. lr_scale=dcn_lr_scale,
  150. regularizer=dcn_regularizer,
  151. dcn_bias_regularizer=dcn_regularizer,
  152. dcn_bias_lr_scale=dcn_lr_scale,
  153. skip_quant=skip_quant)
  154. norm_lr = 0. if freeze_norm else 1.
  155. param_attr = ParamAttr(
  156. learning_rate=norm_lr,
  157. regularizer=L2Decay(norm_decay)
  158. if norm_decay is not None else None)
  159. bias_attr = ParamAttr(
  160. learning_rate=norm_lr,
  161. regularizer=L2Decay(norm_decay)
  162. if norm_decay is not None else None)
  163. if norm_type == 'bn':
  164. self.norm = nn.BatchNorm2D(
  165. ch_out, weight_attr=param_attr, bias_attr=bias_attr)
  166. elif norm_type == 'sync_bn':
  167. self.norm = nn.SyncBatchNorm(
  168. ch_out, weight_attr=param_attr, bias_attr=bias_attr)
  169. elif norm_type == 'gn':
  170. self.norm = nn.GroupNorm(
  171. num_groups=norm_groups,
  172. num_channels=ch_out,
  173. weight_attr=param_attr,
  174. bias_attr=bias_attr)
  175. def forward(self, inputs):
  176. out = self.conv(inputs)
  177. out = self.norm(out)
  178. return out
  179. class LiteConv(nn.Layer):
  180. def __init__(self,
  181. in_channels,
  182. out_channels,
  183. stride=1,
  184. with_act=True,
  185. norm_type='sync_bn',
  186. name=None):
  187. super(LiteConv, self).__init__()
  188. self.lite_conv = nn.Sequential()
  189. conv1 = ConvNormLayer(
  190. in_channels,
  191. in_channels,
  192. filter_size=5,
  193. stride=stride,
  194. groups=in_channels,
  195. norm_type=norm_type,
  196. initializer=XavierUniform())
  197. conv2 = ConvNormLayer(
  198. in_channels,
  199. out_channels,
  200. filter_size=1,
  201. stride=stride,
  202. norm_type=norm_type,
  203. initializer=XavierUniform())
  204. conv3 = ConvNormLayer(
  205. out_channels,
  206. out_channels,
  207. filter_size=1,
  208. stride=stride,
  209. norm_type=norm_type,
  210. initializer=XavierUniform())
  211. conv4 = ConvNormLayer(
  212. out_channels,
  213. out_channels,
  214. filter_size=5,
  215. stride=stride,
  216. groups=out_channels,
  217. norm_type=norm_type,
  218. initializer=XavierUniform())
  219. conv_list = [conv1, conv2, conv3, conv4]
  220. self.lite_conv.add_sublayer('conv1', conv1)
  221. self.lite_conv.add_sublayer('relu6_1', nn.ReLU6())
  222. self.lite_conv.add_sublayer('conv2', conv2)
  223. if with_act:
  224. self.lite_conv.add_sublayer('relu6_2', nn.ReLU6())
  225. self.lite_conv.add_sublayer('conv3', conv3)
  226. self.lite_conv.add_sublayer('relu6_3', nn.ReLU6())
  227. self.lite_conv.add_sublayer('conv4', conv4)
  228. if with_act:
  229. self.lite_conv.add_sublayer('relu6_4', nn.ReLU6())
  230. def forward(self, inputs):
  231. out = self.lite_conv(inputs)
  232. return out
  233. class DropBlock(nn.Layer):
  234. def __init__(self, block_size, keep_prob, name, data_format='NCHW'):
  235. """
  236. DropBlock layer, see https://arxiv.org/abs/1810.12890
  237. Args:
  238. block_size (int): block size
  239. keep_prob (int): keep probability
  240. name (str): layer name
  241. data_format (str): data format, NCHW or NHWC
  242. """
  243. super(DropBlock, self).__init__()
  244. self.block_size = block_size
  245. self.keep_prob = keep_prob
  246. self.name = name
  247. self.data_format = data_format
  248. def forward(self, x):
  249. if not self.training or self.keep_prob == 1:
  250. return x
  251. else:
  252. gamma = (1. - self.keep_prob) / (self.block_size**2)
  253. if self.data_format == 'NCHW':
  254. shape = x.shape[2:]
  255. else:
  256. shape = x.shape[1:3]
  257. for s in shape:
  258. gamma *= s / (s - self.block_size + 1)
  259. matrix = paddle.cast(paddle.rand(x.shape) < gamma, x.dtype)
  260. mask_inv = F.max_pool2d(
  261. matrix,
  262. self.block_size,
  263. stride=1,
  264. padding=self.block_size // 2,
  265. data_format=self.data_format)
  266. mask = 1. - mask_inv
  267. y = x * mask * (mask.numel() / mask.sum())
  268. return y
  269. @register
  270. @serializable
  271. class AnchorGeneratorSSD(object):
  272. def __init__(
  273. self,
  274. steps=[8, 16, 32, 64, 100, 300],
  275. aspect_ratios=[[2.], [2., 3.], [2., 3.], [2., 3.], [2.], [2.]],
  276. min_ratio=15,
  277. max_ratio=90,
  278. base_size=300,
  279. min_sizes=[30.0, 60.0, 111.0, 162.0, 213.0, 264.0],
  280. max_sizes=[60.0, 111.0, 162.0, 213.0, 264.0, 315.0],
  281. offset=0.5,
  282. flip=True,
  283. clip=False,
  284. min_max_aspect_ratios_order=False):
  285. self.steps = steps
  286. self.aspect_ratios = aspect_ratios
  287. self.min_ratio = min_ratio
  288. self.max_ratio = max_ratio
  289. self.base_size = base_size
  290. self.min_sizes = min_sizes
  291. self.max_sizes = max_sizes
  292. self.offset = offset
  293. self.flip = flip
  294. self.clip = clip
  295. self.min_max_aspect_ratios_order = min_max_aspect_ratios_order
  296. if self.min_sizes == [] and self.max_sizes == []:
  297. num_layer = len(aspect_ratios)
  298. step = int(
  299. math.floor(((self.max_ratio - self.min_ratio)) / (num_layer - 2
  300. )))
  301. for ratio in six.moves.range(self.min_ratio, self.max_ratio + 1,
  302. step):
  303. self.min_sizes.append(self.base_size * ratio / 100.)
  304. self.max_sizes.append(self.base_size * (ratio + step) / 100.)
  305. self.min_sizes = [self.base_size * .10] + self.min_sizes
  306. self.max_sizes = [self.base_size * .20] + self.max_sizes
  307. self.num_priors = []
  308. for aspect_ratio, min_size, max_size in zip(
  309. aspect_ratios, self.min_sizes, self.max_sizes):
  310. if isinstance(min_size, (list, tuple)):
  311. self.num_priors.append(
  312. len(_to_list(min_size)) + len(_to_list(max_size)))
  313. else:
  314. self.num_priors.append((len(aspect_ratio) * 2 + 1) * len(
  315. _to_list(min_size)) + len(_to_list(max_size)))
  316. def __call__(self, inputs, image):
  317. boxes = []
  318. for input, min_size, max_size, aspect_ratio, step in zip(
  319. inputs, self.min_sizes, self.max_sizes, self.aspect_ratios,
  320. self.steps):
  321. box, _ = ops.prior_box(
  322. input=input,
  323. image=image,
  324. min_sizes=_to_list(min_size),
  325. max_sizes=_to_list(max_size),
  326. aspect_ratios=aspect_ratio,
  327. flip=self.flip,
  328. clip=self.clip,
  329. steps=[step, step],
  330. offset=self.offset,
  331. min_max_aspect_ratios_order=self.min_max_aspect_ratios_order)
  332. boxes.append(paddle.reshape(box, [-1, 4]))
  333. return boxes
  334. @register
  335. @serializable
  336. class RCNNBox(object):
  337. __shared__ = ['num_classes']
  338. def __init__(self,
  339. prior_box_var=[10., 10., 5., 5.],
  340. code_type="decode_center_size",
  341. box_normalized=False,
  342. num_classes=80):
  343. super(RCNNBox, self).__init__()
  344. self.prior_box_var = prior_box_var
  345. self.code_type = code_type
  346. self.box_normalized = box_normalized
  347. self.num_classes = num_classes
  348. def __call__(self, bbox_head_out, rois, im_shape, scale_factor):
  349. bbox_pred = bbox_head_out[0]
  350. cls_prob = bbox_head_out[1]
  351. roi = rois[0]
  352. rois_num = rois[1]
  353. origin_shape = paddle.floor(im_shape / scale_factor + 0.5)
  354. scale_list = []
  355. origin_shape_list = []
  356. batch_size = 1
  357. if isinstance(roi, list):
  358. batch_size = len(roi)
  359. else:
  360. batch_size = paddle.slice(paddle.shape(im_shape), [0], [0], [1])
  361. # bbox_pred.shape: [N, C*4]
  362. for idx in range(batch_size):
  363. roi_per_im = roi[idx]
  364. rois_num_per_im = rois_num[idx]
  365. expand_im_shape = paddle.expand(im_shape[idx, :],
  366. [rois_num_per_im, 2])
  367. origin_shape_list.append(expand_im_shape)
  368. origin_shape = paddle.concat(origin_shape_list)
  369. # bbox_pred.shape: [N, C*4]
  370. # C=num_classes in faster/mask rcnn(bbox_head), C=1 in cascade rcnn(cascade_head)
  371. bbox = paddle.concat(roi)
  372. if bbox.shape[0] == 0:
  373. bbox = paddle.zeros([0, bbox_pred.shape[1]], dtype='float32')
  374. else:
  375. bbox = delta2bbox(bbox_pred, bbox, self.prior_box_var)
  376. scores = cls_prob[:, :-1]
  377. # bbox.shape: [N, C, 4]
  378. # bbox.shape[1] must be equal to scores.shape[1]
  379. bbox_num_class = bbox.shape[1]
  380. if bbox_num_class == 1:
  381. bbox = paddle.tile(bbox, [1, self.num_classes, 1])
  382. origin_h = paddle.unsqueeze(origin_shape[:, 0], axis=1)
  383. origin_w = paddle.unsqueeze(origin_shape[:, 1], axis=1)
  384. zeros = paddle.zeros_like(origin_h)
  385. x1 = paddle.maximum(paddle.minimum(bbox[:, :, 0], origin_w), zeros)
  386. y1 = paddle.maximum(paddle.minimum(bbox[:, :, 1], origin_h), zeros)
  387. x2 = paddle.maximum(paddle.minimum(bbox[:, :, 2], origin_w), zeros)
  388. y2 = paddle.maximum(paddle.minimum(bbox[:, :, 3], origin_h), zeros)
  389. bbox = paddle.stack([x1, y1, x2, y2], axis=-1)
  390. bboxes = (bbox, rois_num)
  391. return bboxes, scores
  392. @register
  393. @serializable
  394. class MultiClassNMS(object):
  395. def __init__(self,
  396. score_threshold=.05,
  397. nms_top_k=-1,
  398. keep_top_k=100,
  399. nms_threshold=.5,
  400. normalized=True,
  401. nms_eta=1.0,
  402. return_index=False,
  403. return_rois_num=True):
  404. super(MultiClassNMS, self).__init__()
  405. self.score_threshold = score_threshold
  406. self.nms_top_k = nms_top_k
  407. self.keep_top_k = keep_top_k
  408. self.nms_threshold = nms_threshold
  409. self.normalized = normalized
  410. self.nms_eta = nms_eta
  411. self.return_index = return_index
  412. self.return_rois_num = return_rois_num
  413. def __call__(self, bboxes, score, background_label=-1):
  414. """
  415. bboxes (Tensor|List[Tensor]): 1. (Tensor) Predicted bboxes with shape
  416. [N, M, 4], N is the batch size and M
  417. is the number of bboxes
  418. 2. (List[Tensor]) bboxes and bbox_num,
  419. bboxes have shape of [M, C, 4], C
  420. is the class number and bbox_num means
  421. the number of bboxes of each batch with
  422. shape [N,]
  423. score (Tensor): Predicted scores with shape [N, C, M] or [M, C]
  424. background_label (int): Ignore the background label; For example, RCNN
  425. is num_classes and YOLO is -1.
  426. """
  427. kwargs = self.__dict__.copy()
  428. if isinstance(bboxes, tuple):
  429. bboxes, bbox_num = bboxes
  430. kwargs.update({'rois_num': bbox_num})
  431. if background_label > -1:
  432. kwargs.update({'background_label': background_label})
  433. return ops.multiclass_nms(bboxes, score, **kwargs)
  434. @register
  435. @serializable
  436. class MatrixNMS(object):
  437. __append_doc__ = True
  438. def __init__(self,
  439. score_threshold=.05,
  440. post_threshold=.05,
  441. nms_top_k=-1,
  442. keep_top_k=100,
  443. use_gaussian=False,
  444. gaussian_sigma=2.,
  445. normalized=False,
  446. background_label=0):
  447. super(MatrixNMS, self).__init__()
  448. self.score_threshold = score_threshold
  449. self.post_threshold = post_threshold
  450. self.nms_top_k = nms_top_k
  451. self.keep_top_k = keep_top_k
  452. self.normalized = normalized
  453. self.use_gaussian = use_gaussian
  454. self.gaussian_sigma = gaussian_sigma
  455. self.background_label = background_label
  456. def __call__(self, bbox, score, *args):
  457. return ops.matrix_nms(
  458. bboxes=bbox,
  459. scores=score,
  460. score_threshold=self.score_threshold,
  461. post_threshold=self.post_threshold,
  462. nms_top_k=self.nms_top_k,
  463. keep_top_k=self.keep_top_k,
  464. use_gaussian=self.use_gaussian,
  465. gaussian_sigma=self.gaussian_sigma,
  466. background_label=self.background_label,
  467. normalized=self.normalized)
  468. @register
  469. @serializable
  470. class YOLOBox(object):
  471. __shared__ = ['num_classes']
  472. def __init__(self,
  473. num_classes=80,
  474. conf_thresh=0.005,
  475. downsample_ratio=32,
  476. clip_bbox=True,
  477. scale_x_y=1.):
  478. self.num_classes = num_classes
  479. self.conf_thresh = conf_thresh
  480. self.downsample_ratio = downsample_ratio
  481. self.clip_bbox = clip_bbox
  482. self.scale_x_y = scale_x_y
  483. def __call__(self,
  484. yolo_head_out,
  485. anchors,
  486. im_shape,
  487. scale_factor,
  488. var_weight=None):
  489. boxes_list = []
  490. scores_list = []
  491. origin_shape = im_shape / scale_factor
  492. origin_shape = paddle.cast(origin_shape, 'int32')
  493. for i, head_out in enumerate(yolo_head_out):
  494. boxes, scores = ops.yolo_box(head_out, origin_shape, anchors[i],
  495. self.num_classes, self.conf_thresh,
  496. self.downsample_ratio // 2**i,
  497. self.clip_bbox, self.scale_x_y)
  498. boxes_list.append(boxes)
  499. scores_list.append(paddle.transpose(scores, perm=[0, 2, 1]))
  500. yolo_boxes = paddle.concat(boxes_list, axis=1)
  501. yolo_scores = paddle.concat(scores_list, axis=2)
  502. return yolo_boxes, yolo_scores
  503. @register
  504. @serializable
  505. class SSDBox(object):
  506. def __init__(self, is_normalized=True):
  507. self.is_normalized = is_normalized
  508. self.norm_delta = float(not self.is_normalized)
  509. def __call__(self,
  510. preds,
  511. prior_boxes,
  512. im_shape,
  513. scale_factor,
  514. var_weight=None):
  515. boxes, scores = preds
  516. outputs = []
  517. for box, score, prior_box in zip(boxes, scores, prior_boxes):
  518. pb_w = prior_box[:, 2] - prior_box[:, 0] + self.norm_delta
  519. pb_h = prior_box[:, 3] - prior_box[:, 1] + self.norm_delta
  520. pb_x = prior_box[:, 0] + pb_w * 0.5
  521. pb_y = prior_box[:, 1] + pb_h * 0.5
  522. out_x = pb_x + box[:, :, 0] * pb_w * 0.1
  523. out_y = pb_y + box[:, :, 1] * pb_h * 0.1
  524. out_w = paddle.exp(box[:, :, 2] * 0.2) * pb_w
  525. out_h = paddle.exp(box[:, :, 3] * 0.2) * pb_h
  526. if self.is_normalized:
  527. h = paddle.unsqueeze(
  528. im_shape[:, 0] / scale_factor[:, 0], axis=-1)
  529. w = paddle.unsqueeze(
  530. im_shape[:, 1] / scale_factor[:, 1], axis=-1)
  531. output = paddle.stack(
  532. [(out_x - out_w / 2.) * w, (out_y - out_h / 2.) * h,
  533. (out_x + out_w / 2.) * w, (out_y + out_h / 2.) * h],
  534. axis=-1)
  535. else:
  536. output = paddle.stack(
  537. [
  538. out_x - out_w / 2., out_y - out_h / 2.,
  539. out_x + out_w / 2. - 1., out_y + out_h / 2. - 1.
  540. ],
  541. axis=-1)
  542. outputs.append(output)
  543. boxes = paddle.concat(outputs, axis=1)
  544. scores = F.softmax(paddle.concat(scores, axis=1))
  545. scores = paddle.transpose(scores, [0, 2, 1])
  546. return boxes, scores
  547. @register
  548. @serializable
  549. class AnchorGrid(object):
  550. """Generate anchor grid
  551. Args:
  552. image_size (int or list): input image size, may be a single integer or
  553. list of [h, w]. Default: 512
  554. min_level (int): min level of the feature pyramid. Default: 3
  555. max_level (int): max level of the feature pyramid. Default: 7
  556. anchor_base_scale: base anchor scale. Default: 4
  557. num_scales: number of anchor scales. Default: 3
  558. aspect_ratios: aspect ratios. default: [[1, 1], [1.4, 0.7], [0.7, 1.4]]
  559. """
  560. def __init__(self,
  561. image_size=512,
  562. min_level=3,
  563. max_level=7,
  564. anchor_base_scale=4,
  565. num_scales=3,
  566. aspect_ratios=[[1, 1], [1.4, 0.7], [0.7, 1.4]]):
  567. super(AnchorGrid, self).__init__()
  568. if isinstance(image_size, Integral):
  569. self.image_size = [image_size, image_size]
  570. else:
  571. self.image_size = image_size
  572. for dim in self.image_size:
  573. assert dim % 2 ** max_level == 0, \
  574. "image size should be multiple of the max level stride"
  575. self.min_level = min_level
  576. self.max_level = max_level
  577. self.anchor_base_scale = anchor_base_scale
  578. self.num_scales = num_scales
  579. self.aspect_ratios = aspect_ratios
  580. @property
  581. def base_cell(self):
  582. if not hasattr(self, '_base_cell'):
  583. self._base_cell = self.make_cell()
  584. return self._base_cell
  585. def make_cell(self):
  586. scales = [2**(i / self.num_scales) for i in range(self.num_scales)]
  587. scales = np.array(scales)
  588. ratios = np.array(self.aspect_ratios)
  589. ws = np.outer(scales, ratios[:, 0]).reshape(-1, 1)
  590. hs = np.outer(scales, ratios[:, 1]).reshape(-1, 1)
  591. anchors = np.hstack((-0.5 * ws, -0.5 * hs, 0.5 * ws, 0.5 * hs))
  592. return anchors
  593. def make_grid(self, stride):
  594. cell = self.base_cell * stride * self.anchor_base_scale
  595. x_steps = np.arange(stride // 2, self.image_size[1], stride)
  596. y_steps = np.arange(stride // 2, self.image_size[0], stride)
  597. offset_x, offset_y = np.meshgrid(x_steps, y_steps)
  598. offset_x = offset_x.flatten()
  599. offset_y = offset_y.flatten()
  600. offsets = np.stack((offset_x, offset_y, offset_x, offset_y), axis=-1)
  601. offsets = offsets[:, np.newaxis, :]
  602. return (cell + offsets).reshape(-1, 4)
  603. def generate(self):
  604. return [
  605. self.make_grid(2**l)
  606. for l in range(self.min_level, self.max_level + 1)
  607. ]
  608. def __call__(self):
  609. if not hasattr(self, '_anchor_vars'):
  610. anchor_vars = []
  611. helper = LayerHelper('anchor_grid')
  612. for idx, l in enumerate(range(self.min_level, self.max_level + 1)):
  613. stride = 2**l
  614. anchors = self.make_grid(stride)
  615. var = helper.create_parameter(
  616. attr=ParamAttr(name='anchors_{}'.format(idx)),
  617. shape=anchors.shape,
  618. dtype='float32',
  619. stop_gradient=True,
  620. default_initializer=NumpyArrayInitializer(anchors))
  621. anchor_vars.append(var)
  622. var.persistable = True
  623. self._anchor_vars = anchor_vars
  624. return self._anchor_vars
  625. @register
  626. @serializable
  627. class FCOSBox(object):
  628. __shared__ = ['num_classes']
  629. def __init__(self, num_classes=80):
  630. super(FCOSBox, self).__init__()
  631. self.num_classes = num_classes
  632. def _merge_hw(self, inputs, ch_type="channel_first"):
  633. """
  634. Merge h and w of the feature map into one dimension.
  635. Args:
  636. inputs (Tensor): Tensor of the input feature map
  637. ch_type (str): "channel_first" or "channel_last" style
  638. Return:
  639. new_shape (Tensor): The new shape after h and w merged
  640. """
  641. shape_ = paddle.shape(inputs)
  642. bs, ch, hi, wi = shape_[0], shape_[1], shape_[2], shape_[3]
  643. img_size = hi * wi
  644. img_size.stop_gradient = True
  645. if ch_type == "channel_first":
  646. new_shape = paddle.concat([bs, ch, img_size])
  647. elif ch_type == "channel_last":
  648. new_shape = paddle.concat([bs, img_size, ch])
  649. else:
  650. raise KeyError("Wrong ch_type %s" % ch_type)
  651. new_shape.stop_gradient = True
  652. return new_shape
  653. def _postprocessing_by_level(self, locations, box_cls, box_reg, box_ctn,
  654. scale_factor):
  655. """
  656. Postprocess each layer of the output with corresponding locations.
  657. Args:
  658. locations (Tensor): anchor points for current layer, [H*W, 2]
  659. box_cls (Tensor): categories prediction, [N, C, H, W],
  660. C is the number of classes
  661. box_reg (Tensor): bounding box prediction, [N, 4, H, W]
  662. box_ctn (Tensor): centerness prediction, [N, 1, H, W]
  663. scale_factor (Tensor): [h_scale, w_scale] for input images
  664. Return:
  665. box_cls_ch_last (Tensor): score for each category, in [N, C, M]
  666. C is the number of classes and M is the number of anchor points
  667. box_reg_decoding (Tensor): decoded bounding box, in [N, M, 4]
  668. last dimension is [x1, y1, x2, y2]
  669. """
  670. act_shape_cls = self._merge_hw(box_cls)
  671. box_cls_ch_last = paddle.reshape(x=box_cls, shape=act_shape_cls)
  672. box_cls_ch_last = F.sigmoid(box_cls_ch_last)
  673. act_shape_reg = self._merge_hw(box_reg)
  674. box_reg_ch_last = paddle.reshape(x=box_reg, shape=act_shape_reg)
  675. box_reg_ch_last = paddle.transpose(box_reg_ch_last, perm=[0, 2, 1])
  676. box_reg_decoding = paddle.stack(
  677. [
  678. locations[:, 0] - box_reg_ch_last[:, :, 0],
  679. locations[:, 1] - box_reg_ch_last[:, :, 1],
  680. locations[:, 0] + box_reg_ch_last[:, :, 2],
  681. locations[:, 1] + box_reg_ch_last[:, :, 3]
  682. ],
  683. axis=1)
  684. box_reg_decoding = paddle.transpose(box_reg_decoding, perm=[0, 2, 1])
  685. act_shape_ctn = self._merge_hw(box_ctn)
  686. box_ctn_ch_last = paddle.reshape(x=box_ctn, shape=act_shape_ctn)
  687. box_ctn_ch_last = F.sigmoid(box_ctn_ch_last)
  688. # recover the location to original image
  689. im_scale = paddle.concat([scale_factor, scale_factor], axis=1)
  690. im_scale = paddle.expand(im_scale, [box_reg_decoding.shape[0], 4])
  691. im_scale = paddle.reshape(im_scale, [box_reg_decoding.shape[0], -1, 4])
  692. box_reg_decoding = box_reg_decoding / im_scale
  693. box_cls_ch_last = box_cls_ch_last * box_ctn_ch_last
  694. return box_cls_ch_last, box_reg_decoding
  695. def __call__(self, locations, cls_logits, bboxes_reg, centerness,
  696. scale_factor):
  697. pred_boxes_ = []
  698. pred_scores_ = []
  699. for pts, cls, box, ctn in zip(locations, cls_logits, bboxes_reg,
  700. centerness):
  701. pred_scores_lvl, pred_boxes_lvl = self._postprocessing_by_level(
  702. pts, cls, box, ctn, scale_factor)
  703. pred_boxes_.append(pred_boxes_lvl)
  704. pred_scores_.append(pred_scores_lvl)
  705. pred_boxes = paddle.concat(pred_boxes_, axis=1)
  706. pred_scores = paddle.concat(pred_scores_, axis=2)
  707. return pred_boxes, pred_scores
  708. @register
  709. class TTFBox(object):
  710. __shared__ = ['down_ratio']
  711. def __init__(self, max_per_img=100, score_thresh=0.01, down_ratio=4):
  712. super(TTFBox, self).__init__()
  713. self.max_per_img = max_per_img
  714. self.score_thresh = score_thresh
  715. self.down_ratio = down_ratio
  716. def _simple_nms(self, heat, kernel=3):
  717. """
  718. Use maxpool to filter the max score, get local peaks.
  719. """
  720. pad = (kernel - 1) // 2
  721. hmax = F.max_pool2d(heat, kernel, stride=1, padding=pad)
  722. keep = paddle.cast(hmax == heat, 'float32')
  723. return heat * keep
  724. def _topk(self, scores):
  725. """
  726. Select top k scores and decode to get xy coordinates.
  727. """
  728. k = self.max_per_img
  729. shape_fm = paddle.shape(scores)
  730. shape_fm.stop_gradient = True
  731. cat, height, width = shape_fm[1], shape_fm[2], shape_fm[3]
  732. # batch size is 1
  733. scores_r = paddle.reshape(scores, [cat, -1])
  734. topk_scores, topk_inds = paddle.topk(scores_r, k)
  735. topk_scores, topk_inds = paddle.topk(scores_r, k)
  736. topk_ys = topk_inds // width
  737. topk_xs = topk_inds % width
  738. topk_score_r = paddle.reshape(topk_scores, [-1])
  739. topk_score, topk_ind = paddle.topk(topk_score_r, k)
  740. k_t = paddle.full(paddle.shape(topk_ind), k, dtype='int64')
  741. topk_clses = paddle.cast(paddle.floor_divide(topk_ind, k_t), 'float32')
  742. topk_inds = paddle.reshape(topk_inds, [-1])
  743. topk_ys = paddle.reshape(topk_ys, [-1, 1])
  744. topk_xs = paddle.reshape(topk_xs, [-1, 1])
  745. topk_inds = paddle.gather(topk_inds, topk_ind)
  746. topk_ys = paddle.gather(topk_ys, topk_ind)
  747. topk_xs = paddle.gather(topk_xs, topk_ind)
  748. return topk_score, topk_inds, topk_clses, topk_ys, topk_xs
  749. def _decode(self, hm, wh, im_shape, scale_factor):
  750. heatmap = F.sigmoid(hm)
  751. heat = self._simple_nms(heatmap)
  752. scores, inds, clses, ys, xs = self._topk(heat)
  753. ys = paddle.cast(ys, 'float32') * self.down_ratio
  754. xs = paddle.cast(xs, 'float32') * self.down_ratio
  755. scores = paddle.tensor.unsqueeze(scores, [1])
  756. clses = paddle.tensor.unsqueeze(clses, [1])
  757. wh_t = paddle.transpose(wh, [0, 2, 3, 1])
  758. wh = paddle.reshape(wh_t, [-1, paddle.shape(wh_t)[-1]])
  759. wh = paddle.gather(wh, inds)
  760. x1 = xs - wh[:, 0:1]
  761. y1 = ys - wh[:, 1:2]
  762. x2 = xs + wh[:, 2:3]
  763. y2 = ys + wh[:, 3:4]
  764. bboxes = paddle.concat([x1, y1, x2, y2], axis=1)
  765. scale_y = scale_factor[:, 0:1]
  766. scale_x = scale_factor[:, 1:2]
  767. scale_expand = paddle.concat(
  768. [scale_x, scale_y, scale_x, scale_y], axis=1)
  769. boxes_shape = paddle.shape(bboxes)
  770. boxes_shape.stop_gradient = True
  771. scale_expand = paddle.expand(scale_expand, shape=boxes_shape)
  772. bboxes = paddle.divide(bboxes, scale_expand)
  773. results = paddle.concat([clses, scores, bboxes], axis=1)
  774. # hack: append result with cls=-1 and score=1. to avoid all scores
  775. # are less than score_thresh which may cause error in gather.
  776. fill_r = paddle.to_tensor(np.array([[-1, 1, 0, 0, 0, 0]]))
  777. fill_r = paddle.cast(fill_r, results.dtype)
  778. results = paddle.concat([results, fill_r])
  779. scores = results[:, 1]
  780. valid_ind = paddle.nonzero(scores > self.score_thresh)
  781. results = paddle.gather(results, valid_ind)
  782. return results, paddle.shape(results)[0:1]
  783. def __call__(self, hm, wh, im_shape, scale_factor):
  784. results = []
  785. results_num = []
  786. for i in range(scale_factor.shape[0]):
  787. result, num = self._decode(hm[i:i + 1, ], wh[i:i + 1, ],
  788. im_shape[i:i + 1, ],
  789. scale_factor[i:i + 1, ])
  790. results.append(result)
  791. results_num.append(num)
  792. results = paddle.concat(results, axis=0)
  793. results_num = paddle.concat(results_num, axis=0)
  794. return results, results_num
  795. @register
  796. @serializable
  797. class JDEBox(object):
  798. __shared__ = ['num_classes']
  799. def __init__(self, num_classes=1, conf_thresh=0.3, downsample_ratio=32):
  800. self.num_classes = num_classes
  801. self.conf_thresh = conf_thresh
  802. self.downsample_ratio = downsample_ratio
  803. def generate_anchor(self, nGh, nGw, anchor_wh):
  804. nA = len(anchor_wh)
  805. yv, xv = paddle.meshgrid([paddle.arange(nGh), paddle.arange(nGw)])
  806. mesh = paddle.stack(
  807. (xv, yv), axis=0).cast(dtype='float32') # 2 x nGh x nGw
  808. meshs = paddle.tile(mesh, [nA, 1, 1, 1])
  809. anchor_offset_mesh = anchor_wh[:, :, None][:, :, :, None].repeat(
  810. int(nGh), axis=-2).repeat(
  811. int(nGw), axis=-1)
  812. anchor_offset_mesh = paddle.to_tensor(
  813. anchor_offset_mesh.astype(np.float32))
  814. # nA x 2 x nGh x nGw
  815. anchor_mesh = paddle.concat([meshs, anchor_offset_mesh], axis=1)
  816. anchor_mesh = paddle.transpose(anchor_mesh,
  817. [0, 2, 3, 1]) # (nA x nGh x nGw) x 4
  818. return anchor_mesh
  819. def decode_delta(self, delta, fg_anchor_list):
  820. px, py, pw, ph = fg_anchor_list[:, 0], fg_anchor_list[:,1], \
  821. fg_anchor_list[:, 2], fg_anchor_list[:,3]
  822. dx, dy, dw, dh = delta[:, 0], delta[:, 1], delta[:, 2], delta[:, 3]
  823. gx = pw * dx + px
  824. gy = ph * dy + py
  825. gw = pw * paddle.exp(dw)
  826. gh = ph * paddle.exp(dh)
  827. gx1 = gx - gw * 0.5
  828. gy1 = gy - gh * 0.5
  829. gx2 = gx + gw * 0.5
  830. gy2 = gy + gh * 0.5
  831. return paddle.stack([gx1, gy1, gx2, gy2], axis=1)
  832. def decode_delta_map(self, nA, nGh, nGw, delta_map, anchor_vec):
  833. anchor_mesh = self.generate_anchor(nGh, nGw, anchor_vec)
  834. anchor_mesh = paddle.unsqueeze(anchor_mesh, 0)
  835. pred_list = self.decode_delta(
  836. paddle.reshape(
  837. delta_map, shape=[-1, 4]),
  838. paddle.reshape(
  839. anchor_mesh, shape=[-1, 4]))
  840. pred_map = paddle.reshape(pred_list, shape=[nA * nGh * nGw, 4])
  841. return pred_map
  842. def _postprocessing_by_level(self, nA, stride, head_out, anchor_vec):
  843. boxes_shape = head_out.shape # [nB, nA*6, nGh, nGw]
  844. nGh, nGw = boxes_shape[-2], boxes_shape[-1]
  845. nB = 1 # TODO: only support bs=1 now
  846. boxes_list, scores_list = [], []
  847. for idx in range(nB):
  848. p = paddle.reshape(
  849. head_out[idx], shape=[nA, self.num_classes + 5, nGh, nGw])
  850. p = paddle.transpose(p, perm=[0, 2, 3, 1]) # [nA, nGh, nGw, 6]
  851. delta_map = p[:, :, :, :4]
  852. boxes = self.decode_delta_map(nA, nGh, nGw, delta_map, anchor_vec)
  853. # [nA * nGh * nGw, 4]
  854. boxes_list.append(boxes * stride)
  855. p_conf = paddle.transpose(
  856. p[:, :, :, 4:6], perm=[3, 0, 1, 2]) # [2, nA, nGh, nGw]
  857. p_conf = F.softmax(
  858. p_conf, axis=0)[1, :, :, :].unsqueeze(-1) # [nA, nGh, nGw, 1]
  859. scores = paddle.reshape(p_conf, shape=[nA * nGh * nGw, 1])
  860. scores_list.append(scores)
  861. boxes_results = paddle.stack(boxes_list)
  862. scores_results = paddle.stack(scores_list)
  863. return boxes_results, scores_results
  864. def __call__(self, yolo_head_out, anchors):
  865. bbox_pred_list = []
  866. for i, head_out in enumerate(yolo_head_out):
  867. stride = self.downsample_ratio // 2**i
  868. anc_w, anc_h = anchors[i][0::2], anchors[i][1::2]
  869. anchor_vec = np.stack((anc_w, anc_h), axis=1) / stride
  870. nA = len(anc_w)
  871. boxes, scores = self._postprocessing_by_level(nA, stride, head_out,
  872. anchor_vec)
  873. bbox_pred_list.append(paddle.concat([boxes, scores], axis=-1))
  874. yolo_boxes_scores = paddle.concat(bbox_pred_list, axis=1)
  875. boxes_idx_over_conf_thr = paddle.nonzero(
  876. yolo_boxes_scores[:, :, -1] > self.conf_thresh)
  877. boxes_idx_over_conf_thr.stop_gradient = True
  878. return boxes_idx_over_conf_thr, yolo_boxes_scores
  879. @register
  880. @serializable
  881. class MaskMatrixNMS(object):
  882. """
  883. Matrix NMS for multi-class masks.
  884. Args:
  885. update_threshold (float): Updated threshold of categroy score in second time.
  886. pre_nms_top_n (int): Number of total instance to be kept per image before NMS
  887. post_nms_top_n (int): Number of total instance to be kept per image after NMS.
  888. kernel (str): 'linear' or 'gaussian'.
  889. sigma (float): std in gaussian method.
  890. Input:
  891. seg_preds (Variable): shape (n, h, w), segmentation feature maps
  892. seg_masks (Variable): shape (n, h, w), segmentation feature maps
  893. cate_labels (Variable): shape (n), mask labels in descending order
  894. cate_scores (Variable): shape (n), mask scores in descending order
  895. sum_masks (Variable): a float tensor of the sum of seg_masks
  896. Returns:
  897. Variable: cate_scores, tensors of shape (n)
  898. """
  899. def __init__(self,
  900. update_threshold=0.05,
  901. pre_nms_top_n=500,
  902. post_nms_top_n=100,
  903. kernel='gaussian',
  904. sigma=2.0):
  905. super(MaskMatrixNMS, self).__init__()
  906. self.update_threshold = update_threshold
  907. self.pre_nms_top_n = pre_nms_top_n
  908. self.post_nms_top_n = post_nms_top_n
  909. self.kernel = kernel
  910. self.sigma = sigma
  911. def _sort_score(self, scores, top_num):
  912. if paddle.shape(scores)[0] > top_num:
  913. return paddle.topk(scores, top_num)[1]
  914. else:
  915. return paddle.argsort(scores, descending=True)
  916. def __call__(self,
  917. seg_preds,
  918. seg_masks,
  919. cate_labels,
  920. cate_scores,
  921. sum_masks=None):
  922. # sort and keep top nms_pre
  923. sort_inds = self._sort_score(cate_scores, self.pre_nms_top_n)
  924. seg_masks = paddle.gather(seg_masks, index=sort_inds)
  925. seg_preds = paddle.gather(seg_preds, index=sort_inds)
  926. sum_masks = paddle.gather(sum_masks, index=sort_inds)
  927. cate_scores = paddle.gather(cate_scores, index=sort_inds)
  928. cate_labels = paddle.gather(cate_labels, index=sort_inds)
  929. seg_masks = paddle.flatten(seg_masks, start_axis=1, stop_axis=-1)
  930. # inter.
  931. inter_matrix = paddle.mm(seg_masks,
  932. paddle.transpose(seg_masks, [1, 0]))
  933. n_samples = paddle.shape(cate_labels)
  934. # union.
  935. sum_masks_x = paddle.expand(sum_masks, shape=[n_samples, n_samples])
  936. # iou.
  937. iou_matrix = (inter_matrix / (
  938. sum_masks_x + paddle.transpose(sum_masks_x, [1, 0]) - inter_matrix)
  939. )
  940. iou_matrix = paddle.triu(iou_matrix, diagonal=1)
  941. # label_specific matrix.
  942. cate_labels_x = paddle.expand(
  943. cate_labels, shape=[n_samples, n_samples])
  944. label_matrix = paddle.cast(
  945. (cate_labels_x == paddle.transpose(cate_labels_x, [1, 0])),
  946. 'float32')
  947. label_matrix = paddle.triu(label_matrix, diagonal=1)
  948. # IoU compensation
  949. compensate_iou = paddle.max((iou_matrix * label_matrix), axis=0)
  950. compensate_iou = paddle.expand(
  951. compensate_iou, shape=[n_samples, n_samples])
  952. compensate_iou = paddle.transpose(compensate_iou, [1, 0])
  953. # IoU decay
  954. decay_iou = iou_matrix * label_matrix
  955. # matrix nms
  956. if self.kernel == 'gaussian':
  957. decay_matrix = paddle.exp(-1 * self.sigma * (decay_iou**2))
  958. compensate_matrix = paddle.exp(-1 * self.sigma *
  959. (compensate_iou**2))
  960. decay_coefficient = paddle.min(decay_matrix / compensate_matrix,
  961. axis=0)
  962. elif self.kernel == 'linear':
  963. decay_matrix = (1 - decay_iou) / (1 - compensate_iou)
  964. decay_coefficient = paddle.min(decay_matrix, axis=0)
  965. else:
  966. raise NotImplementedError
  967. # update the score.
  968. cate_scores = cate_scores * decay_coefficient
  969. y = paddle.zeros(shape=paddle.shape(cate_scores), dtype='float32')
  970. keep = paddle.where(cate_scores >= self.update_threshold, cate_scores,
  971. y)
  972. keep = paddle.nonzero(keep)
  973. keep = paddle.squeeze(keep, axis=[1])
  974. # Prevent empty and increase fake data
  975. keep = paddle.concat(
  976. [keep, paddle.cast(paddle.shape(cate_scores)[0] - 1, 'int64')])
  977. seg_preds = paddle.gather(seg_preds, index=keep)
  978. cate_scores = paddle.gather(cate_scores, index=keep)
  979. cate_labels = paddle.gather(cate_labels, index=keep)
  980. # sort and keep top_k
  981. sort_inds = self._sort_score(cate_scores, self.post_nms_top_n)
  982. seg_preds = paddle.gather(seg_preds, index=sort_inds)
  983. cate_scores = paddle.gather(cate_scores, index=sort_inds)
  984. cate_labels = paddle.gather(cate_labels, index=sort_inds)
  985. return seg_preds, cate_scores, cate_labels
  986. def Conv2d(in_channels,
  987. out_channels,
  988. kernel_size,
  989. stride=1,
  990. padding=0,
  991. dilation=1,
  992. groups=1,
  993. bias=True,
  994. weight_init=Normal(std=0.001),
  995. bias_init=Constant(0.)):
  996. weight_attr = paddle.framework.ParamAttr(initializer=weight_init)
  997. if bias:
  998. bias_attr = paddle.framework.ParamAttr(initializer=bias_init)
  999. else:
  1000. bias_attr = False
  1001. conv = nn.Conv2D(
  1002. in_channels,
  1003. out_channels,
  1004. kernel_size,
  1005. stride,
  1006. padding,
  1007. dilation,
  1008. groups,
  1009. weight_attr=weight_attr,
  1010. bias_attr=bias_attr)
  1011. return conv
  1012. def ConvTranspose2d(in_channels,
  1013. out_channels,
  1014. kernel_size,
  1015. stride=1,
  1016. padding=0,
  1017. output_padding=0,
  1018. groups=1,
  1019. bias=True,
  1020. dilation=1,
  1021. weight_init=Normal(std=0.001),
  1022. bias_init=Constant(0.)):
  1023. weight_attr = paddle.framework.ParamAttr(initializer=weight_init)
  1024. if bias:
  1025. bias_attr = paddle.framework.ParamAttr(initializer=bias_init)
  1026. else:
  1027. bias_attr = False
  1028. conv = nn.Conv2DTranspose(
  1029. in_channels,
  1030. out_channels,
  1031. kernel_size,
  1032. stride,
  1033. padding,
  1034. output_padding,
  1035. dilation,
  1036. groups,
  1037. weight_attr=weight_attr,
  1038. bias_attr=bias_attr)
  1039. return conv
  1040. def BatchNorm2d(num_features, eps=1e-05, momentum=0.9, affine=True):
  1041. if not affine:
  1042. weight_attr = False
  1043. bias_attr = False
  1044. else:
  1045. weight_attr = None
  1046. bias_attr = None
  1047. batchnorm = nn.BatchNorm2D(
  1048. num_features,
  1049. momentum,
  1050. eps,
  1051. weight_attr=weight_attr,
  1052. bias_attr=bias_attr)
  1053. return batchnorm
  1054. def ReLU():
  1055. return nn.ReLU()
  1056. def Upsample(scale_factor=None, mode='nearest', align_corners=False):
  1057. return nn.Upsample(None, scale_factor, mode, align_corners)
  1058. def MaxPool(kernel_size, stride, padding, ceil_mode=False):
  1059. return nn.MaxPool2D(kernel_size, stride, padding, ceil_mode=ceil_mode)
  1060. class Concat(nn.Layer):
  1061. def __init__(self, dim=0):
  1062. super(Concat, self).__init__()
  1063. self.dim = dim
  1064. def forward(self, inputs):
  1065. return paddle.concat(inputs, axis=self.dim)
  1066. def extra_repr(self):
  1067. return 'dim={}'.format(self.dim)
  1068. def _convert_attention_mask(attn_mask, dtype):
  1069. """
  1070. Convert the attention mask to the target dtype we expect.
  1071. Parameters:
  1072. attn_mask (Tensor, optional): A tensor used in multi-head attention
  1073. to prevents attention to some unwanted positions, usually the
  1074. paddings or the subsequent positions. It is a tensor with shape
  1075. broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`.
  1076. When the data type is bool, the unwanted positions have `False`
  1077. values and the others have `True` values. When the data type is
  1078. int, the unwanted positions have 0 values and the others have 1
  1079. values. When the data type is float, the unwanted positions have
  1080. `-INF` values and the others have 0 values. It can be None when
  1081. nothing wanted or needed to be prevented attention to. Default None.
  1082. dtype (VarType): The target type of `attn_mask` we expect.
  1083. Returns:
  1084. Tensor: A Tensor with shape same as input `attn_mask`, with data type `dtype`.
  1085. """
  1086. return nn.layer.transformer._convert_attention_mask(attn_mask, dtype)
  1087. class MultiHeadAttention(nn.Layer):
  1088. """
  1089. Attention mapps queries and a set of key-value pairs to outputs, and
  1090. Multi-Head Attention performs multiple parallel attention to jointly attending
  1091. to information from different representation subspaces.
  1092. Please refer to `Attention Is All You Need <https://arxiv.org/pdf/1706.03762.pdf>`_
  1093. for more details.
  1094. Parameters:
  1095. embed_dim (int): The expected feature size in the input and output.
  1096. num_heads (int): The number of heads in multi-head attention.
  1097. dropout (float, optional): The dropout probability used on attention
  1098. weights to drop some attention targets. 0 for no dropout. Default 0
  1099. kdim (int, optional): The feature size in key. If None, assumed equal to
  1100. `embed_dim`. Default None.
  1101. vdim (int, optional): The feature size in value. If None, assumed equal to
  1102. `embed_dim`. Default None.
  1103. need_weights (bool, optional): Indicate whether to return the attention
  1104. weights. Default False.
  1105. Examples:
  1106. .. code-block:: python
  1107. import paddle
  1108. # encoder input: [batch_size, sequence_length, d_model]
  1109. query = paddle.rand((2, 4, 128))
  1110. # self attention mask: [batch_size, num_heads, query_len, query_len]
  1111. attn_mask = paddle.rand((2, 2, 4, 4))
  1112. multi_head_attn = paddle.nn.MultiHeadAttention(128, 2)
  1113. output = multi_head_attn(query, None, None, attn_mask=attn_mask) # [2, 4, 128]
  1114. """
  1115. def __init__(self,
  1116. embed_dim,
  1117. num_heads,
  1118. dropout=0.,
  1119. kdim=None,
  1120. vdim=None,
  1121. need_weights=False):
  1122. super(MultiHeadAttention, self).__init__()
  1123. self.embed_dim = embed_dim
  1124. self.kdim = kdim if kdim is not None else embed_dim
  1125. self.vdim = vdim if vdim is not None else embed_dim
  1126. self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
  1127. self.num_heads = num_heads
  1128. self.dropout = dropout
  1129. self.need_weights = need_weights
  1130. self.head_dim = embed_dim // num_heads
  1131. assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
  1132. if self._qkv_same_embed_dim:
  1133. self.in_proj_weight = self.create_parameter(
  1134. shape=[embed_dim, 3 * embed_dim],
  1135. attr=None,
  1136. dtype=self._dtype,
  1137. is_bias=False)
  1138. self.in_proj_bias = self.create_parameter(
  1139. shape=[3 * embed_dim],
  1140. attr=None,
  1141. dtype=self._dtype,
  1142. is_bias=True)
  1143. else:
  1144. self.q_proj = nn.Linear(embed_dim, embed_dim)
  1145. self.k_proj = nn.Linear(self.kdim, embed_dim)
  1146. self.v_proj = nn.Linear(self.vdim, embed_dim)
  1147. self.out_proj = nn.Linear(embed_dim, embed_dim)
  1148. self._type_list = ('q_proj', 'k_proj', 'v_proj')
  1149. self._reset_parameters()
  1150. def _reset_parameters(self):
  1151. for p in self.parameters():
  1152. if p.dim() > 1:
  1153. xavier_uniform_(p)
  1154. else:
  1155. constant_(p)
  1156. def compute_qkv(self, tensor, index):
  1157. if self._qkv_same_embed_dim:
  1158. tensor = F.linear(
  1159. x=tensor,
  1160. weight=self.in_proj_weight[:, index * self.embed_dim:(
  1161. index + 1) * self.embed_dim],
  1162. bias=self.in_proj_bias[index * self.embed_dim:(index + 1) *
  1163. self.embed_dim]
  1164. if self.in_proj_bias is not None else None)
  1165. else:
  1166. tensor = getattr(self, self._type_list[index])(tensor)
  1167. tensor = tensor.reshape(
  1168. [0, 0, self.num_heads, self.head_dim]).transpose([0, 2, 1, 3])
  1169. return tensor
  1170. def forward(self, query, key=None, value=None, attn_mask=None):
  1171. r"""
  1172. Applies multi-head attention to map queries and a set of key-value pairs
  1173. to outputs.
  1174. Parameters:
  1175. query (Tensor): The queries for multi-head attention. It is a
  1176. tensor with shape `[batch_size, query_length, embed_dim]`. The
  1177. data type should be float32 or float64.
  1178. key (Tensor, optional): The keys for multi-head attention. It is
  1179. a tensor with shape `[batch_size, key_length, kdim]`. The
  1180. data type should be float32 or float64. If None, use `query` as
  1181. `key`. Default None.
  1182. value (Tensor, optional): The values for multi-head attention. It
  1183. is a tensor with shape `[batch_size, value_length, vdim]`.
  1184. The data type should be float32 or float64. If None, use `query` as
  1185. `value`. Default None.
  1186. attn_mask (Tensor, optional): A tensor used in multi-head attention
  1187. to prevents attention to some unwanted positions, usually the
  1188. paddings or the subsequent positions. It is a tensor with shape
  1189. broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`.
  1190. When the data type is bool, the unwanted positions have `False`
  1191. values and the others have `True` values. When the data type is
  1192. int, the unwanted positions have 0 values and the others have 1
  1193. values. When the data type is float, the unwanted positions have
  1194. `-INF` values and the others have 0 values. It can be None when
  1195. nothing wanted or needed to be prevented attention to. Default None.
  1196. Returns:
  1197. Tensor|tuple: It is a tensor that has the same shape and data type \
  1198. as `query`, representing attention output. Or a tuple if \
  1199. `need_weights` is True or `cache` is not None. If `need_weights` \
  1200. is True, except for attention output, the tuple also includes \
  1201. the attention weights tensor shaped `[batch_size, num_heads, query_length, key_length]`. \
  1202. If `cache` is not None, the tuple then includes the new cache \
  1203. having the same type as `cache`, and if it is `StaticCache`, it \
  1204. is same as the input `cache`, if it is `Cache`, the new cache \
  1205. reserves tensors concatanating raw tensors with intermediate \
  1206. results of current query.
  1207. """
  1208. key = query if key is None else key
  1209. value = query if value is None else value
  1210. # compute q ,k ,v
  1211. q, k, v = (self.compute_qkv(t, i)
  1212. for i, t in enumerate([query, key, value]))
  1213. # scale dot product attention
  1214. product = paddle.matmul(x=q, y=k, transpose_y=True)
  1215. scaling = float(self.head_dim)**-0.5
  1216. product = product * scaling
  1217. if attn_mask is not None:
  1218. # Support bool or int mask
  1219. attn_mask = _convert_attention_mask(attn_mask, product.dtype)
  1220. product = product + attn_mask
  1221. weights = F.softmax(product)
  1222. if self.dropout:
  1223. weights = F.dropout(
  1224. weights,
  1225. self.dropout,
  1226. training=self.training,
  1227. mode="upscale_in_train")
  1228. out = paddle.matmul(weights, v)
  1229. # combine heads
  1230. out = paddle.transpose(out, perm=[0, 2, 1, 3])
  1231. out = paddle.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
  1232. # project to output
  1233. out = self.out_proj(out)
  1234. outs = [out]
  1235. if self.need_weights:
  1236. outs.append(weights)
  1237. return out if len(outs) == 1 else tuple(outs)