layers.py 53 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import six
  16. import numpy as np
  17. from numbers import Integral
  18. import paddle
  19. import paddle.nn as nn
  20. from paddle import ParamAttr
  21. from paddle import to_tensor
  22. import paddle.nn.functional as F
  23. from paddle.nn.initializer import Normal, Constant, XavierUniform
  24. from paddle.regularizer import L2Decay
  25. from paddlex.ppdet.core.workspace import register, serializable
  26. from paddlex.ppdet.modeling.bbox_utils import delta2bbox
  27. from . import ops
  28. from .initializer import xavier_uniform_, constant_
  29. from paddle.vision.ops import DeformConv2D
  30. def _to_list(l):
  31. if isinstance(l, (list, tuple)):
  32. return list(l)
  33. return [l]
  34. class DeformableConvV2(nn.Layer):
  35. def __init__(self,
  36. in_channels,
  37. out_channels,
  38. kernel_size,
  39. stride=1,
  40. padding=0,
  41. dilation=1,
  42. groups=1,
  43. weight_attr=None,
  44. bias_attr=None,
  45. lr_scale=1,
  46. regularizer=None,
  47. skip_quant=False,
  48. dcn_bias_regularizer=L2Decay(0.),
  49. dcn_bias_lr_scale=2.):
  50. super(DeformableConvV2, self).__init__()
  51. self.offset_channel = 2 * kernel_size**2
  52. self.mask_channel = kernel_size**2
  53. if lr_scale == 1 and regularizer is None:
  54. offset_bias_attr = ParamAttr(initializer=Constant(0.))
  55. else:
  56. offset_bias_attr = ParamAttr(
  57. initializer=Constant(0.),
  58. learning_rate=lr_scale,
  59. regularizer=regularizer)
  60. self.conv_offset = nn.Conv2D(
  61. in_channels,
  62. 3 * kernel_size**2,
  63. kernel_size,
  64. stride=stride,
  65. padding=(kernel_size - 1) // 2,
  66. weight_attr=ParamAttr(initializer=Constant(0.0)),
  67. bias_attr=offset_bias_attr)
  68. if skip_quant:
  69. self.conv_offset.skip_quant = True
  70. if bias_attr:
  71. # in FCOS-DCN head, specifically need learning_rate and regularizer
  72. dcn_bias_attr = ParamAttr(
  73. initializer=Constant(value=0),
  74. regularizer=dcn_bias_regularizer,
  75. learning_rate=dcn_bias_lr_scale)
  76. else:
  77. # in ResNet backbone, do not need bias
  78. dcn_bias_attr = False
  79. self.conv_dcn = DeformConv2D(
  80. in_channels,
  81. out_channels,
  82. kernel_size,
  83. stride=stride,
  84. padding=(kernel_size - 1) // 2 * dilation,
  85. dilation=dilation,
  86. groups=groups,
  87. weight_attr=weight_attr,
  88. bias_attr=dcn_bias_attr)
  89. def forward(self, x):
  90. offset_mask = self.conv_offset(x)
  91. offset, mask = paddle.split(
  92. offset_mask,
  93. num_or_sections=[self.offset_channel, self.mask_channel],
  94. axis=1)
  95. mask = F.sigmoid(mask)
  96. y = self.conv_dcn(x, offset, mask=mask)
  97. return y
  98. class ConvNormLayer(nn.Layer):
  99. def __init__(self,
  100. ch_in,
  101. ch_out,
  102. filter_size,
  103. stride,
  104. groups=1,
  105. norm_type='bn',
  106. norm_decay=0.,
  107. norm_groups=32,
  108. use_dcn=False,
  109. bias_on=False,
  110. lr_scale=1.,
  111. freeze_norm=False,
  112. initializer=Normal(
  113. mean=0., std=0.01),
  114. skip_quant=False,
  115. dcn_lr_scale=2.,
  116. dcn_regularizer=L2Decay(0.)):
  117. super(ConvNormLayer, self).__init__()
  118. assert norm_type in ['bn', 'sync_bn', 'gn']
  119. if bias_on:
  120. bias_attr = ParamAttr(
  121. initializer=Constant(value=0.), learning_rate=lr_scale)
  122. else:
  123. bias_attr = False
  124. if not use_dcn:
  125. self.conv = nn.Conv2D(
  126. in_channels=ch_in,
  127. out_channels=ch_out,
  128. kernel_size=filter_size,
  129. stride=stride,
  130. padding=(filter_size - 1) // 2,
  131. groups=groups,
  132. weight_attr=ParamAttr(
  133. initializer=initializer, learning_rate=1.),
  134. bias_attr=bias_attr)
  135. if skip_quant:
  136. self.conv.skip_quant = True
  137. else:
  138. # in FCOS-DCN head, specifically need learning_rate and regularizer
  139. self.conv = DeformableConvV2(
  140. in_channels=ch_in,
  141. out_channels=ch_out,
  142. kernel_size=filter_size,
  143. stride=stride,
  144. padding=(filter_size - 1) // 2,
  145. groups=groups,
  146. weight_attr=ParamAttr(
  147. initializer=initializer, learning_rate=1.),
  148. bias_attr=True,
  149. lr_scale=dcn_lr_scale,
  150. regularizer=dcn_regularizer,
  151. dcn_bias_regularizer=dcn_regularizer,
  152. dcn_bias_lr_scale=dcn_lr_scale,
  153. skip_quant=skip_quant)
  154. norm_lr = 0. if freeze_norm else 1.
  155. param_attr = ParamAttr(
  156. learning_rate=norm_lr,
  157. regularizer=L2Decay(norm_decay)
  158. if norm_decay is not None else None)
  159. bias_attr = ParamAttr(
  160. learning_rate=norm_lr,
  161. regularizer=L2Decay(norm_decay)
  162. if norm_decay is not None else None)
  163. if norm_type in ['bn', 'sync_bn']:
  164. self.norm = nn.BatchNorm2D(
  165. ch_out, weight_attr=param_attr, bias_attr=bias_attr)
  166. elif norm_type == 'gn':
  167. self.norm = nn.GroupNorm(
  168. num_groups=norm_groups,
  169. num_channels=ch_out,
  170. weight_attr=param_attr,
  171. bias_attr=bias_attr)
  172. def forward(self, inputs):
  173. out = self.conv(inputs)
  174. out = self.norm(out)
  175. return out
  176. class LiteConv(nn.Layer):
  177. def __init__(self,
  178. in_channels,
  179. out_channels,
  180. stride=1,
  181. with_act=True,
  182. norm_type='sync_bn',
  183. name=None):
  184. super(LiteConv, self).__init__()
  185. self.lite_conv = nn.Sequential()
  186. conv1 = ConvNormLayer(
  187. in_channels,
  188. in_channels,
  189. filter_size=5,
  190. stride=stride,
  191. groups=in_channels,
  192. norm_type=norm_type,
  193. initializer=XavierUniform())
  194. conv2 = ConvNormLayer(
  195. in_channels,
  196. out_channels,
  197. filter_size=1,
  198. stride=stride,
  199. norm_type=norm_type,
  200. initializer=XavierUniform())
  201. conv3 = ConvNormLayer(
  202. out_channels,
  203. out_channels,
  204. filter_size=1,
  205. stride=stride,
  206. norm_type=norm_type,
  207. initializer=XavierUniform())
  208. conv4 = ConvNormLayer(
  209. out_channels,
  210. out_channels,
  211. filter_size=5,
  212. stride=stride,
  213. groups=out_channels,
  214. norm_type=norm_type,
  215. initializer=XavierUniform())
  216. conv_list = [conv1, conv2, conv3, conv4]
  217. self.lite_conv.add_sublayer('conv1', conv1)
  218. self.lite_conv.add_sublayer('relu6_1', nn.ReLU6())
  219. self.lite_conv.add_sublayer('conv2', conv2)
  220. if with_act:
  221. self.lite_conv.add_sublayer('relu6_2', nn.ReLU6())
  222. self.lite_conv.add_sublayer('conv3', conv3)
  223. self.lite_conv.add_sublayer('relu6_3', nn.ReLU6())
  224. self.lite_conv.add_sublayer('conv4', conv4)
  225. if with_act:
  226. self.lite_conv.add_sublayer('relu6_4', nn.ReLU6())
  227. def forward(self, inputs):
  228. out = self.lite_conv(inputs)
  229. return out
  230. class DropBlock(nn.Layer):
  231. def __init__(self, block_size, keep_prob, name, data_format='NCHW'):
  232. """
  233. DropBlock layer, see https://arxiv.org/abs/1810.12890
  234. Args:
  235. block_size (int): block size
  236. keep_prob (int): keep probability
  237. name (str): layer name
  238. data_format (str): data format, NCHW or NHWC
  239. """
  240. super(DropBlock, self).__init__()
  241. self.block_size = block_size
  242. self.keep_prob = keep_prob
  243. self.name = name
  244. self.data_format = data_format
  245. def forward(self, x):
  246. if not self.training or self.keep_prob == 1:
  247. return x
  248. else:
  249. gamma = (1. - self.keep_prob) / (self.block_size**2)
  250. if self.data_format == 'NCHW':
  251. shape = x.shape[2:]
  252. else:
  253. shape = x.shape[1:3]
  254. for s in shape:
  255. gamma *= s / (s - self.block_size + 1)
  256. matrix = paddle.cast(paddle.rand(x.shape) < gamma, x.dtype)
  257. mask_inv = F.max_pool2d(
  258. matrix,
  259. self.block_size,
  260. stride=1,
  261. padding=self.block_size // 2,
  262. data_format=self.data_format)
  263. mask = 1. - mask_inv
  264. y = x * mask * (mask.numel() / mask.sum())
  265. return y
  266. @register
  267. @serializable
  268. class AnchorGeneratorSSD(object):
  269. def __init__(
  270. self,
  271. steps=[8, 16, 32, 64, 100, 300],
  272. aspect_ratios=[[2.], [2., 3.], [2., 3.], [2., 3.], [2.], [2.]],
  273. min_ratio=15,
  274. max_ratio=90,
  275. base_size=300,
  276. min_sizes=[30.0, 60.0, 111.0, 162.0, 213.0, 264.0],
  277. max_sizes=[60.0, 111.0, 162.0, 213.0, 264.0, 315.0],
  278. offset=0.5,
  279. flip=True,
  280. clip=False,
  281. min_max_aspect_ratios_order=False):
  282. self.steps = steps
  283. self.aspect_ratios = aspect_ratios
  284. self.min_ratio = min_ratio
  285. self.max_ratio = max_ratio
  286. self.base_size = base_size
  287. self.min_sizes = min_sizes
  288. self.max_sizes = max_sizes
  289. self.offset = offset
  290. self.flip = flip
  291. self.clip = clip
  292. self.min_max_aspect_ratios_order = min_max_aspect_ratios_order
  293. if self.min_sizes == [] and self.max_sizes == []:
  294. num_layer = len(aspect_ratios)
  295. step = int(
  296. math.floor(((self.max_ratio - self.min_ratio)) / (num_layer - 2
  297. )))
  298. for ratio in six.moves.range(self.min_ratio, self.max_ratio + 1,
  299. step):
  300. self.min_sizes.append(self.base_size * ratio / 100.)
  301. self.max_sizes.append(self.base_size * (ratio + step) / 100.)
  302. self.min_sizes = [self.base_size * .10] + self.min_sizes
  303. self.max_sizes = [self.base_size * .20] + self.max_sizes
  304. self.num_priors = []
  305. for aspect_ratio, min_size, max_size in zip(
  306. aspect_ratios, self.min_sizes, self.max_sizes):
  307. if isinstance(min_size, (list, tuple)):
  308. self.num_priors.append(
  309. len(_to_list(min_size)) + len(_to_list(max_size)))
  310. else:
  311. self.num_priors.append((len(aspect_ratio) * 2 + 1) * len(
  312. _to_list(min_size)) + len(_to_list(max_size)))
  313. def __call__(self, inputs, image):
  314. boxes = []
  315. for input, min_size, max_size, aspect_ratio, step in zip(
  316. inputs, self.min_sizes, self.max_sizes, self.aspect_ratios,
  317. self.steps):
  318. box, _ = ops.prior_box(
  319. input=input,
  320. image=image,
  321. min_sizes=_to_list(min_size),
  322. max_sizes=_to_list(max_size),
  323. aspect_ratios=aspect_ratio,
  324. flip=self.flip,
  325. clip=self.clip,
  326. steps=[step, step],
  327. offset=self.offset,
  328. min_max_aspect_ratios_order=self.min_max_aspect_ratios_order)
  329. boxes.append(paddle.reshape(box, [-1, 4]))
  330. return boxes
  331. @register
  332. @serializable
  333. class RCNNBox(object):
  334. __shared__ = ['num_classes']
  335. def __init__(self,
  336. prior_box_var=[10., 10., 5., 5.],
  337. code_type="decode_center_size",
  338. box_normalized=False,
  339. num_classes=80):
  340. super(RCNNBox, self).__init__()
  341. self.prior_box_var = prior_box_var
  342. self.code_type = code_type
  343. self.box_normalized = box_normalized
  344. self.num_classes = num_classes
  345. def __call__(self, bbox_head_out, rois, im_shape, scale_factor):
  346. bbox_pred = bbox_head_out[0]
  347. cls_prob = bbox_head_out[1]
  348. roi = rois[0]
  349. rois_num = rois[1]
  350. origin_shape = paddle.floor(im_shape / scale_factor + 0.5)
  351. scale_list = []
  352. origin_shape_list = []
  353. batch_size = 1
  354. if isinstance(roi, list):
  355. batch_size = len(roi)
  356. else:
  357. batch_size = paddle.slice(paddle.shape(im_shape), [0], [0], [1])
  358. # bbox_pred.shape: [N, C*4]
  359. for idx in range(batch_size):
  360. roi_per_im = roi[idx]
  361. rois_num_per_im = rois_num[idx]
  362. expand_im_shape = paddle.expand(im_shape[idx, :],
  363. [rois_num_per_im, 2])
  364. origin_shape_list.append(expand_im_shape)
  365. origin_shape = paddle.concat(origin_shape_list)
  366. # bbox_pred.shape: [N, C*4]
  367. # C=num_classes in faster/mask rcnn(bbox_head), C=1 in cascade rcnn(cascade_head)
  368. bbox = paddle.concat(roi)
  369. if bbox.shape[0] == 0:
  370. bbox = paddle.zeros([0, bbox_pred.shape[1]], dtype='float32')
  371. else:
  372. bbox = delta2bbox(bbox_pred, bbox, self.prior_box_var)
  373. scores = cls_prob[:, :-1]
  374. # bbox.shape: [N, C, 4]
  375. # bbox.shape[1] must be equal to scores.shape[1]
  376. bbox_num_class = bbox.shape[1]
  377. if bbox_num_class == 1:
  378. bbox = paddle.tile(bbox, [1, self.num_classes, 1])
  379. origin_h = paddle.unsqueeze(origin_shape[:, 0], axis=1)
  380. origin_w = paddle.unsqueeze(origin_shape[:, 1], axis=1)
  381. zeros = paddle.zeros_like(origin_h)
  382. x1 = paddle.maximum(paddle.minimum(bbox[:, :, 0], origin_w), zeros)
  383. y1 = paddle.maximum(paddle.minimum(bbox[:, :, 1], origin_h), zeros)
  384. x2 = paddle.maximum(paddle.minimum(bbox[:, :, 2], origin_w), zeros)
  385. y2 = paddle.maximum(paddle.minimum(bbox[:, :, 3], origin_h), zeros)
  386. bbox = paddle.stack([x1, y1, x2, y2], axis=-1)
  387. bboxes = (bbox, rois_num)
  388. return bboxes, scores
  389. @register
  390. @serializable
  391. class MultiClassNMS(object):
  392. def __init__(self,
  393. score_threshold=.05,
  394. nms_top_k=-1,
  395. keep_top_k=100,
  396. nms_threshold=.5,
  397. normalized=True,
  398. nms_eta=1.0,
  399. return_index=False,
  400. return_rois_num=True):
  401. super(MultiClassNMS, self).__init__()
  402. self.score_threshold = score_threshold
  403. self.nms_top_k = nms_top_k
  404. self.keep_top_k = keep_top_k
  405. self.nms_threshold = nms_threshold
  406. self.normalized = normalized
  407. self.nms_eta = nms_eta
  408. self.return_index = return_index
  409. self.return_rois_num = return_rois_num
  410. def __call__(self, bboxes, score, background_label=-1):
  411. """
  412. bboxes (Tensor|List[Tensor]): 1. (Tensor) Predicted bboxes with shape
  413. [N, M, 4], N is the batch size and M
  414. is the number of bboxes
  415. 2. (List[Tensor]) bboxes and bbox_num,
  416. bboxes have shape of [M, C, 4], C
  417. is the class number and bbox_num means
  418. the number of bboxes of each batch with
  419. shape [N,]
  420. score (Tensor): Predicted scores with shape [N, C, M] or [M, C]
  421. background_label (int): Ignore the background label; For example, RCNN
  422. is num_classes and YOLO is -1.
  423. """
  424. kwargs = self.__dict__.copy()
  425. if isinstance(bboxes, tuple):
  426. bboxes, bbox_num = bboxes
  427. kwargs.update({'rois_num': bbox_num})
  428. if background_label > -1:
  429. kwargs.update({'background_label': background_label})
  430. return ops.multiclass_nms(bboxes, score, **kwargs)
  431. @register
  432. @serializable
  433. class MatrixNMS(object):
  434. __append_doc__ = True
  435. def __init__(self,
  436. score_threshold=.05,
  437. post_threshold=.05,
  438. nms_top_k=-1,
  439. keep_top_k=100,
  440. use_gaussian=False,
  441. gaussian_sigma=2.,
  442. normalized=False,
  443. background_label=0):
  444. super(MatrixNMS, self).__init__()
  445. self.score_threshold = score_threshold
  446. self.post_threshold = post_threshold
  447. self.nms_top_k = nms_top_k
  448. self.keep_top_k = keep_top_k
  449. self.normalized = normalized
  450. self.use_gaussian = use_gaussian
  451. self.gaussian_sigma = gaussian_sigma
  452. self.background_label = background_label
  453. def __call__(self, bbox, score, *args):
  454. return ops.matrix_nms(
  455. bboxes=bbox,
  456. scores=score,
  457. score_threshold=self.score_threshold,
  458. post_threshold=self.post_threshold,
  459. nms_top_k=self.nms_top_k,
  460. keep_top_k=self.keep_top_k,
  461. use_gaussian=self.use_gaussian,
  462. gaussian_sigma=self.gaussian_sigma,
  463. background_label=self.background_label,
  464. normalized=self.normalized)
  465. @register
  466. @serializable
  467. class YOLOBox(object):
  468. __shared__ = ['num_classes']
  469. def __init__(self,
  470. num_classes=80,
  471. conf_thresh=0.005,
  472. downsample_ratio=32,
  473. clip_bbox=True,
  474. scale_x_y=1.):
  475. self.num_classes = num_classes
  476. self.conf_thresh = conf_thresh
  477. self.downsample_ratio = downsample_ratio
  478. self.clip_bbox = clip_bbox
  479. self.scale_x_y = scale_x_y
  480. def __call__(self,
  481. yolo_head_out,
  482. anchors,
  483. im_shape,
  484. scale_factor,
  485. var_weight=None):
  486. boxes_list = []
  487. scores_list = []
  488. origin_shape = im_shape / scale_factor
  489. origin_shape = paddle.cast(origin_shape, 'int32')
  490. for i, head_out in enumerate(yolo_head_out):
  491. boxes, scores = ops.yolo_box(head_out, origin_shape, anchors[i],
  492. self.num_classes, self.conf_thresh,
  493. self.downsample_ratio // 2**i,
  494. self.clip_bbox, self.scale_x_y)
  495. boxes_list.append(boxes)
  496. scores_list.append(paddle.transpose(scores, perm=[0, 2, 1]))
  497. yolo_boxes = paddle.concat(boxes_list, axis=1)
  498. yolo_scores = paddle.concat(scores_list, axis=2)
  499. return yolo_boxes, yolo_scores
  500. @register
  501. @serializable
  502. class SSDBox(object):
  503. def __init__(self, is_normalized=True):
  504. self.is_normalized = is_normalized
  505. self.norm_delta = float(not self.is_normalized)
  506. def __call__(self,
  507. preds,
  508. prior_boxes,
  509. im_shape,
  510. scale_factor,
  511. var_weight=None):
  512. boxes, scores = preds
  513. outputs = []
  514. for box, score, prior_box in zip(boxes, scores, prior_boxes):
  515. pb_w = prior_box[:, 2] - prior_box[:, 0] + self.norm_delta
  516. pb_h = prior_box[:, 3] - prior_box[:, 1] + self.norm_delta
  517. pb_x = prior_box[:, 0] + pb_w * 0.5
  518. pb_y = prior_box[:, 1] + pb_h * 0.5
  519. out_x = pb_x + box[:, :, 0] * pb_w * 0.1
  520. out_y = pb_y + box[:, :, 1] * pb_h * 0.1
  521. out_w = paddle.exp(box[:, :, 2] * 0.2) * pb_w
  522. out_h = paddle.exp(box[:, :, 3] * 0.2) * pb_h
  523. if self.is_normalized:
  524. h = paddle.unsqueeze(
  525. im_shape[:, 0] / scale_factor[:, 0], axis=-1)
  526. w = paddle.unsqueeze(
  527. im_shape[:, 1] / scale_factor[:, 1], axis=-1)
  528. output = paddle.stack(
  529. [(out_x - out_w / 2.) * w, (out_y - out_h / 2.) * h,
  530. (out_x + out_w / 2.) * w, (out_y + out_h / 2.) * h],
  531. axis=-1)
  532. else:
  533. output = paddle.stack(
  534. [
  535. out_x - out_w / 2., out_y - out_h / 2.,
  536. out_x + out_w / 2. - 1., out_y + out_h / 2. - 1.
  537. ],
  538. axis=-1)
  539. outputs.append(output)
  540. boxes = paddle.concat(outputs, axis=1)
  541. scores = F.softmax(paddle.concat(scores, axis=1))
  542. scores = paddle.transpose(scores, [0, 2, 1])
  543. return boxes, scores
  544. @register
  545. @serializable
  546. class AnchorGrid(object):
  547. """Generate anchor grid
  548. Args:
  549. image_size (int or list): input image size, may be a single integer or
  550. list of [h, w]. Default: 512
  551. min_level (int): min level of the feature pyramid. Default: 3
  552. max_level (int): max level of the feature pyramid. Default: 7
  553. anchor_base_scale: base anchor scale. Default: 4
  554. num_scales: number of anchor scales. Default: 3
  555. aspect_ratios: aspect ratios. default: [[1, 1], [1.4, 0.7], [0.7, 1.4]]
  556. """
  557. def __init__(self,
  558. image_size=512,
  559. min_level=3,
  560. max_level=7,
  561. anchor_base_scale=4,
  562. num_scales=3,
  563. aspect_ratios=[[1, 1], [1.4, 0.7], [0.7, 1.4]]):
  564. super(AnchorGrid, self).__init__()
  565. if isinstance(image_size, Integral):
  566. self.image_size = [image_size, image_size]
  567. else:
  568. self.image_size = image_size
  569. for dim in self.image_size:
  570. assert dim % 2 ** max_level == 0, \
  571. "image size should be multiple of the max level stride"
  572. self.min_level = min_level
  573. self.max_level = max_level
  574. self.anchor_base_scale = anchor_base_scale
  575. self.num_scales = num_scales
  576. self.aspect_ratios = aspect_ratios
  577. @property
  578. def base_cell(self):
  579. if not hasattr(self, '_base_cell'):
  580. self._base_cell = self.make_cell()
  581. return self._base_cell
  582. def make_cell(self):
  583. scales = [2**(i / self.num_scales) for i in range(self.num_scales)]
  584. scales = np.array(scales)
  585. ratios = np.array(self.aspect_ratios)
  586. ws = np.outer(scales, ratios[:, 0]).reshape(-1, 1)
  587. hs = np.outer(scales, ratios[:, 1]).reshape(-1, 1)
  588. anchors = np.hstack((-0.5 * ws, -0.5 * hs, 0.5 * ws, 0.5 * hs))
  589. return anchors
  590. def make_grid(self, stride):
  591. cell = self.base_cell * stride * self.anchor_base_scale
  592. x_steps = np.arange(stride // 2, self.image_size[1], stride)
  593. y_steps = np.arange(stride // 2, self.image_size[0], stride)
  594. offset_x, offset_y = np.meshgrid(x_steps, y_steps)
  595. offset_x = offset_x.flatten()
  596. offset_y = offset_y.flatten()
  597. offsets = np.stack((offset_x, offset_y, offset_x, offset_y), axis=-1)
  598. offsets = offsets[:, np.newaxis, :]
  599. return (cell + offsets).reshape(-1, 4)
  600. def generate(self):
  601. return [
  602. self.make_grid(2**l)
  603. for l in range(self.min_level, self.max_level + 1)
  604. ]
  605. def __call__(self):
  606. if not hasattr(self, '_anchor_vars'):
  607. anchor_vars = []
  608. helper = LayerHelper('anchor_grid')
  609. for idx, l in enumerate(range(self.min_level, self.max_level + 1)):
  610. stride = 2**l
  611. anchors = self.make_grid(stride)
  612. var = helper.create_parameter(
  613. attr=ParamAttr(name='anchors_{}'.format(idx)),
  614. shape=anchors.shape,
  615. dtype='float32',
  616. stop_gradient=True,
  617. default_initializer=NumpyArrayInitializer(anchors))
  618. anchor_vars.append(var)
  619. var.persistable = True
  620. self._anchor_vars = anchor_vars
  621. return self._anchor_vars
  622. @register
  623. @serializable
  624. class FCOSBox(object):
  625. __shared__ = ['num_classes']
  626. def __init__(self, num_classes=80):
  627. super(FCOSBox, self).__init__()
  628. self.num_classes = num_classes
  629. def _merge_hw(self, inputs, ch_type="channel_first"):
  630. """
  631. Merge h and w of the feature map into one dimension.
  632. Args:
  633. inputs (Tensor): Tensor of the input feature map
  634. ch_type (str): "channel_first" or "channel_last" style
  635. Return:
  636. new_shape (Tensor): The new shape after h and w merged
  637. """
  638. shape_ = paddle.shape(inputs)
  639. bs, ch, hi, wi = shape_[0], shape_[1], shape_[2], shape_[3]
  640. img_size = hi * wi
  641. img_size.stop_gradient = True
  642. if ch_type == "channel_first":
  643. new_shape = paddle.concat([bs, ch, img_size])
  644. elif ch_type == "channel_last":
  645. new_shape = paddle.concat([bs, img_size, ch])
  646. else:
  647. raise KeyError("Wrong ch_type %s" % ch_type)
  648. new_shape.stop_gradient = True
  649. return new_shape
  650. def _postprocessing_by_level(self, locations, box_cls, box_reg, box_ctn,
  651. scale_factor):
  652. """
  653. Postprocess each layer of the output with corresponding locations.
  654. Args:
  655. locations (Tensor): anchor points for current layer, [H*W, 2]
  656. box_cls (Tensor): categories prediction, [N, C, H, W],
  657. C is the number of classes
  658. box_reg (Tensor): bounding box prediction, [N, 4, H, W]
  659. box_ctn (Tensor): centerness prediction, [N, 1, H, W]
  660. scale_factor (Tensor): [h_scale, w_scale] for input images
  661. Return:
  662. box_cls_ch_last (Tensor): score for each category, in [N, C, M]
  663. C is the number of classes and M is the number of anchor points
  664. box_reg_decoding (Tensor): decoded bounding box, in [N, M, 4]
  665. last dimension is [x1, y1, x2, y2]
  666. """
  667. act_shape_cls = self._merge_hw(box_cls)
  668. box_cls_ch_last = paddle.reshape(x=box_cls, shape=act_shape_cls)
  669. box_cls_ch_last = F.sigmoid(box_cls_ch_last)
  670. act_shape_reg = self._merge_hw(box_reg)
  671. box_reg_ch_last = paddle.reshape(x=box_reg, shape=act_shape_reg)
  672. box_reg_ch_last = paddle.transpose(box_reg_ch_last, perm=[0, 2, 1])
  673. box_reg_decoding = paddle.stack(
  674. [
  675. locations[:, 0] - box_reg_ch_last[:, :, 0],
  676. locations[:, 1] - box_reg_ch_last[:, :, 1],
  677. locations[:, 0] + box_reg_ch_last[:, :, 2],
  678. locations[:, 1] + box_reg_ch_last[:, :, 3]
  679. ],
  680. axis=1)
  681. box_reg_decoding = paddle.transpose(box_reg_decoding, perm=[0, 2, 1])
  682. act_shape_ctn = self._merge_hw(box_ctn)
  683. box_ctn_ch_last = paddle.reshape(x=box_ctn, shape=act_shape_ctn)
  684. box_ctn_ch_last = F.sigmoid(box_ctn_ch_last)
  685. # recover the location to original image
  686. im_scale = paddle.concat([scale_factor, scale_factor], axis=1)
  687. im_scale = paddle.expand(im_scale, [box_reg_decoding.shape[0], 4])
  688. im_scale = paddle.reshape(im_scale, [box_reg_decoding.shape[0], -1, 4])
  689. box_reg_decoding = box_reg_decoding / im_scale
  690. box_cls_ch_last = box_cls_ch_last * box_ctn_ch_last
  691. return box_cls_ch_last, box_reg_decoding
  692. def __call__(self, locations, cls_logits, bboxes_reg, centerness,
  693. scale_factor):
  694. pred_boxes_ = []
  695. pred_scores_ = []
  696. for pts, cls, box, ctn in zip(locations, cls_logits, bboxes_reg,
  697. centerness):
  698. pred_scores_lvl, pred_boxes_lvl = self._postprocessing_by_level(
  699. pts, cls, box, ctn, scale_factor)
  700. pred_boxes_.append(pred_boxes_lvl)
  701. pred_scores_.append(pred_scores_lvl)
  702. pred_boxes = paddle.concat(pred_boxes_, axis=1)
  703. pred_scores = paddle.concat(pred_scores_, axis=2)
  704. return pred_boxes, pred_scores
  705. @register
  706. class TTFBox(object):
  707. __shared__ = ['down_ratio']
  708. def __init__(self, max_per_img=100, score_thresh=0.01, down_ratio=4):
  709. super(TTFBox, self).__init__()
  710. self.max_per_img = max_per_img
  711. self.score_thresh = score_thresh
  712. self.down_ratio = down_ratio
  713. def _simple_nms(self, heat, kernel=3):
  714. """
  715. Use maxpool to filter the max score, get local peaks.
  716. """
  717. pad = (kernel - 1) // 2
  718. hmax = F.max_pool2d(heat, kernel, stride=1, padding=pad)
  719. keep = paddle.cast(hmax == heat, 'float32')
  720. return heat * keep
  721. def _topk(self, scores):
  722. """
  723. Select top k scores and decode to get xy coordinates.
  724. """
  725. k = self.max_per_img
  726. shape_fm = paddle.shape(scores)
  727. shape_fm.stop_gradient = True
  728. cat, height, width = shape_fm[1], shape_fm[2], shape_fm[3]
  729. # batch size is 1
  730. scores_r = paddle.reshape(scores, [cat, -1])
  731. topk_scores, topk_inds = paddle.topk(scores_r, k)
  732. topk_scores, topk_inds = paddle.topk(scores_r, k)
  733. topk_ys = topk_inds // width
  734. topk_xs = topk_inds % width
  735. topk_score_r = paddle.reshape(topk_scores, [-1])
  736. topk_score, topk_ind = paddle.topk(topk_score_r, k)
  737. k_t = paddle.full(paddle.shape(topk_ind), k, dtype='int64')
  738. topk_clses = paddle.cast(paddle.floor_divide(topk_ind, k_t), 'float32')
  739. topk_inds = paddle.reshape(topk_inds, [-1])
  740. topk_ys = paddle.reshape(topk_ys, [-1, 1])
  741. topk_xs = paddle.reshape(topk_xs, [-1, 1])
  742. topk_inds = paddle.gather(topk_inds, topk_ind)
  743. topk_ys = paddle.gather(topk_ys, topk_ind)
  744. topk_xs = paddle.gather(topk_xs, topk_ind)
  745. return topk_score, topk_inds, topk_clses, topk_ys, topk_xs
  746. def _decode(self, hm, wh, im_shape, scale_factor):
  747. heatmap = F.sigmoid(hm)
  748. heat = self._simple_nms(heatmap)
  749. scores, inds, clses, ys, xs = self._topk(heat)
  750. ys = paddle.cast(ys, 'float32') * self.down_ratio
  751. xs = paddle.cast(xs, 'float32') * self.down_ratio
  752. scores = paddle.tensor.unsqueeze(scores, [1])
  753. clses = paddle.tensor.unsqueeze(clses, [1])
  754. wh_t = paddle.transpose(wh, [0, 2, 3, 1])
  755. wh = paddle.reshape(wh_t, [-1, paddle.shape(wh_t)[-1]])
  756. wh = paddle.gather(wh, inds)
  757. x1 = xs - wh[:, 0:1]
  758. y1 = ys - wh[:, 1:2]
  759. x2 = xs + wh[:, 2:3]
  760. y2 = ys + wh[:, 3:4]
  761. bboxes = paddle.concat([x1, y1, x2, y2], axis=1)
  762. scale_y = scale_factor[:, 0:1]
  763. scale_x = scale_factor[:, 1:2]
  764. scale_expand = paddle.concat(
  765. [scale_x, scale_y, scale_x, scale_y], axis=1)
  766. boxes_shape = paddle.shape(bboxes)
  767. boxes_shape.stop_gradient = True
  768. scale_expand = paddle.expand(scale_expand, shape=boxes_shape)
  769. bboxes = paddle.divide(bboxes, scale_expand)
  770. results = paddle.concat([clses, scores, bboxes], axis=1)
  771. # hack: append result with cls=-1 and score=1. to avoid all scores
  772. # are less than score_thresh which may cause error in gather.
  773. fill_r = paddle.to_tensor(np.array([[-1, 1, 0, 0, 0, 0]]))
  774. fill_r = paddle.cast(fill_r, results.dtype)
  775. results = paddle.concat([results, fill_r])
  776. scores = results[:, 1]
  777. valid_ind = paddle.nonzero(scores > self.score_thresh)
  778. results = paddle.gather(results, valid_ind)
  779. return results, paddle.shape(results)[0:1]
  780. def __call__(self, hm, wh, im_shape, scale_factor):
  781. results = []
  782. results_num = []
  783. for i in range(scale_factor.shape[0]):
  784. result, num = self._decode(hm[i:i + 1, ], wh[i:i + 1, ],
  785. im_shape[i:i + 1, ],
  786. scale_factor[i:i + 1, ])
  787. results.append(result)
  788. results_num.append(num)
  789. results = paddle.concat(results, axis=0)
  790. results_num = paddle.concat(results_num, axis=0)
  791. return results, results_num
  792. @register
  793. @serializable
  794. class JDEBox(object):
  795. __shared__ = ['num_classes']
  796. def __init__(self, num_classes=1, conf_thresh=0.3, downsample_ratio=32):
  797. self.num_classes = num_classes
  798. self.conf_thresh = conf_thresh
  799. self.downsample_ratio = downsample_ratio
  800. def generate_anchor(self, nGh, nGw, anchor_wh):
  801. nA = len(anchor_wh)
  802. yv, xv = paddle.meshgrid([paddle.arange(nGh), paddle.arange(nGw)])
  803. mesh = paddle.stack(
  804. (xv, yv), axis=0).cast(dtype='float32') # 2 x nGh x nGw
  805. meshs = paddle.tile(mesh, [nA, 1, 1, 1])
  806. anchor_offset_mesh = anchor_wh[:, :, None][:, :, :, None].repeat(
  807. int(nGh), axis=-2).repeat(
  808. int(nGw), axis=-1)
  809. anchor_offset_mesh = paddle.to_tensor(
  810. anchor_offset_mesh.astype(np.float32))
  811. # nA x 2 x nGh x nGw
  812. anchor_mesh = paddle.concat([meshs, anchor_offset_mesh], axis=1)
  813. anchor_mesh = paddle.transpose(anchor_mesh,
  814. [0, 2, 3, 1]) # (nA x nGh x nGw) x 4
  815. return anchor_mesh
  816. def decode_delta(self, delta, fg_anchor_list):
  817. px, py, pw, ph = fg_anchor_list[:, 0], fg_anchor_list[:,1], \
  818. fg_anchor_list[:, 2], fg_anchor_list[:,3]
  819. dx, dy, dw, dh = delta[:, 0], delta[:, 1], delta[:, 2], delta[:, 3]
  820. gx = pw * dx + px
  821. gy = ph * dy + py
  822. gw = pw * paddle.exp(dw)
  823. gh = ph * paddle.exp(dh)
  824. gx1 = gx - gw * 0.5
  825. gy1 = gy - gh * 0.5
  826. gx2 = gx + gw * 0.5
  827. gy2 = gy + gh * 0.5
  828. return paddle.stack([gx1, gy1, gx2, gy2], axis=1)
  829. def decode_delta_map(self, nA, nGh, nGw, delta_map, anchor_vec):
  830. anchor_mesh = self.generate_anchor(nGh, nGw, anchor_vec)
  831. anchor_mesh = paddle.unsqueeze(anchor_mesh, 0)
  832. pred_list = self.decode_delta(
  833. paddle.reshape(
  834. delta_map, shape=[-1, 4]),
  835. paddle.reshape(
  836. anchor_mesh, shape=[-1, 4]))
  837. pred_map = paddle.reshape(pred_list, shape=[nA * nGh * nGw, 4])
  838. return pred_map
  839. def _postprocessing_by_level(self, nA, stride, head_out, anchor_vec):
  840. boxes_shape = head_out.shape # [nB, nA*6, nGh, nGw]
  841. nGh, nGw = boxes_shape[-2], boxes_shape[-1]
  842. nB = 1 # TODO: only support bs=1 now
  843. boxes_list, scores_list = [], []
  844. for idx in range(nB):
  845. p = paddle.reshape(
  846. head_out[idx], shape=[nA, self.num_classes + 5, nGh, nGw])
  847. p = paddle.transpose(p, perm=[0, 2, 3, 1]) # [nA, nGh, nGw, 6]
  848. delta_map = p[:, :, :, :4]
  849. boxes = self.decode_delta_map(nA, nGh, nGw, delta_map, anchor_vec)
  850. # [nA * nGh * nGw, 4]
  851. boxes_list.append(boxes * stride)
  852. p_conf = paddle.transpose(
  853. p[:, :, :, 4:6], perm=[3, 0, 1, 2]) # [2, nA, nGh, nGw]
  854. p_conf = F.softmax(
  855. p_conf, axis=0)[1, :, :, :].unsqueeze(-1) # [nA, nGh, nGw, 1]
  856. scores = paddle.reshape(p_conf, shape=[nA * nGh * nGw, 1])
  857. scores_list.append(scores)
  858. boxes_results = paddle.stack(boxes_list)
  859. scores_results = paddle.stack(scores_list)
  860. return boxes_results, scores_results
  861. def __call__(self, yolo_head_out, anchors):
  862. bbox_pred_list = []
  863. for i, head_out in enumerate(yolo_head_out):
  864. stride = self.downsample_ratio // 2**i
  865. anc_w, anc_h = anchors[i][0::2], anchors[i][1::2]
  866. anchor_vec = np.stack((anc_w, anc_h), axis=1) / stride
  867. nA = len(anc_w)
  868. boxes, scores = self._postprocessing_by_level(nA, stride, head_out,
  869. anchor_vec)
  870. bbox_pred_list.append(paddle.concat([boxes, scores], axis=-1))
  871. yolo_boxes_scores = paddle.concat(bbox_pred_list, axis=1)
  872. boxes_idx_over_conf_thr = paddle.nonzero(
  873. yolo_boxes_scores[:, :, -1] > self.conf_thresh)
  874. boxes_idx_over_conf_thr.stop_gradient = True
  875. return boxes_idx_over_conf_thr, yolo_boxes_scores
  876. @register
  877. @serializable
  878. class MaskMatrixNMS(object):
  879. """
  880. Matrix NMS for multi-class masks.
  881. Args:
  882. update_threshold (float): Updated threshold of categroy score in second time.
  883. pre_nms_top_n (int): Number of total instance to be kept per image before NMS
  884. post_nms_top_n (int): Number of total instance to be kept per image after NMS.
  885. kernel (str): 'linear' or 'gaussian'.
  886. sigma (float): std in gaussian method.
  887. Input:
  888. seg_preds (Variable): shape (n, h, w), segmentation feature maps
  889. seg_masks (Variable): shape (n, h, w), segmentation feature maps
  890. cate_labels (Variable): shape (n), mask labels in descending order
  891. cate_scores (Variable): shape (n), mask scores in descending order
  892. sum_masks (Variable): a float tensor of the sum of seg_masks
  893. Returns:
  894. Variable: cate_scores, tensors of shape (n)
  895. """
  896. def __init__(self,
  897. update_threshold=0.05,
  898. pre_nms_top_n=500,
  899. post_nms_top_n=100,
  900. kernel='gaussian',
  901. sigma=2.0):
  902. super(MaskMatrixNMS, self).__init__()
  903. self.update_threshold = update_threshold
  904. self.pre_nms_top_n = pre_nms_top_n
  905. self.post_nms_top_n = post_nms_top_n
  906. self.kernel = kernel
  907. self.sigma = sigma
  908. def _sort_score(self, scores, top_num):
  909. if paddle.shape(scores)[0] > top_num:
  910. return paddle.topk(scores, top_num)[1]
  911. else:
  912. return paddle.argsort(scores, descending=True)
  913. def __call__(self,
  914. seg_preds,
  915. seg_masks,
  916. cate_labels,
  917. cate_scores,
  918. sum_masks=None):
  919. # sort and keep top nms_pre
  920. sort_inds = self._sort_score(cate_scores, self.pre_nms_top_n)
  921. seg_masks = paddle.gather(seg_masks, index=sort_inds)
  922. seg_preds = paddle.gather(seg_preds, index=sort_inds)
  923. sum_masks = paddle.gather(sum_masks, index=sort_inds)
  924. cate_scores = paddle.gather(cate_scores, index=sort_inds)
  925. cate_labels = paddle.gather(cate_labels, index=sort_inds)
  926. seg_masks = paddle.flatten(seg_masks, start_axis=1, stop_axis=-1)
  927. # inter.
  928. inter_matrix = paddle.mm(seg_masks,
  929. paddle.transpose(seg_masks, [1, 0]))
  930. n_samples = paddle.shape(cate_labels)
  931. # union.
  932. sum_masks_x = paddle.expand(sum_masks, shape=[n_samples, n_samples])
  933. # iou.
  934. iou_matrix = (inter_matrix / (
  935. sum_masks_x + paddle.transpose(sum_masks_x, [1, 0]) - inter_matrix)
  936. )
  937. iou_matrix = paddle.triu(iou_matrix, diagonal=1)
  938. # label_specific matrix.
  939. cate_labels_x = paddle.expand(
  940. cate_labels, shape=[n_samples, n_samples])
  941. label_matrix = paddle.cast(
  942. (cate_labels_x == paddle.transpose(cate_labels_x, [1, 0])),
  943. 'float32')
  944. label_matrix = paddle.triu(label_matrix, diagonal=1)
  945. # IoU compensation
  946. compensate_iou = paddle.max((iou_matrix * label_matrix), axis=0)
  947. compensate_iou = paddle.expand(
  948. compensate_iou, shape=[n_samples, n_samples])
  949. compensate_iou = paddle.transpose(compensate_iou, [1, 0])
  950. # IoU decay
  951. decay_iou = iou_matrix * label_matrix
  952. # matrix nms
  953. if self.kernel == 'gaussian':
  954. decay_matrix = paddle.exp(-1 * self.sigma * (decay_iou**2))
  955. compensate_matrix = paddle.exp(-1 * self.sigma *
  956. (compensate_iou**2))
  957. decay_coefficient = paddle.min(decay_matrix / compensate_matrix,
  958. axis=0)
  959. elif self.kernel == 'linear':
  960. decay_matrix = (1 - decay_iou) / (1 - compensate_iou)
  961. decay_coefficient = paddle.min(decay_matrix, axis=0)
  962. else:
  963. raise NotImplementedError
  964. # update the score.
  965. cate_scores = cate_scores * decay_coefficient
  966. y = paddle.zeros(shape=paddle.shape(cate_scores), dtype='float32')
  967. keep = paddle.where(cate_scores >= self.update_threshold, cate_scores,
  968. y)
  969. keep = paddle.nonzero(keep)
  970. keep = paddle.squeeze(keep, axis=[1])
  971. # Prevent empty and increase fake data
  972. keep = paddle.concat(
  973. [keep, paddle.cast(paddle.shape(cate_scores)[0] - 1, 'int64')])
  974. seg_preds = paddle.gather(seg_preds, index=keep)
  975. cate_scores = paddle.gather(cate_scores, index=keep)
  976. cate_labels = paddle.gather(cate_labels, index=keep)
  977. # sort and keep top_k
  978. sort_inds = self._sort_score(cate_scores, self.post_nms_top_n)
  979. seg_preds = paddle.gather(seg_preds, index=sort_inds)
  980. cate_scores = paddle.gather(cate_scores, index=sort_inds)
  981. cate_labels = paddle.gather(cate_labels, index=sort_inds)
  982. return seg_preds, cate_scores, cate_labels
  983. def Conv2d(in_channels,
  984. out_channels,
  985. kernel_size,
  986. stride=1,
  987. padding=0,
  988. dilation=1,
  989. groups=1,
  990. bias=True,
  991. weight_init=Normal(std=0.001),
  992. bias_init=Constant(0.)):
  993. weight_attr = paddle.framework.ParamAttr(initializer=weight_init)
  994. if bias:
  995. bias_attr = paddle.framework.ParamAttr(initializer=bias_init)
  996. else:
  997. bias_attr = False
  998. conv = nn.Conv2D(
  999. in_channels,
  1000. out_channels,
  1001. kernel_size,
  1002. stride,
  1003. padding,
  1004. dilation,
  1005. groups,
  1006. weight_attr=weight_attr,
  1007. bias_attr=bias_attr)
  1008. return conv
  1009. def ConvTranspose2d(in_channels,
  1010. out_channels,
  1011. kernel_size,
  1012. stride=1,
  1013. padding=0,
  1014. output_padding=0,
  1015. groups=1,
  1016. bias=True,
  1017. dilation=1,
  1018. weight_init=Normal(std=0.001),
  1019. bias_init=Constant(0.)):
  1020. weight_attr = paddle.framework.ParamAttr(initializer=weight_init)
  1021. if bias:
  1022. bias_attr = paddle.framework.ParamAttr(initializer=bias_init)
  1023. else:
  1024. bias_attr = False
  1025. conv = nn.Conv2DTranspose(
  1026. in_channels,
  1027. out_channels,
  1028. kernel_size,
  1029. stride,
  1030. padding,
  1031. output_padding,
  1032. dilation,
  1033. groups,
  1034. weight_attr=weight_attr,
  1035. bias_attr=bias_attr)
  1036. return conv
  1037. def BatchNorm2d(num_features, eps=1e-05, momentum=0.9, affine=True):
  1038. if not affine:
  1039. weight_attr = False
  1040. bias_attr = False
  1041. else:
  1042. weight_attr = None
  1043. bias_attr = None
  1044. batchnorm = nn.BatchNorm2D(
  1045. num_features,
  1046. momentum,
  1047. eps,
  1048. weight_attr=weight_attr,
  1049. bias_attr=bias_attr)
  1050. return batchnorm
  1051. def ReLU():
  1052. return nn.ReLU()
  1053. def Upsample(scale_factor=None, mode='nearest', align_corners=False):
  1054. return nn.Upsample(None, scale_factor, mode, align_corners)
  1055. def MaxPool(kernel_size, stride, padding, ceil_mode=False):
  1056. return nn.MaxPool2D(kernel_size, stride, padding, ceil_mode=ceil_mode)
  1057. class Concat(nn.Layer):
  1058. def __init__(self, dim=0):
  1059. super(Concat, self).__init__()
  1060. self.dim = dim
  1061. def forward(self, inputs):
  1062. return paddle.concat(inputs, axis=self.dim)
  1063. def extra_repr(self):
  1064. return 'dim={}'.format(self.dim)
  1065. def _convert_attention_mask(attn_mask, dtype):
  1066. """
  1067. Convert the attention mask to the target dtype we expect.
  1068. Parameters:
  1069. attn_mask (Tensor, optional): A tensor used in multi-head attention
  1070. to prevents attention to some unwanted positions, usually the
  1071. paddings or the subsequent positions. It is a tensor with shape
  1072. broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`.
  1073. When the data type is bool, the unwanted positions have `False`
  1074. values and the others have `True` values. When the data type is
  1075. int, the unwanted positions have 0 values and the others have 1
  1076. values. When the data type is float, the unwanted positions have
  1077. `-INF` values and the others have 0 values. It can be None when
  1078. nothing wanted or needed to be prevented attention to. Default None.
  1079. dtype (VarType): The target type of `attn_mask` we expect.
  1080. Returns:
  1081. Tensor: A Tensor with shape same as input `attn_mask`, with data type `dtype`.
  1082. """
  1083. return nn.layer.transformer._convert_attention_mask(attn_mask, dtype)
  1084. class MultiHeadAttention(nn.Layer):
  1085. """
  1086. Attention mapps queries and a set of key-value pairs to outputs, and
  1087. Multi-Head Attention performs multiple parallel attention to jointly attending
  1088. to information from different representation subspaces.
  1089. Please refer to `Attention Is All You Need <https://arxiv.org/pdf/1706.03762.pdf>`_
  1090. for more details.
  1091. Parameters:
  1092. embed_dim (int): The expected feature size in the input and output.
  1093. num_heads (int): The number of heads in multi-head attention.
  1094. dropout (float, optional): The dropout probability used on attention
  1095. weights to drop some attention targets. 0 for no dropout. Default 0
  1096. kdim (int, optional): The feature size in key. If None, assumed equal to
  1097. `embed_dim`. Default None.
  1098. vdim (int, optional): The feature size in value. If None, assumed equal to
  1099. `embed_dim`. Default None.
  1100. need_weights (bool, optional): Indicate whether to return the attention
  1101. weights. Default False.
  1102. Examples:
  1103. .. code-block:: python
  1104. import paddle
  1105. # encoder input: [batch_size, sequence_length, d_model]
  1106. query = paddle.rand((2, 4, 128))
  1107. # self attention mask: [batch_size, num_heads, query_len, query_len]
  1108. attn_mask = paddle.rand((2, 2, 4, 4))
  1109. multi_head_attn = paddle.nn.MultiHeadAttention(128, 2)
  1110. output = multi_head_attn(query, None, None, attn_mask=attn_mask) # [2, 4, 128]
  1111. """
  1112. def __init__(self,
  1113. embed_dim,
  1114. num_heads,
  1115. dropout=0.,
  1116. kdim=None,
  1117. vdim=None,
  1118. need_weights=False):
  1119. super(MultiHeadAttention, self).__init__()
  1120. self.embed_dim = embed_dim
  1121. self.kdim = kdim if kdim is not None else embed_dim
  1122. self.vdim = vdim if vdim is not None else embed_dim
  1123. self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
  1124. self.num_heads = num_heads
  1125. self.dropout = dropout
  1126. self.need_weights = need_weights
  1127. self.head_dim = embed_dim // num_heads
  1128. assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
  1129. if self._qkv_same_embed_dim:
  1130. self.in_proj_weight = self.create_parameter(
  1131. shape=[embed_dim, 3 * embed_dim],
  1132. attr=None,
  1133. dtype=self._dtype,
  1134. is_bias=False)
  1135. self.in_proj_bias = self.create_parameter(
  1136. shape=[3 * embed_dim],
  1137. attr=None,
  1138. dtype=self._dtype,
  1139. is_bias=True)
  1140. else:
  1141. self.q_proj = nn.Linear(embed_dim, embed_dim)
  1142. self.k_proj = nn.Linear(self.kdim, embed_dim)
  1143. self.v_proj = nn.Linear(self.vdim, embed_dim)
  1144. self.out_proj = nn.Linear(embed_dim, embed_dim)
  1145. self._type_list = ('q_proj', 'k_proj', 'v_proj')
  1146. self._reset_parameters()
  1147. def _reset_parameters(self):
  1148. for p in self.parameters():
  1149. if p.dim() > 1:
  1150. xavier_uniform_(p)
  1151. else:
  1152. constant_(p)
  1153. def compute_qkv(self, tensor, index):
  1154. if self._qkv_same_embed_dim:
  1155. tensor = F.linear(
  1156. x=tensor,
  1157. weight=self.in_proj_weight[:, index * self.embed_dim:(
  1158. index + 1) * self.embed_dim],
  1159. bias=self.in_proj_bias[index * self.embed_dim:(index + 1) *
  1160. self.embed_dim]
  1161. if self.in_proj_bias is not None else None)
  1162. else:
  1163. tensor = getattr(self, self._type_list[index])(tensor)
  1164. tensor = tensor.reshape(
  1165. [0, 0, self.num_heads, self.head_dim]).transpose([0, 2, 1, 3])
  1166. return tensor
  1167. def forward(self, query, key=None, value=None, attn_mask=None):
  1168. r"""
  1169. Applies multi-head attention to map queries and a set of key-value pairs
  1170. to outputs.
  1171. Parameters:
  1172. query (Tensor): The queries for multi-head attention. It is a
  1173. tensor with shape `[batch_size, query_length, embed_dim]`. The
  1174. data type should be float32 or float64.
  1175. key (Tensor, optional): The keys for multi-head attention. It is
  1176. a tensor with shape `[batch_size, key_length, kdim]`. The
  1177. data type should be float32 or float64. If None, use `query` as
  1178. `key`. Default None.
  1179. value (Tensor, optional): The values for multi-head attention. It
  1180. is a tensor with shape `[batch_size, value_length, vdim]`.
  1181. The data type should be float32 or float64. If None, use `query` as
  1182. `value`. Default None.
  1183. attn_mask (Tensor, optional): A tensor used in multi-head attention
  1184. to prevents attention to some unwanted positions, usually the
  1185. paddings or the subsequent positions. It is a tensor with shape
  1186. broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`.
  1187. When the data type is bool, the unwanted positions have `False`
  1188. values and the others have `True` values. When the data type is
  1189. int, the unwanted positions have 0 values and the others have 1
  1190. values. When the data type is float, the unwanted positions have
  1191. `-INF` values and the others have 0 values. It can be None when
  1192. nothing wanted or needed to be prevented attention to. Default None.
  1193. Returns:
  1194. Tensor|tuple: It is a tensor that has the same shape and data type \
  1195. as `query`, representing attention output. Or a tuple if \
  1196. `need_weights` is True or `cache` is not None. If `need_weights` \
  1197. is True, except for attention output, the tuple also includes \
  1198. the attention weights tensor shaped `[batch_size, num_heads, query_length, key_length]`. \
  1199. If `cache` is not None, the tuple then includes the new cache \
  1200. having the same type as `cache`, and if it is `StaticCache`, it \
  1201. is same as the input `cache`, if it is `Cache`, the new cache \
  1202. reserves tensors concatanating raw tensors with intermediate \
  1203. results of current query.
  1204. """
  1205. key = query if key is None else key
  1206. value = query if value is None else value
  1207. # compute q ,k ,v
  1208. q, k, v = (self.compute_qkv(t, i)
  1209. for i, t in enumerate([query, key, value]))
  1210. # scale dot product attention
  1211. product = paddle.matmul(x=q, y=k, transpose_y=True)
  1212. scaling = float(self.head_dim)**-0.5
  1213. product = product * scaling
  1214. if attn_mask is not None:
  1215. # Support bool or int mask
  1216. attn_mask = _convert_attention_mask(attn_mask, product.dtype)
  1217. product = product + attn_mask
  1218. weights = F.softmax(product)
  1219. if self.dropout:
  1220. weights = F.dropout(
  1221. weights,
  1222. self.dropout,
  1223. training=self.training,
  1224. mode="upscale_in_train")
  1225. out = paddle.matmul(weights, v)
  1226. # combine heads
  1227. out = paddle.transpose(out, perm=[0, 2, 1, 3])
  1228. out = paddle.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
  1229. # project to output
  1230. out = self.out_proj(out)
  1231. outs = [out]
  1232. if self.need_weights:
  1233. outs.append(weights)
  1234. return out if len(outs) == 1 else tuple(outs)
  1235. @register
  1236. class ConvMixer(nn.Layer):
  1237. def __init__(
  1238. self,
  1239. dim,
  1240. depth,
  1241. kernel_size=3, ):
  1242. super().__init__()
  1243. self.dim = dim
  1244. self.depth = depth
  1245. self.kernel_size = kernel_size
  1246. self.mixer = self.conv_mixer(dim, depth, kernel_size)
  1247. def forward(self, x):
  1248. return self.mixer(x)
  1249. @staticmethod
  1250. def conv_mixer(
  1251. dim,
  1252. depth,
  1253. kernel_size, ):
  1254. Seq, ActBn = nn.Sequential, lambda x: Seq(x, nn.GELU(), nn.BatchNorm2D(dim))
  1255. Residual = type('Residual', (Seq, ),
  1256. {'forward': lambda self, x: self[0](x) + x})
  1257. return Seq(*[
  1258. Seq(Residual(
  1259. ActBn(
  1260. nn.Conv2D(
  1261. dim, dim, kernel_size, groups=dim, padding="same"))),
  1262. ActBn(nn.Conv2D(dim, dim, 1))) for i in range(depth)
  1263. ])