layers.py 53 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import six
  16. import numpy as np
  17. from numbers import Integral
  18. import paddle
  19. import paddle.nn as nn
  20. from paddle import ParamAttr
  21. from paddle import to_tensor
  22. import paddle.nn.functional as F
  23. from paddle.nn.initializer import Normal, Constant, XavierUniform
  24. from paddle.regularizer import L2Decay
  25. from paddlex.ppdet.core.workspace import register, serializable
  26. from paddlex.ppdet.modeling.bbox_utils import delta2bbox
  27. from . import ops
  28. from .initializer import xavier_uniform_, constant_
  29. from paddle.vision.ops import DeformConv2D
  30. def _to_list(l):
  31. if isinstance(l, (list, tuple)):
  32. return list(l)
  33. return [l]
  34. class DeformableConvV2(nn.Layer):
  35. def __init__(self,
  36. in_channels,
  37. out_channels,
  38. kernel_size,
  39. stride=1,
  40. padding=0,
  41. dilation=1,
  42. groups=1,
  43. weight_attr=None,
  44. bias_attr=None,
  45. lr_scale=1,
  46. regularizer=None,
  47. skip_quant=False,
  48. dcn_bias_regularizer=L2Decay(0.),
  49. dcn_bias_lr_scale=2.):
  50. super(DeformableConvV2, self).__init__()
  51. self.offset_channel = 2 * kernel_size**2
  52. self.mask_channel = kernel_size**2
  53. if lr_scale == 1 and regularizer is None:
  54. offset_bias_attr = ParamAttr(initializer=Constant(0.))
  55. else:
  56. offset_bias_attr = ParamAttr(
  57. initializer=Constant(0.),
  58. learning_rate=lr_scale,
  59. regularizer=regularizer)
  60. self.conv_offset = nn.Conv2D(
  61. in_channels,
  62. 3 * kernel_size**2,
  63. kernel_size,
  64. stride=stride,
  65. padding=(kernel_size - 1) // 2,
  66. weight_attr=ParamAttr(initializer=Constant(0.0)),
  67. bias_attr=offset_bias_attr)
  68. if skip_quant:
  69. self.conv_offset.skip_quant = True
  70. if bias_attr:
  71. # in FCOS-DCN head, specifically need learning_rate and regularizer
  72. dcn_bias_attr = ParamAttr(
  73. initializer=Constant(value=0),
  74. regularizer=dcn_bias_regularizer,
  75. learning_rate=dcn_bias_lr_scale)
  76. else:
  77. # in ResNet backbone, do not need bias
  78. dcn_bias_attr = False
  79. self.conv_dcn = DeformConv2D(
  80. in_channels,
  81. out_channels,
  82. kernel_size,
  83. stride=stride,
  84. padding=(kernel_size - 1) // 2 * dilation,
  85. dilation=dilation,
  86. groups=groups,
  87. weight_attr=weight_attr,
  88. bias_attr=dcn_bias_attr)
  89. def forward(self, x):
  90. offset_mask = self.conv_offset(x)
  91. offset, mask = paddle.split(
  92. offset_mask,
  93. num_or_sections=[self.offset_channel, self.mask_channel],
  94. axis=1)
  95. mask = F.sigmoid(mask)
  96. y = self.conv_dcn(x, offset, mask=mask)
  97. return y
  98. class ConvNormLayer(nn.Layer):
  99. def __init__(self,
  100. ch_in,
  101. ch_out,
  102. filter_size,
  103. stride,
  104. groups=1,
  105. norm_type='bn',
  106. norm_decay=0.,
  107. norm_groups=32,
  108. use_dcn=False,
  109. bias_on=False,
  110. lr_scale=1.,
  111. freeze_norm=False,
  112. initializer=Normal(
  113. mean=0., std=0.01),
  114. skip_quant=False,
  115. dcn_lr_scale=2.,
  116. dcn_regularizer=L2Decay(0.)):
  117. super(ConvNormLayer, self).__init__()
  118. assert norm_type in ['bn', 'sync_bn', 'gn']
  119. if bias_on:
  120. bias_attr = ParamAttr(
  121. initializer=Constant(value=0.), learning_rate=lr_scale)
  122. else:
  123. bias_attr = False
  124. if not use_dcn:
  125. self.conv = nn.Conv2D(
  126. in_channels=ch_in,
  127. out_channels=ch_out,
  128. kernel_size=filter_size,
  129. stride=stride,
  130. padding=(filter_size - 1) // 2,
  131. groups=groups,
  132. weight_attr=ParamAttr(
  133. initializer=initializer, learning_rate=1.),
  134. bias_attr=bias_attr)
  135. if skip_quant:
  136. self.conv.skip_quant = True
  137. else:
  138. # in FCOS-DCN head, specifically need learning_rate and regularizer
  139. self.conv = DeformableConvV2(
  140. in_channels=ch_in,
  141. out_channels=ch_out,
  142. kernel_size=filter_size,
  143. stride=stride,
  144. padding=(filter_size - 1) // 2,
  145. groups=groups,
  146. weight_attr=ParamAttr(
  147. initializer=initializer, learning_rate=1.),
  148. bias_attr=True,
  149. lr_scale=dcn_lr_scale,
  150. regularizer=dcn_regularizer,
  151. dcn_bias_regularizer=dcn_regularizer,
  152. dcn_bias_lr_scale=dcn_lr_scale,
  153. skip_quant=skip_quant)
  154. norm_lr = 0. if freeze_norm else 1.
  155. param_attr = ParamAttr(
  156. learning_rate=norm_lr,
  157. regularizer=L2Decay(norm_decay)
  158. if norm_decay is not None else None)
  159. bias_attr = ParamAttr(
  160. learning_rate=norm_lr,
  161. regularizer=L2Decay(norm_decay)
  162. if norm_decay is not None else None)
  163. if norm_type == 'bn':
  164. self.norm = nn.BatchNorm2D(
  165. ch_out, weight_attr=param_attr, bias_attr=bias_attr)
  166. elif norm_type == 'sync_bn':
  167. self.norm = nn.SyncBatchNorm(
  168. ch_out, weight_attr=param_attr, bias_attr=bias_attr)
  169. elif norm_type == 'gn':
  170. self.norm = nn.GroupNorm(
  171. num_groups=norm_groups,
  172. num_channels=ch_out,
  173. weight_attr=param_attr,
  174. bias_attr=bias_attr)
  175. def forward(self, inputs):
  176. out = self.conv(inputs)
  177. out = self.norm(out)
  178. return out
  179. class LiteConv(nn.Layer):
  180. def __init__(self,
  181. in_channels,
  182. out_channels,
  183. stride=1,
  184. with_act=True,
  185. norm_type='sync_bn',
  186. name=None):
  187. super(LiteConv, self).__init__()
  188. self.lite_conv = nn.Sequential()
  189. conv1 = ConvNormLayer(
  190. in_channels,
  191. in_channels,
  192. filter_size=5,
  193. stride=stride,
  194. groups=in_channels,
  195. norm_type=norm_type,
  196. initializer=XavierUniform())
  197. conv2 = ConvNormLayer(
  198. in_channels,
  199. out_channels,
  200. filter_size=1,
  201. stride=stride,
  202. norm_type=norm_type,
  203. initializer=XavierUniform())
  204. conv3 = ConvNormLayer(
  205. out_channels,
  206. out_channels,
  207. filter_size=1,
  208. stride=stride,
  209. norm_type=norm_type,
  210. initializer=XavierUniform())
  211. conv4 = ConvNormLayer(
  212. out_channels,
  213. out_channels,
  214. filter_size=5,
  215. stride=stride,
  216. groups=out_channels,
  217. norm_type=norm_type,
  218. initializer=XavierUniform())
  219. conv_list = [conv1, conv2, conv3, conv4]
  220. self.lite_conv.add_sublayer('conv1', conv1)
  221. self.lite_conv.add_sublayer('relu6_1', nn.ReLU6())
  222. self.lite_conv.add_sublayer('conv2', conv2)
  223. if with_act:
  224. self.lite_conv.add_sublayer('relu6_2', nn.ReLU6())
  225. self.lite_conv.add_sublayer('conv3', conv3)
  226. self.lite_conv.add_sublayer('relu6_3', nn.ReLU6())
  227. self.lite_conv.add_sublayer('conv4', conv4)
  228. if with_act:
  229. self.lite_conv.add_sublayer('relu6_4', nn.ReLU6())
  230. def forward(self, inputs):
  231. out = self.lite_conv(inputs)
  232. return out
  233. class DropBlock(nn.Layer):
  234. def __init__(self, block_size, keep_prob, name, data_format='NCHW'):
  235. """
  236. DropBlock layer, see https://arxiv.org/abs/1810.12890
  237. Args:
  238. block_size (int): block size
  239. keep_prob (int): keep probability
  240. name (str): layer name
  241. data_format (str): data format, NCHW or NHWC
  242. """
  243. super(DropBlock, self).__init__()
  244. self.block_size = block_size
  245. self.keep_prob = keep_prob
  246. self.name = name
  247. self.data_format = data_format
  248. def forward(self, x):
  249. if not self.training or self.keep_prob == 1:
  250. return x
  251. else:
  252. gamma = (1. - self.keep_prob) / (self.block_size**2)
  253. if self.data_format == 'NCHW':
  254. shape = x.shape[2:]
  255. else:
  256. shape = x.shape[1:3]
  257. for s in shape:
  258. gamma *= s / (s - self.block_size + 1)
  259. matrix = paddle.cast(
  260. paddle.rand(x.shape, x.dtype) < gamma, x.dtype)
  261. mask_inv = F.max_pool2d(
  262. matrix,
  263. self.block_size,
  264. stride=1,
  265. padding=self.block_size // 2,
  266. data_format=self.data_format)
  267. mask = 1. - mask_inv
  268. y = x * mask * (mask.numel() / mask.sum())
  269. return y
  270. @register
  271. @serializable
  272. class AnchorGeneratorSSD(object):
  273. def __init__(
  274. self,
  275. steps=[8, 16, 32, 64, 100, 300],
  276. aspect_ratios=[[2.], [2., 3.], [2., 3.], [2., 3.], [2.], [2.]],
  277. min_ratio=15,
  278. max_ratio=90,
  279. base_size=300,
  280. min_sizes=[30.0, 60.0, 111.0, 162.0, 213.0, 264.0],
  281. max_sizes=[60.0, 111.0, 162.0, 213.0, 264.0, 315.0],
  282. offset=0.5,
  283. flip=True,
  284. clip=False,
  285. min_max_aspect_ratios_order=False):
  286. self.steps = steps
  287. self.aspect_ratios = aspect_ratios
  288. self.min_ratio = min_ratio
  289. self.max_ratio = max_ratio
  290. self.base_size = base_size
  291. self.min_sizes = min_sizes
  292. self.max_sizes = max_sizes
  293. self.offset = offset
  294. self.flip = flip
  295. self.clip = clip
  296. self.min_max_aspect_ratios_order = min_max_aspect_ratios_order
  297. if self.min_sizes == [] and self.max_sizes == []:
  298. num_layer = len(aspect_ratios)
  299. step = int(
  300. math.floor(((self.max_ratio - self.min_ratio)) / (num_layer - 2
  301. )))
  302. for ratio in six.moves.range(self.min_ratio, self.max_ratio + 1,
  303. step):
  304. self.min_sizes.append(self.base_size * ratio / 100.)
  305. self.max_sizes.append(self.base_size * (ratio + step) / 100.)
  306. self.min_sizes = [self.base_size * .10] + self.min_sizes
  307. self.max_sizes = [self.base_size * .20] + self.max_sizes
  308. self.num_priors = []
  309. for aspect_ratio, min_size, max_size in zip(
  310. aspect_ratios, self.min_sizes, self.max_sizes):
  311. if isinstance(min_size, (list, tuple)):
  312. self.num_priors.append(
  313. len(_to_list(min_size)) + len(_to_list(max_size)))
  314. else:
  315. self.num_priors.append((len(aspect_ratio) * 2 + 1) * len(
  316. _to_list(min_size)) + len(_to_list(max_size)))
  317. def __call__(self, inputs, image):
  318. boxes = []
  319. for input, min_size, max_size, aspect_ratio, step in zip(
  320. inputs, self.min_sizes, self.max_sizes, self.aspect_ratios,
  321. self.steps):
  322. box, _ = ops.prior_box(
  323. input=input,
  324. image=image,
  325. min_sizes=_to_list(min_size),
  326. max_sizes=_to_list(max_size),
  327. aspect_ratios=aspect_ratio,
  328. flip=self.flip,
  329. clip=self.clip,
  330. steps=[step, step],
  331. offset=self.offset,
  332. min_max_aspect_ratios_order=self.min_max_aspect_ratios_order)
  333. boxes.append(paddle.reshape(box, [-1, 4]))
  334. return boxes
  335. @register
  336. @serializable
  337. class RCNNBox(object):
  338. __shared__ = ['num_classes']
  339. def __init__(self,
  340. prior_box_var=[10., 10., 5., 5.],
  341. code_type="decode_center_size",
  342. box_normalized=False,
  343. num_classes=80):
  344. super(RCNNBox, self).__init__()
  345. self.prior_box_var = prior_box_var
  346. self.code_type = code_type
  347. self.box_normalized = box_normalized
  348. self.num_classes = num_classes
  349. def __call__(self, bbox_head_out, rois, im_shape, scale_factor):
  350. bbox_pred = bbox_head_out[0]
  351. cls_prob = bbox_head_out[1]
  352. roi = rois[0]
  353. rois_num = rois[1]
  354. origin_shape = paddle.floor(im_shape / scale_factor + 0.5)
  355. scale_list = []
  356. origin_shape_list = []
  357. batch_size = 1
  358. if isinstance(roi, list):
  359. batch_size = len(roi)
  360. else:
  361. batch_size = paddle.slice(paddle.shape(im_shape), [0], [0], [1])
  362. # bbox_pred.shape: [N, C*4]
  363. for idx in range(batch_size):
  364. roi_per_im = roi[idx]
  365. rois_num_per_im = rois_num[idx]
  366. expand_im_shape = paddle.expand(im_shape[idx, :],
  367. [rois_num_per_im, 2])
  368. origin_shape_list.append(expand_im_shape)
  369. origin_shape = paddle.concat(origin_shape_list)
  370. # bbox_pred.shape: [N, C*4]
  371. # C=num_classes in faster/mask rcnn(bbox_head), C=1 in cascade rcnn(cascade_head)
  372. bbox = paddle.concat(roi)
  373. if bbox.shape[0] == 0:
  374. bbox = paddle.zeros([0, bbox_pred.shape[1]], dtype='float32')
  375. else:
  376. bbox = delta2bbox(bbox_pred, bbox, self.prior_box_var)
  377. scores = cls_prob[:, :-1]
  378. # bbox.shape: [N, C, 4]
  379. # bbox.shape[1] must be equal to scores.shape[1]
  380. bbox_num_class = bbox.shape[1]
  381. if bbox_num_class == 1:
  382. bbox = paddle.tile(bbox, [1, self.num_classes, 1])
  383. origin_h = paddle.unsqueeze(origin_shape[:, 0], axis=1)
  384. origin_w = paddle.unsqueeze(origin_shape[:, 1], axis=1)
  385. zeros = paddle.zeros_like(origin_h)
  386. x1 = paddle.maximum(paddle.minimum(bbox[:, :, 0], origin_w), zeros)
  387. y1 = paddle.maximum(paddle.minimum(bbox[:, :, 1], origin_h), zeros)
  388. x2 = paddle.maximum(paddle.minimum(bbox[:, :, 2], origin_w), zeros)
  389. y2 = paddle.maximum(paddle.minimum(bbox[:, :, 3], origin_h), zeros)
  390. bbox = paddle.stack([x1, y1, x2, y2], axis=-1)
  391. bboxes = (bbox, rois_num)
  392. return bboxes, scores
  393. @register
  394. @serializable
  395. class MultiClassNMS(object):
  396. def __init__(self,
  397. score_threshold=.05,
  398. nms_top_k=-1,
  399. keep_top_k=100,
  400. nms_threshold=.5,
  401. normalized=True,
  402. nms_eta=1.0,
  403. return_index=False,
  404. return_rois_num=True):
  405. super(MultiClassNMS, self).__init__()
  406. self.score_threshold = score_threshold
  407. self.nms_top_k = nms_top_k
  408. self.keep_top_k = keep_top_k
  409. self.nms_threshold = nms_threshold
  410. self.normalized = normalized
  411. self.nms_eta = nms_eta
  412. self.return_index = return_index
  413. self.return_rois_num = return_rois_num
  414. def __call__(self, bboxes, score, background_label=-1):
  415. """
  416. bboxes (Tensor|List[Tensor]): 1. (Tensor) Predicted bboxes with shape
  417. [N, M, 4], N is the batch size and M
  418. is the number of bboxes
  419. 2. (List[Tensor]) bboxes and bbox_num,
  420. bboxes have shape of [M, C, 4], C
  421. is the class number and bbox_num means
  422. the number of bboxes of each batch with
  423. shape [N,]
  424. score (Tensor): Predicted scores with shape [N, C, M] or [M, C]
  425. background_label (int): Ignore the background label; For example, RCNN
  426. is num_classes and YOLO is -1.
  427. """
  428. kwargs = self.__dict__.copy()
  429. if isinstance(bboxes, tuple):
  430. bboxes, bbox_num = bboxes
  431. kwargs.update({'rois_num': bbox_num})
  432. if background_label > -1:
  433. kwargs.update({'background_label': background_label})
  434. return ops.multiclass_nms(bboxes, score, **kwargs)
  435. @register
  436. @serializable
  437. class MatrixNMS(object):
  438. __append_doc__ = True
  439. def __init__(self,
  440. score_threshold=.05,
  441. post_threshold=.05,
  442. nms_top_k=-1,
  443. keep_top_k=100,
  444. use_gaussian=False,
  445. gaussian_sigma=2.,
  446. normalized=False,
  447. background_label=0):
  448. super(MatrixNMS, self).__init__()
  449. self.score_threshold = score_threshold
  450. self.post_threshold = post_threshold
  451. self.nms_top_k = nms_top_k
  452. self.keep_top_k = keep_top_k
  453. self.normalized = normalized
  454. self.use_gaussian = use_gaussian
  455. self.gaussian_sigma = gaussian_sigma
  456. self.background_label = background_label
  457. def __call__(self, bbox, score, *args):
  458. return ops.matrix_nms(
  459. bboxes=bbox,
  460. scores=score,
  461. score_threshold=self.score_threshold,
  462. post_threshold=self.post_threshold,
  463. nms_top_k=self.nms_top_k,
  464. keep_top_k=self.keep_top_k,
  465. use_gaussian=self.use_gaussian,
  466. gaussian_sigma=self.gaussian_sigma,
  467. background_label=self.background_label,
  468. normalized=self.normalized)
  469. @register
  470. @serializable
  471. class YOLOBox(object):
  472. __shared__ = ['num_classes']
  473. def __init__(self,
  474. num_classes=80,
  475. conf_thresh=0.005,
  476. downsample_ratio=32,
  477. clip_bbox=True,
  478. scale_x_y=1.):
  479. self.num_classes = num_classes
  480. self.conf_thresh = conf_thresh
  481. self.downsample_ratio = downsample_ratio
  482. self.clip_bbox = clip_bbox
  483. self.scale_x_y = scale_x_y
  484. def __call__(self,
  485. yolo_head_out,
  486. anchors,
  487. im_shape,
  488. scale_factor,
  489. var_weight=None):
  490. boxes_list = []
  491. scores_list = []
  492. origin_shape = im_shape / scale_factor
  493. origin_shape = paddle.cast(origin_shape, 'int32')
  494. for i, head_out in enumerate(yolo_head_out):
  495. boxes, scores = ops.yolo_box(head_out, origin_shape, anchors[i],
  496. self.num_classes, self.conf_thresh,
  497. self.downsample_ratio // 2**i,
  498. self.clip_bbox, self.scale_x_y)
  499. boxes_list.append(boxes)
  500. scores_list.append(paddle.transpose(scores, perm=[0, 2, 1]))
  501. yolo_boxes = paddle.concat(boxes_list, axis=1)
  502. yolo_scores = paddle.concat(scores_list, axis=2)
  503. return yolo_boxes, yolo_scores
  504. @register
  505. @serializable
  506. class SSDBox(object):
  507. def __init__(self, is_normalized=True):
  508. self.is_normalized = is_normalized
  509. self.norm_delta = float(not self.is_normalized)
  510. def __call__(self,
  511. preds,
  512. prior_boxes,
  513. im_shape,
  514. scale_factor,
  515. var_weight=None):
  516. boxes, scores = preds
  517. outputs = []
  518. for box, score, prior_box in zip(boxes, scores, prior_boxes):
  519. pb_w = prior_box[:, 2] - prior_box[:, 0] + self.norm_delta
  520. pb_h = prior_box[:, 3] - prior_box[:, 1] + self.norm_delta
  521. pb_x = prior_box[:, 0] + pb_w * 0.5
  522. pb_y = prior_box[:, 1] + pb_h * 0.5
  523. out_x = pb_x + box[:, :, 0] * pb_w * 0.1
  524. out_y = pb_y + box[:, :, 1] * pb_h * 0.1
  525. out_w = paddle.exp(box[:, :, 2] * 0.2) * pb_w
  526. out_h = paddle.exp(box[:, :, 3] * 0.2) * pb_h
  527. if self.is_normalized:
  528. h = paddle.unsqueeze(
  529. im_shape[:, 0] / scale_factor[:, 0], axis=-1)
  530. w = paddle.unsqueeze(
  531. im_shape[:, 1] / scale_factor[:, 1], axis=-1)
  532. output = paddle.stack(
  533. [(out_x - out_w / 2.) * w, (out_y - out_h / 2.) * h,
  534. (out_x + out_w / 2.) * w, (out_y + out_h / 2.) * h],
  535. axis=-1)
  536. else:
  537. output = paddle.stack(
  538. [
  539. out_x - out_w / 2., out_y - out_h / 2.,
  540. out_x + out_w / 2. - 1., out_y + out_h / 2. - 1.
  541. ],
  542. axis=-1)
  543. outputs.append(output)
  544. boxes = paddle.concat(outputs, axis=1)
  545. scores = F.softmax(paddle.concat(scores, axis=1))
  546. scores = paddle.transpose(scores, [0, 2, 1])
  547. return boxes, scores
  548. @register
  549. @serializable
  550. class AnchorGrid(object):
  551. """Generate anchor grid
  552. Args:
  553. image_size (int or list): input image size, may be a single integer or
  554. list of [h, w]. Default: 512
  555. min_level (int): min level of the feature pyramid. Default: 3
  556. max_level (int): max level of the feature pyramid. Default: 7
  557. anchor_base_scale: base anchor scale. Default: 4
  558. num_scales: number of anchor scales. Default: 3
  559. aspect_ratios: aspect ratios. default: [[1, 1], [1.4, 0.7], [0.7, 1.4]]
  560. """
  561. def __init__(self,
  562. image_size=512,
  563. min_level=3,
  564. max_level=7,
  565. anchor_base_scale=4,
  566. num_scales=3,
  567. aspect_ratios=[[1, 1], [1.4, 0.7], [0.7, 1.4]]):
  568. super(AnchorGrid, self).__init__()
  569. if isinstance(image_size, Integral):
  570. self.image_size = [image_size, image_size]
  571. else:
  572. self.image_size = image_size
  573. for dim in self.image_size:
  574. assert dim % 2 ** max_level == 0, \
  575. "image size should be multiple of the max level stride"
  576. self.min_level = min_level
  577. self.max_level = max_level
  578. self.anchor_base_scale = anchor_base_scale
  579. self.num_scales = num_scales
  580. self.aspect_ratios = aspect_ratios
  581. @property
  582. def base_cell(self):
  583. if not hasattr(self, '_base_cell'):
  584. self._base_cell = self.make_cell()
  585. return self._base_cell
  586. def make_cell(self):
  587. scales = [2**(i / self.num_scales) for i in range(self.num_scales)]
  588. scales = np.array(scales)
  589. ratios = np.array(self.aspect_ratios)
  590. ws = np.outer(scales, ratios[:, 0]).reshape(-1, 1)
  591. hs = np.outer(scales, ratios[:, 1]).reshape(-1, 1)
  592. anchors = np.hstack((-0.5 * ws, -0.5 * hs, 0.5 * ws, 0.5 * hs))
  593. return anchors
  594. def make_grid(self, stride):
  595. cell = self.base_cell * stride * self.anchor_base_scale
  596. x_steps = np.arange(stride // 2, self.image_size[1], stride)
  597. y_steps = np.arange(stride // 2, self.image_size[0], stride)
  598. offset_x, offset_y = np.meshgrid(x_steps, y_steps)
  599. offset_x = offset_x.flatten()
  600. offset_y = offset_y.flatten()
  601. offsets = np.stack((offset_x, offset_y, offset_x, offset_y), axis=-1)
  602. offsets = offsets[:, np.newaxis, :]
  603. return (cell + offsets).reshape(-1, 4)
  604. def generate(self):
  605. return [
  606. self.make_grid(2**l)
  607. for l in range(self.min_level, self.max_level + 1)
  608. ]
  609. def __call__(self):
  610. if not hasattr(self, '_anchor_vars'):
  611. anchor_vars = []
  612. helper = LayerHelper('anchor_grid')
  613. for idx, l in enumerate(range(self.min_level, self.max_level + 1)):
  614. stride = 2**l
  615. anchors = self.make_grid(stride)
  616. var = helper.create_parameter(
  617. attr=ParamAttr(name='anchors_{}'.format(idx)),
  618. shape=anchors.shape,
  619. dtype='float32',
  620. stop_gradient=True,
  621. default_initializer=NumpyArrayInitializer(anchors))
  622. anchor_vars.append(var)
  623. var.persistable = True
  624. self._anchor_vars = anchor_vars
  625. return self._anchor_vars
  626. @register
  627. @serializable
  628. class FCOSBox(object):
  629. __shared__ = ['num_classes']
  630. def __init__(self, num_classes=80):
  631. super(FCOSBox, self).__init__()
  632. self.num_classes = num_classes
  633. def _merge_hw(self, inputs, ch_type="channel_first"):
  634. """
  635. Merge h and w of the feature map into one dimension.
  636. Args:
  637. inputs (Tensor): Tensor of the input feature map
  638. ch_type (str): "channel_first" or "channel_last" style
  639. Return:
  640. new_shape (Tensor): The new shape after h and w merged
  641. """
  642. shape_ = paddle.shape(inputs)
  643. bs, ch, hi, wi = shape_[0], shape_[1], shape_[2], shape_[3]
  644. img_size = hi * wi
  645. img_size.stop_gradient = True
  646. if ch_type == "channel_first":
  647. new_shape = paddle.concat([bs, ch, img_size])
  648. elif ch_type == "channel_last":
  649. new_shape = paddle.concat([bs, img_size, ch])
  650. else:
  651. raise KeyError("Wrong ch_type %s" % ch_type)
  652. new_shape.stop_gradient = True
  653. return new_shape
  654. def _postprocessing_by_level(self, locations, box_cls, box_reg, box_ctn,
  655. scale_factor):
  656. """
  657. Postprocess each layer of the output with corresponding locations.
  658. Args:
  659. locations (Tensor): anchor points for current layer, [H*W, 2]
  660. box_cls (Tensor): categories prediction, [N, C, H, W],
  661. C is the number of classes
  662. box_reg (Tensor): bounding box prediction, [N, 4, H, W]
  663. box_ctn (Tensor): centerness prediction, [N, 1, H, W]
  664. scale_factor (Tensor): [h_scale, w_scale] for input images
  665. Return:
  666. box_cls_ch_last (Tensor): score for each category, in [N, C, M]
  667. C is the number of classes and M is the number of anchor points
  668. box_reg_decoding (Tensor): decoded bounding box, in [N, M, 4]
  669. last dimension is [x1, y1, x2, y2]
  670. """
  671. act_shape_cls = self._merge_hw(box_cls)
  672. box_cls_ch_last = paddle.reshape(x=box_cls, shape=act_shape_cls)
  673. box_cls_ch_last = F.sigmoid(box_cls_ch_last)
  674. act_shape_reg = self._merge_hw(box_reg)
  675. box_reg_ch_last = paddle.reshape(x=box_reg, shape=act_shape_reg)
  676. box_reg_ch_last = paddle.transpose(box_reg_ch_last, perm=[0, 2, 1])
  677. box_reg_decoding = paddle.stack(
  678. [
  679. locations[:, 0] - box_reg_ch_last[:, :, 0],
  680. locations[:, 1] - box_reg_ch_last[:, :, 1],
  681. locations[:, 0] + box_reg_ch_last[:, :, 2],
  682. locations[:, 1] + box_reg_ch_last[:, :, 3]
  683. ],
  684. axis=1)
  685. box_reg_decoding = paddle.transpose(box_reg_decoding, perm=[0, 2, 1])
  686. act_shape_ctn = self._merge_hw(box_ctn)
  687. box_ctn_ch_last = paddle.reshape(x=box_ctn, shape=act_shape_ctn)
  688. box_ctn_ch_last = F.sigmoid(box_ctn_ch_last)
  689. # recover the location to original image
  690. im_scale = paddle.concat([scale_factor, scale_factor], axis=1)
  691. im_scale = paddle.expand(im_scale, [box_reg_decoding.shape[0], 4])
  692. im_scale = paddle.reshape(im_scale, [box_reg_decoding.shape[0], -1, 4])
  693. box_reg_decoding = box_reg_decoding / im_scale
  694. box_cls_ch_last = box_cls_ch_last * box_ctn_ch_last
  695. return box_cls_ch_last, box_reg_decoding
  696. def __call__(self, locations, cls_logits, bboxes_reg, centerness,
  697. scale_factor):
  698. pred_boxes_ = []
  699. pred_scores_ = []
  700. for pts, cls, box, ctn in zip(locations, cls_logits, bboxes_reg,
  701. centerness):
  702. pred_scores_lvl, pred_boxes_lvl = self._postprocessing_by_level(
  703. pts, cls, box, ctn, scale_factor)
  704. pred_boxes_.append(pred_boxes_lvl)
  705. pred_scores_.append(pred_scores_lvl)
  706. pred_boxes = paddle.concat(pred_boxes_, axis=1)
  707. pred_scores = paddle.concat(pred_scores_, axis=2)
  708. return pred_boxes, pred_scores
  709. @register
  710. class TTFBox(object):
  711. __shared__ = ['down_ratio']
  712. def __init__(self, max_per_img=100, score_thresh=0.01, down_ratio=4):
  713. super(TTFBox, self).__init__()
  714. self.max_per_img = max_per_img
  715. self.score_thresh = score_thresh
  716. self.down_ratio = down_ratio
  717. def _simple_nms(self, heat, kernel=3):
  718. """
  719. Use maxpool to filter the max score, get local peaks.
  720. """
  721. pad = (kernel - 1) // 2
  722. hmax = F.max_pool2d(heat, kernel, stride=1, padding=pad)
  723. keep = paddle.cast(hmax == heat, 'float32')
  724. return heat * keep
  725. def _topk(self, scores):
  726. """
  727. Select top k scores and decode to get xy coordinates.
  728. """
  729. k = self.max_per_img
  730. shape_fm = paddle.shape(scores)
  731. shape_fm.stop_gradient = True
  732. cat, height, width = shape_fm[1], shape_fm[2], shape_fm[3]
  733. # batch size is 1
  734. scores_r = paddle.reshape(scores, [cat, -1])
  735. topk_scores, topk_inds = paddle.topk(scores_r, k)
  736. topk_scores, topk_inds = paddle.topk(scores_r, k)
  737. topk_ys = topk_inds // width
  738. topk_xs = topk_inds % width
  739. topk_score_r = paddle.reshape(topk_scores, [-1])
  740. topk_score, topk_ind = paddle.topk(topk_score_r, k)
  741. k_t = paddle.full(paddle.shape(topk_ind), k, dtype='int64')
  742. topk_clses = paddle.cast(paddle.floor_divide(topk_ind, k_t), 'float32')
  743. topk_inds = paddle.reshape(topk_inds, [-1])
  744. topk_ys = paddle.reshape(topk_ys, [-1, 1])
  745. topk_xs = paddle.reshape(topk_xs, [-1, 1])
  746. topk_inds = paddle.gather(topk_inds, topk_ind)
  747. topk_ys = paddle.gather(topk_ys, topk_ind)
  748. topk_xs = paddle.gather(topk_xs, topk_ind)
  749. return topk_score, topk_inds, topk_clses, topk_ys, topk_xs
  750. def _decode(self, hm, wh, im_shape, scale_factor):
  751. heatmap = F.sigmoid(hm)
  752. heat = self._simple_nms(heatmap)
  753. scores, inds, clses, ys, xs = self._topk(heat)
  754. ys = paddle.cast(ys, 'float32') * self.down_ratio
  755. xs = paddle.cast(xs, 'float32') * self.down_ratio
  756. scores = paddle.tensor.unsqueeze(scores, [1])
  757. clses = paddle.tensor.unsqueeze(clses, [1])
  758. wh_t = paddle.transpose(wh, [0, 2, 3, 1])
  759. wh = paddle.reshape(wh_t, [-1, paddle.shape(wh_t)[-1]])
  760. wh = paddle.gather(wh, inds)
  761. x1 = xs - wh[:, 0:1]
  762. y1 = ys - wh[:, 1:2]
  763. x2 = xs + wh[:, 2:3]
  764. y2 = ys + wh[:, 3:4]
  765. bboxes = paddle.concat([x1, y1, x2, y2], axis=1)
  766. scale_y = scale_factor[:, 0:1]
  767. scale_x = scale_factor[:, 1:2]
  768. scale_expand = paddle.concat(
  769. [scale_x, scale_y, scale_x, scale_y], axis=1)
  770. boxes_shape = paddle.shape(bboxes)
  771. boxes_shape.stop_gradient = True
  772. scale_expand = paddle.expand(scale_expand, shape=boxes_shape)
  773. bboxes = paddle.divide(bboxes, scale_expand)
  774. results = paddle.concat([clses, scores, bboxes], axis=1)
  775. # hack: append result with cls=-1 and score=1. to avoid all scores
  776. # are less than score_thresh which may cause error in gather.
  777. fill_r = paddle.to_tensor(np.array([[-1, 1, 0, 0, 0, 0]]))
  778. fill_r = paddle.cast(fill_r, results.dtype)
  779. results = paddle.concat([results, fill_r])
  780. scores = results[:, 1]
  781. valid_ind = paddle.nonzero(scores > self.score_thresh)
  782. results = paddle.gather(results, valid_ind)
  783. return results, paddle.shape(results)[0:1]
  784. def __call__(self, hm, wh, im_shape, scale_factor):
  785. results = []
  786. results_num = []
  787. for i in range(scale_factor.shape[0]):
  788. result, num = self._decode(hm[i:i + 1, ], wh[i:i + 1, ],
  789. im_shape[i:i + 1, ],
  790. scale_factor[i:i + 1, ])
  791. results.append(result)
  792. results_num.append(num)
  793. results = paddle.concat(results, axis=0)
  794. results_num = paddle.concat(results_num, axis=0)
  795. return results, results_num
  796. @register
  797. @serializable
  798. class JDEBox(object):
  799. __shared__ = ['num_classes']
  800. def __init__(self, num_classes=1, conf_thresh=0.3, downsample_ratio=32):
  801. self.num_classes = num_classes
  802. self.conf_thresh = conf_thresh
  803. self.downsample_ratio = downsample_ratio
  804. def generate_anchor(self, nGh, nGw, anchor_wh):
  805. nA = len(anchor_wh)
  806. yv, xv = paddle.meshgrid([paddle.arange(nGh), paddle.arange(nGw)])
  807. mesh = paddle.stack(
  808. (xv, yv), axis=0).cast(dtype='float32') # 2 x nGh x nGw
  809. meshs = paddle.tile(mesh, [nA, 1, 1, 1])
  810. anchor_offset_mesh = anchor_wh[:, :, None][:, :, :, None].repeat(
  811. int(nGh), axis=-2).repeat(
  812. int(nGw), axis=-1)
  813. anchor_offset_mesh = paddle.to_tensor(
  814. anchor_offset_mesh.astype(np.float32))
  815. # nA x 2 x nGh x nGw
  816. anchor_mesh = paddle.concat([meshs, anchor_offset_mesh], axis=1)
  817. anchor_mesh = paddle.transpose(anchor_mesh,
  818. [0, 2, 3, 1]) # (nA x nGh x nGw) x 4
  819. return anchor_mesh
  820. def decode_delta(self, delta, fg_anchor_list):
  821. px, py, pw, ph = fg_anchor_list[:, 0], fg_anchor_list[:,1], \
  822. fg_anchor_list[:, 2], fg_anchor_list[:,3]
  823. dx, dy, dw, dh = delta[:, 0], delta[:, 1], delta[:, 2], delta[:, 3]
  824. gx = pw * dx + px
  825. gy = ph * dy + py
  826. gw = pw * paddle.exp(dw)
  827. gh = ph * paddle.exp(dh)
  828. gx1 = gx - gw * 0.5
  829. gy1 = gy - gh * 0.5
  830. gx2 = gx + gw * 0.5
  831. gy2 = gy + gh * 0.5
  832. return paddle.stack([gx1, gy1, gx2, gy2], axis=1)
  833. def decode_delta_map(self, nA, nGh, nGw, delta_map, anchor_vec):
  834. anchor_mesh = self.generate_anchor(nGh, nGw, anchor_vec)
  835. anchor_mesh = paddle.unsqueeze(anchor_mesh, 0)
  836. pred_list = self.decode_delta(
  837. paddle.reshape(
  838. delta_map, shape=[-1, 4]),
  839. paddle.reshape(
  840. anchor_mesh, shape=[-1, 4]))
  841. pred_map = paddle.reshape(pred_list, shape=[nA * nGh * nGw, 4])
  842. return pred_map
  843. def _postprocessing_by_level(self, nA, stride, head_out, anchor_vec):
  844. boxes_shape = head_out.shape # [nB, nA*6, nGh, nGw]
  845. nGh, nGw = boxes_shape[-2], boxes_shape[-1]
  846. nB = 1 # TODO: only support bs=1 now
  847. boxes_list, scores_list = [], []
  848. for idx in range(nB):
  849. p = paddle.reshape(
  850. head_out[idx], shape=[nA, self.num_classes + 5, nGh, nGw])
  851. p = paddle.transpose(p, perm=[0, 2, 3, 1]) # [nA, nGh, nGw, 6]
  852. delta_map = p[:, :, :, :4]
  853. boxes = self.decode_delta_map(nA, nGh, nGw, delta_map, anchor_vec)
  854. # [nA * nGh * nGw, 4]
  855. boxes_list.append(boxes * stride)
  856. p_conf = paddle.transpose(
  857. p[:, :, :, 4:6], perm=[3, 0, 1, 2]) # [2, nA, nGh, nGw]
  858. p_conf = F.softmax(
  859. p_conf, axis=0)[1, :, :, :].unsqueeze(-1) # [nA, nGh, nGw, 1]
  860. scores = paddle.reshape(p_conf, shape=[nA * nGh * nGw, 1])
  861. scores_list.append(scores)
  862. boxes_results = paddle.stack(boxes_list)
  863. scores_results = paddle.stack(scores_list)
  864. return boxes_results, scores_results
  865. def __call__(self, yolo_head_out, anchors):
  866. bbox_pred_list = []
  867. for i, head_out in enumerate(yolo_head_out):
  868. stride = self.downsample_ratio // 2**i
  869. anc_w, anc_h = anchors[i][0::2], anchors[i][1::2]
  870. anchor_vec = np.stack((anc_w, anc_h), axis=1) / stride
  871. nA = len(anc_w)
  872. boxes, scores = self._postprocessing_by_level(nA, stride, head_out,
  873. anchor_vec)
  874. bbox_pred_list.append(paddle.concat([boxes, scores], axis=-1))
  875. yolo_boxes_scores = paddle.concat(bbox_pred_list, axis=1)
  876. boxes_idx_over_conf_thr = paddle.nonzero(
  877. yolo_boxes_scores[:, :, -1] > self.conf_thresh)
  878. boxes_idx_over_conf_thr.stop_gradient = True
  879. return boxes_idx_over_conf_thr, yolo_boxes_scores
  880. @register
  881. @serializable
  882. class MaskMatrixNMS(object):
  883. """
  884. Matrix NMS for multi-class masks.
  885. Args:
  886. update_threshold (float): Updated threshold of categroy score in second time.
  887. pre_nms_top_n (int): Number of total instance to be kept per image before NMS
  888. post_nms_top_n (int): Number of total instance to be kept per image after NMS.
  889. kernel (str): 'linear' or 'gaussian'.
  890. sigma (float): std in gaussian method.
  891. Input:
  892. seg_preds (Variable): shape (n, h, w), segmentation feature maps
  893. seg_masks (Variable): shape (n, h, w), segmentation feature maps
  894. cate_labels (Variable): shape (n), mask labels in descending order
  895. cate_scores (Variable): shape (n), mask scores in descending order
  896. sum_masks (Variable): a float tensor of the sum of seg_masks
  897. Returns:
  898. Variable: cate_scores, tensors of shape (n)
  899. """
  900. def __init__(self,
  901. update_threshold=0.05,
  902. pre_nms_top_n=500,
  903. post_nms_top_n=100,
  904. kernel='gaussian',
  905. sigma=2.0):
  906. super(MaskMatrixNMS, self).__init__()
  907. self.update_threshold = update_threshold
  908. self.pre_nms_top_n = pre_nms_top_n
  909. self.post_nms_top_n = post_nms_top_n
  910. self.kernel = kernel
  911. self.sigma = sigma
  912. def _sort_score(self, scores, top_num):
  913. if paddle.shape(scores)[0] > top_num:
  914. return paddle.topk(scores, top_num)[1]
  915. else:
  916. return paddle.argsort(scores, descending=True)
  917. def __call__(self,
  918. seg_preds,
  919. seg_masks,
  920. cate_labels,
  921. cate_scores,
  922. sum_masks=None):
  923. # sort and keep top nms_pre
  924. sort_inds = self._sort_score(cate_scores, self.pre_nms_top_n)
  925. seg_masks = paddle.gather(seg_masks, index=sort_inds)
  926. seg_preds = paddle.gather(seg_preds, index=sort_inds)
  927. sum_masks = paddle.gather(sum_masks, index=sort_inds)
  928. cate_scores = paddle.gather(cate_scores, index=sort_inds)
  929. cate_labels = paddle.gather(cate_labels, index=sort_inds)
  930. seg_masks = paddle.flatten(seg_masks, start_axis=1, stop_axis=-1)
  931. # inter.
  932. inter_matrix = paddle.mm(seg_masks,
  933. paddle.transpose(seg_masks, [1, 0]))
  934. n_samples = paddle.shape(cate_labels)
  935. # union.
  936. sum_masks_x = paddle.expand(sum_masks, shape=[n_samples, n_samples])
  937. # iou.
  938. iou_matrix = (inter_matrix / (
  939. sum_masks_x + paddle.transpose(sum_masks_x, [1, 0]) - inter_matrix)
  940. )
  941. iou_matrix = paddle.triu(iou_matrix, diagonal=1)
  942. # label_specific matrix.
  943. cate_labels_x = paddle.expand(
  944. cate_labels, shape=[n_samples, n_samples])
  945. label_matrix = paddle.cast(
  946. (cate_labels_x == paddle.transpose(cate_labels_x, [1, 0])),
  947. 'float32')
  948. label_matrix = paddle.triu(label_matrix, diagonal=1)
  949. # IoU compensation
  950. compensate_iou = paddle.max((iou_matrix * label_matrix), axis=0)
  951. compensate_iou = paddle.expand(
  952. compensate_iou, shape=[n_samples, n_samples])
  953. compensate_iou = paddle.transpose(compensate_iou, [1, 0])
  954. # IoU decay
  955. decay_iou = iou_matrix * label_matrix
  956. # matrix nms
  957. if self.kernel == 'gaussian':
  958. decay_matrix = paddle.exp(-1 * self.sigma * (decay_iou**2))
  959. compensate_matrix = paddle.exp(-1 * self.sigma *
  960. (compensate_iou**2))
  961. decay_coefficient = paddle.min(decay_matrix / compensate_matrix,
  962. axis=0)
  963. elif self.kernel == 'linear':
  964. decay_matrix = (1 - decay_iou) / (1 - compensate_iou)
  965. decay_coefficient = paddle.min(decay_matrix, axis=0)
  966. else:
  967. raise NotImplementedError
  968. # update the score.
  969. cate_scores = cate_scores * decay_coefficient
  970. y = paddle.zeros(shape=paddle.shape(cate_scores), dtype='float32')
  971. keep = paddle.where(cate_scores >= self.update_threshold, cate_scores,
  972. y)
  973. keep = paddle.nonzero(keep)
  974. keep = paddle.squeeze(keep, axis=[1])
  975. # Prevent empty and increase fake data
  976. keep = paddle.concat(
  977. [keep, paddle.cast(paddle.shape(cate_scores)[0] - 1, 'int64')])
  978. seg_preds = paddle.gather(seg_preds, index=keep)
  979. cate_scores = paddle.gather(cate_scores, index=keep)
  980. cate_labels = paddle.gather(cate_labels, index=keep)
  981. # sort and keep top_k
  982. sort_inds = self._sort_score(cate_scores, self.post_nms_top_n)
  983. seg_preds = paddle.gather(seg_preds, index=sort_inds)
  984. cate_scores = paddle.gather(cate_scores, index=sort_inds)
  985. cate_labels = paddle.gather(cate_labels, index=sort_inds)
  986. return seg_preds, cate_scores, cate_labels
  987. def Conv2d(in_channels,
  988. out_channels,
  989. kernel_size,
  990. stride=1,
  991. padding=0,
  992. dilation=1,
  993. groups=1,
  994. bias=True,
  995. weight_init=Normal(std=0.001),
  996. bias_init=Constant(0.)):
  997. weight_attr = paddle.framework.ParamAttr(initializer=weight_init)
  998. if bias:
  999. bias_attr = paddle.framework.ParamAttr(initializer=bias_init)
  1000. else:
  1001. bias_attr = False
  1002. conv = nn.Conv2D(
  1003. in_channels,
  1004. out_channels,
  1005. kernel_size,
  1006. stride,
  1007. padding,
  1008. dilation,
  1009. groups,
  1010. weight_attr=weight_attr,
  1011. bias_attr=bias_attr)
  1012. return conv
  1013. def ConvTranspose2d(in_channels,
  1014. out_channels,
  1015. kernel_size,
  1016. stride=1,
  1017. padding=0,
  1018. output_padding=0,
  1019. groups=1,
  1020. bias=True,
  1021. dilation=1,
  1022. weight_init=Normal(std=0.001),
  1023. bias_init=Constant(0.)):
  1024. weight_attr = paddle.framework.ParamAttr(initializer=weight_init)
  1025. if bias:
  1026. bias_attr = paddle.framework.ParamAttr(initializer=bias_init)
  1027. else:
  1028. bias_attr = False
  1029. conv = nn.Conv2DTranspose(
  1030. in_channels,
  1031. out_channels,
  1032. kernel_size,
  1033. stride,
  1034. padding,
  1035. output_padding,
  1036. dilation,
  1037. groups,
  1038. weight_attr=weight_attr,
  1039. bias_attr=bias_attr)
  1040. return conv
  1041. def BatchNorm2d(num_features, eps=1e-05, momentum=0.9, affine=True):
  1042. if not affine:
  1043. weight_attr = False
  1044. bias_attr = False
  1045. else:
  1046. weight_attr = None
  1047. bias_attr = None
  1048. batchnorm = nn.BatchNorm2D(
  1049. num_features,
  1050. momentum,
  1051. eps,
  1052. weight_attr=weight_attr,
  1053. bias_attr=bias_attr)
  1054. return batchnorm
  1055. def ReLU():
  1056. return nn.ReLU()
  1057. def Upsample(scale_factor=None, mode='nearest', align_corners=False):
  1058. return nn.Upsample(None, scale_factor, mode, align_corners)
  1059. def MaxPool(kernel_size, stride, padding, ceil_mode=False):
  1060. return nn.MaxPool2D(kernel_size, stride, padding, ceil_mode=ceil_mode)
  1061. class Concat(nn.Layer):
  1062. def __init__(self, dim=0):
  1063. super(Concat, self).__init__()
  1064. self.dim = dim
  1065. def forward(self, inputs):
  1066. return paddle.concat(inputs, axis=self.dim)
  1067. def extra_repr(self):
  1068. return 'dim={}'.format(self.dim)
  1069. def _convert_attention_mask(attn_mask, dtype):
  1070. """
  1071. Convert the attention mask to the target dtype we expect.
  1072. Parameters:
  1073. attn_mask (Tensor, optional): A tensor used in multi-head attention
  1074. to prevents attention to some unwanted positions, usually the
  1075. paddings or the subsequent positions. It is a tensor with shape
  1076. broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`.
  1077. When the data type is bool, the unwanted positions have `False`
  1078. values and the others have `True` values. When the data type is
  1079. int, the unwanted positions have 0 values and the others have 1
  1080. values. When the data type is float, the unwanted positions have
  1081. `-INF` values and the others have 0 values. It can be None when
  1082. nothing wanted or needed to be prevented attention to. Default None.
  1083. dtype (VarType): The target type of `attn_mask` we expect.
  1084. Returns:
  1085. Tensor: A Tensor with shape same as input `attn_mask`, with data type `dtype`.
  1086. """
  1087. return nn.layer.transformer._convert_attention_mask(attn_mask, dtype)
  1088. class MultiHeadAttention(nn.Layer):
  1089. """
  1090. Attention mapps queries and a set of key-value pairs to outputs, and
  1091. Multi-Head Attention performs multiple parallel attention to jointly attending
  1092. to information from different representation subspaces.
  1093. Please refer to `Attention Is All You Need <https://arxiv.org/pdf/1706.03762.pdf>`_
  1094. for more details.
  1095. Parameters:
  1096. embed_dim (int): The expected feature size in the input and output.
  1097. num_heads (int): The number of heads in multi-head attention.
  1098. dropout (float, optional): The dropout probability used on attention
  1099. weights to drop some attention targets. 0 for no dropout. Default 0
  1100. kdim (int, optional): The feature size in key. If None, assumed equal to
  1101. `embed_dim`. Default None.
  1102. vdim (int, optional): The feature size in value. If None, assumed equal to
  1103. `embed_dim`. Default None.
  1104. need_weights (bool, optional): Indicate whether to return the attention
  1105. weights. Default False.
  1106. Examples:
  1107. .. code-block:: python
  1108. import paddle
  1109. # encoder input: [batch_size, sequence_length, d_model]
  1110. query = paddle.rand((2, 4, 128))
  1111. # self attention mask: [batch_size, num_heads, query_len, query_len]
  1112. attn_mask = paddle.rand((2, 2, 4, 4))
  1113. multi_head_attn = paddle.nn.MultiHeadAttention(128, 2)
  1114. output = multi_head_attn(query, None, None, attn_mask=attn_mask) # [2, 4, 128]
  1115. """
  1116. def __init__(self,
  1117. embed_dim,
  1118. num_heads,
  1119. dropout=0.,
  1120. kdim=None,
  1121. vdim=None,
  1122. need_weights=False):
  1123. super(MultiHeadAttention, self).__init__()
  1124. self.embed_dim = embed_dim
  1125. self.kdim = kdim if kdim is not None else embed_dim
  1126. self.vdim = vdim if vdim is not None else embed_dim
  1127. self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
  1128. self.num_heads = num_heads
  1129. self.dropout = dropout
  1130. self.need_weights = need_weights
  1131. self.head_dim = embed_dim // num_heads
  1132. assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
  1133. if self._qkv_same_embed_dim:
  1134. self.in_proj_weight = self.create_parameter(
  1135. shape=[embed_dim, 3 * embed_dim],
  1136. attr=None,
  1137. dtype=self._dtype,
  1138. is_bias=False)
  1139. self.in_proj_bias = self.create_parameter(
  1140. shape=[3 * embed_dim],
  1141. attr=None,
  1142. dtype=self._dtype,
  1143. is_bias=True)
  1144. else:
  1145. self.q_proj = nn.Linear(embed_dim, embed_dim)
  1146. self.k_proj = nn.Linear(self.kdim, embed_dim)
  1147. self.v_proj = nn.Linear(self.vdim, embed_dim)
  1148. self.out_proj = nn.Linear(embed_dim, embed_dim)
  1149. self._type_list = ('q_proj', 'k_proj', 'v_proj')
  1150. self._reset_parameters()
  1151. def _reset_parameters(self):
  1152. for p in self.parameters():
  1153. if p.dim() > 1:
  1154. xavier_uniform_(p)
  1155. else:
  1156. constant_(p)
  1157. def compute_qkv(self, tensor, index):
  1158. if self._qkv_same_embed_dim:
  1159. tensor = F.linear(
  1160. x=tensor,
  1161. weight=self.in_proj_weight[:, index * self.embed_dim:(
  1162. index + 1) * self.embed_dim],
  1163. bias=self.in_proj_bias[index * self.embed_dim:(index + 1) *
  1164. self.embed_dim]
  1165. if self.in_proj_bias is not None else None)
  1166. else:
  1167. tensor = getattr(self, self._type_list[index])(tensor)
  1168. tensor = tensor.reshape(
  1169. [0, 0, self.num_heads, self.head_dim]).transpose([0, 2, 1, 3])
  1170. return tensor
  1171. def forward(self, query, key=None, value=None, attn_mask=None):
  1172. r"""
  1173. Applies multi-head attention to map queries and a set of key-value pairs
  1174. to outputs.
  1175. Parameters:
  1176. query (Tensor): The queries for multi-head attention. It is a
  1177. tensor with shape `[batch_size, query_length, embed_dim]`. The
  1178. data type should be float32 or float64.
  1179. key (Tensor, optional): The keys for multi-head attention. It is
  1180. a tensor with shape `[batch_size, key_length, kdim]`. The
  1181. data type should be float32 or float64. If None, use `query` as
  1182. `key`. Default None.
  1183. value (Tensor, optional): The values for multi-head attention. It
  1184. is a tensor with shape `[batch_size, value_length, vdim]`.
  1185. The data type should be float32 or float64. If None, use `query` as
  1186. `value`. Default None.
  1187. attn_mask (Tensor, optional): A tensor used in multi-head attention
  1188. to prevents attention to some unwanted positions, usually the
  1189. paddings or the subsequent positions. It is a tensor with shape
  1190. broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`.
  1191. When the data type is bool, the unwanted positions have `False`
  1192. values and the others have `True` values. When the data type is
  1193. int, the unwanted positions have 0 values and the others have 1
  1194. values. When the data type is float, the unwanted positions have
  1195. `-INF` values and the others have 0 values. It can be None when
  1196. nothing wanted or needed to be prevented attention to. Default None.
  1197. Returns:
  1198. Tensor|tuple: It is a tensor that has the same shape and data type \
  1199. as `query`, representing attention output. Or a tuple if \
  1200. `need_weights` is True or `cache` is not None. If `need_weights` \
  1201. is True, except for attention output, the tuple also includes \
  1202. the attention weights tensor shaped `[batch_size, num_heads, query_length, key_length]`. \
  1203. If `cache` is not None, the tuple then includes the new cache \
  1204. having the same type as `cache`, and if it is `StaticCache`, it \
  1205. is same as the input `cache`, if it is `Cache`, the new cache \
  1206. reserves tensors concatanating raw tensors with intermediate \
  1207. results of current query.
  1208. """
  1209. key = query if key is None else key
  1210. value = query if value is None else value
  1211. # compute q ,k ,v
  1212. q, k, v = (self.compute_qkv(t, i)
  1213. for i, t in enumerate([query, key, value]))
  1214. # scale dot product attention
  1215. product = paddle.matmul(x=q, y=k, transpose_y=True)
  1216. scaling = float(self.head_dim)**-0.5
  1217. product = product * scaling
  1218. if attn_mask is not None:
  1219. # Support bool or int mask
  1220. attn_mask = _convert_attention_mask(attn_mask, product.dtype)
  1221. product = product + attn_mask
  1222. weights = F.softmax(product)
  1223. if self.dropout:
  1224. weights = F.dropout(
  1225. weights,
  1226. self.dropout,
  1227. training=self.training,
  1228. mode="upscale_in_train")
  1229. out = paddle.matmul(weights, v)
  1230. # combine heads
  1231. out = paddle.transpose(out, perm=[0, 2, 1, 3])
  1232. out = paddle.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
  1233. # project to output
  1234. out = self.out_proj(out)
  1235. outs = [out]
  1236. if self.need_weights:
  1237. outs.append(weights)
  1238. return out if len(outs) == 1 else tuple(outs)