layers.py 42 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import six
  16. import numpy as np
  17. from numbers import Integral
  18. import paddle
  19. import paddle.nn as nn
  20. from paddle import ParamAttr
  21. from paddle import to_tensor
  22. from paddle.nn import Conv2D, BatchNorm2D, GroupNorm
  23. import paddle.nn.functional as F
  24. from paddle.nn.initializer import Normal, Constant, XavierUniform
  25. from paddle.regularizer import L2Decay
  26. from paddlex.ppdet.core.workspace import register, serializable
  27. from paddlex.ppdet.modeling.bbox_utils import delta2bbox
  28. from . import ops
  29. from paddle.vision.ops import DeformConv2D
  30. def _to_list(l):
  31. if isinstance(l, (list, tuple)):
  32. return list(l)
  33. return [l]
  34. class DeformableConvV2(nn.Layer):
  35. def __init__(self,
  36. in_channels,
  37. out_channels,
  38. kernel_size,
  39. stride=1,
  40. padding=0,
  41. dilation=1,
  42. groups=1,
  43. weight_attr=None,
  44. bias_attr=None,
  45. lr_scale=1,
  46. regularizer=None,
  47. skip_quant=False,
  48. dcn_bias_regularizer=L2Decay(0.),
  49. dcn_bias_lr_scale=2.):
  50. super(DeformableConvV2, self).__init__()
  51. self.offset_channel = 2 * kernel_size**2
  52. self.mask_channel = kernel_size**2
  53. if lr_scale == 1 and regularizer is None:
  54. offset_bias_attr = ParamAttr(initializer=Constant(0.))
  55. else:
  56. offset_bias_attr = ParamAttr(
  57. initializer=Constant(0.),
  58. learning_rate=lr_scale,
  59. regularizer=regularizer)
  60. self.conv_offset = nn.Conv2D(
  61. in_channels,
  62. 3 * kernel_size**2,
  63. kernel_size,
  64. stride=stride,
  65. padding=(kernel_size - 1) // 2,
  66. weight_attr=ParamAttr(initializer=Constant(0.0)),
  67. bias_attr=offset_bias_attr)
  68. if skip_quant:
  69. self.conv_offset.skip_quant = True
  70. if bias_attr:
  71. # in FCOS-DCN head, specifically need learning_rate and regularizer
  72. dcn_bias_attr = ParamAttr(
  73. initializer=Constant(value=0),
  74. regularizer=dcn_bias_regularizer,
  75. learning_rate=dcn_bias_lr_scale)
  76. else:
  77. # in ResNet backbone, do not need bias
  78. dcn_bias_attr = False
  79. self.conv_dcn = DeformConv2D(
  80. in_channels,
  81. out_channels,
  82. kernel_size,
  83. stride=stride,
  84. padding=(kernel_size - 1) // 2 * dilation,
  85. dilation=dilation,
  86. groups=groups,
  87. weight_attr=weight_attr,
  88. bias_attr=dcn_bias_attr)
  89. def forward(self, x):
  90. offset_mask = self.conv_offset(x)
  91. offset, mask = paddle.split(
  92. offset_mask,
  93. num_or_sections=[self.offset_channel, self.mask_channel],
  94. axis=1)
  95. mask = F.sigmoid(mask)
  96. y = self.conv_dcn(x, offset, mask=mask)
  97. return y
  98. class ConvNormLayer(nn.Layer):
  99. def __init__(self,
  100. ch_in,
  101. ch_out,
  102. filter_size,
  103. stride,
  104. groups=1,
  105. norm_type='bn',
  106. norm_decay=0.,
  107. norm_groups=32,
  108. use_dcn=False,
  109. bias_on=False,
  110. lr_scale=1.,
  111. freeze_norm=False,
  112. initializer=Normal(
  113. mean=0., std=0.01),
  114. skip_quant=False,
  115. dcn_lr_scale=2.,
  116. dcn_regularizer=L2Decay(0.)):
  117. super(ConvNormLayer, self).__init__()
  118. assert norm_type in ['bn', 'sync_bn', 'gn']
  119. if bias_on:
  120. bias_attr = ParamAttr(
  121. initializer=Constant(value=0.), learning_rate=lr_scale)
  122. else:
  123. bias_attr = False
  124. if not use_dcn:
  125. self.conv = nn.Conv2D(
  126. in_channels=ch_in,
  127. out_channels=ch_out,
  128. kernel_size=filter_size,
  129. stride=stride,
  130. padding=(filter_size - 1) // 2,
  131. groups=groups,
  132. weight_attr=ParamAttr(
  133. initializer=initializer, learning_rate=1.),
  134. bias_attr=bias_attr)
  135. if skip_quant:
  136. self.conv.skip_quant = True
  137. else:
  138. # in FCOS-DCN head, specifically need learning_rate and regularizer
  139. self.conv = DeformableConvV2(
  140. in_channels=ch_in,
  141. out_channels=ch_out,
  142. kernel_size=filter_size,
  143. stride=stride,
  144. padding=(filter_size - 1) // 2,
  145. groups=groups,
  146. weight_attr=ParamAttr(
  147. initializer=initializer, learning_rate=1.),
  148. bias_attr=True,
  149. lr_scale=dcn_lr_scale,
  150. regularizer=dcn_regularizer,
  151. skip_quant=skip_quant)
  152. norm_lr = 0. if freeze_norm else 1.
  153. param_attr = ParamAttr(
  154. learning_rate=norm_lr,
  155. regularizer=L2Decay(norm_decay) if norm_decay is not None else None)
  156. bias_attr = ParamAttr(
  157. learning_rate=norm_lr,
  158. regularizer=L2Decay(norm_decay) if norm_decay is not None else None)
  159. if norm_type == 'bn':
  160. self.norm = nn.BatchNorm2D(
  161. ch_out, weight_attr=param_attr, bias_attr=bias_attr)
  162. elif norm_type == 'sync_bn':
  163. self.norm = nn.SyncBatchNorm(
  164. ch_out, weight_attr=param_attr, bias_attr=bias_attr)
  165. elif norm_type == 'gn':
  166. self.norm = nn.GroupNorm(
  167. num_groups=norm_groups,
  168. num_channels=ch_out,
  169. weight_attr=param_attr,
  170. bias_attr=bias_attr)
  171. def forward(self, inputs):
  172. out = self.conv(inputs)
  173. out = self.norm(out)
  174. return out
  175. class LiteConv(nn.Layer):
  176. def __init__(self,
  177. in_channels,
  178. out_channels,
  179. stride=1,
  180. with_act=True,
  181. norm_type='sync_bn',
  182. name=None):
  183. super(LiteConv, self).__init__()
  184. self.lite_conv = nn.Sequential()
  185. conv1 = ConvNormLayer(
  186. in_channels,
  187. in_channels,
  188. filter_size=5,
  189. stride=stride,
  190. groups=in_channels,
  191. norm_type=norm_type,
  192. initializer=XavierUniform())
  193. conv2 = ConvNormLayer(
  194. in_channels,
  195. out_channels,
  196. filter_size=1,
  197. stride=stride,
  198. norm_type=norm_type,
  199. initializer=XavierUniform())
  200. conv3 = ConvNormLayer(
  201. out_channels,
  202. out_channels,
  203. filter_size=1,
  204. stride=stride,
  205. norm_type=norm_type,
  206. initializer=XavierUniform())
  207. conv4 = ConvNormLayer(
  208. out_channels,
  209. out_channels,
  210. filter_size=5,
  211. stride=stride,
  212. groups=out_channels,
  213. norm_type=norm_type,
  214. initializer=XavierUniform())
  215. conv_list = [conv1, conv2, conv3, conv4]
  216. self.lite_conv.add_sublayer('conv1', conv1)
  217. self.lite_conv.add_sublayer('relu6_1', nn.ReLU6())
  218. self.lite_conv.add_sublayer('conv2', conv2)
  219. if with_act:
  220. self.lite_conv.add_sublayer('relu6_2', nn.ReLU6())
  221. self.lite_conv.add_sublayer('conv3', conv3)
  222. self.lite_conv.add_sublayer('relu6_3', nn.ReLU6())
  223. self.lite_conv.add_sublayer('conv4', conv4)
  224. if with_act:
  225. self.lite_conv.add_sublayer('relu6_4', nn.ReLU6())
  226. def forward(self, inputs):
  227. out = self.lite_conv(inputs)
  228. return out
  229. @register
  230. @serializable
  231. class AnchorGeneratorSSD(object):
  232. def __init__(self,
  233. steps=[8, 16, 32, 64, 100, 300],
  234. aspect_ratios=[[2.], [2., 3.], [2., 3.], [2., 3.], [2.], [2.]],
  235. min_ratio=15,
  236. max_ratio=90,
  237. base_size=300,
  238. min_sizes=[30.0, 60.0, 111.0, 162.0, 213.0, 264.0],
  239. max_sizes=[60.0, 111.0, 162.0, 213.0, 264.0, 315.0],
  240. offset=0.5,
  241. flip=True,
  242. clip=False,
  243. min_max_aspect_ratios_order=False):
  244. self.steps = steps
  245. self.aspect_ratios = aspect_ratios
  246. self.min_ratio = min_ratio
  247. self.max_ratio = max_ratio
  248. self.base_size = base_size
  249. self.min_sizes = min_sizes
  250. self.max_sizes = max_sizes
  251. self.offset = offset
  252. self.flip = flip
  253. self.clip = clip
  254. self.min_max_aspect_ratios_order = min_max_aspect_ratios_order
  255. if self.min_sizes == [] and self.max_sizes == []:
  256. num_layer = len(aspect_ratios)
  257. step = int(
  258. math.floor(((self.max_ratio - self.min_ratio)) / (num_layer - 2
  259. )))
  260. for ratio in six.moves.range(self.min_ratio, self.max_ratio + 1,
  261. step):
  262. self.min_sizes.append(self.base_size * ratio / 100.)
  263. self.max_sizes.append(self.base_size * (ratio + step) / 100.)
  264. self.min_sizes = [self.base_size * .10] + self.min_sizes
  265. self.max_sizes = [self.base_size * .20] + self.max_sizes
  266. self.num_priors = []
  267. for aspect_ratio, min_size, max_size in zip(
  268. aspect_ratios, self.min_sizes, self.max_sizes):
  269. if isinstance(min_size, (list, tuple)):
  270. self.num_priors.append(
  271. len(_to_list(min_size)) + len(_to_list(max_size)))
  272. else:
  273. self.num_priors.append((len(aspect_ratio) * 2 + 1) * len(
  274. _to_list(min_size)) + len(_to_list(max_size)))
  275. def __call__(self, inputs, image):
  276. boxes = []
  277. for input, min_size, max_size, aspect_ratio, step in zip(
  278. inputs, self.min_sizes, self.max_sizes, self.aspect_ratios,
  279. self.steps):
  280. box, _ = ops.prior_box(
  281. input=input,
  282. image=image,
  283. min_sizes=_to_list(min_size),
  284. max_sizes=_to_list(max_size),
  285. aspect_ratios=aspect_ratio,
  286. flip=self.flip,
  287. clip=self.clip,
  288. steps=[step, step],
  289. offset=self.offset,
  290. min_max_aspect_ratios_order=self.min_max_aspect_ratios_order)
  291. boxes.append(paddle.reshape(box, [-1, 4]))
  292. return boxes
  293. @register
  294. @serializable
  295. class RCNNBox(object):
  296. __shared__ = ['num_classes']
  297. def __init__(self,
  298. prior_box_var=[10., 10., 5., 5.],
  299. code_type="decode_center_size",
  300. box_normalized=False,
  301. num_classes=80):
  302. super(RCNNBox, self).__init__()
  303. self.prior_box_var = prior_box_var
  304. self.code_type = code_type
  305. self.box_normalized = box_normalized
  306. self.num_classes = num_classes
  307. def __call__(self, bbox_head_out, rois, im_shape, scale_factor):
  308. bbox_pred = bbox_head_out[0]
  309. cls_prob = bbox_head_out[1]
  310. roi = rois[0]
  311. rois_num = rois[1]
  312. origin_shape = paddle.floor(im_shape / scale_factor + 0.5)
  313. scale_list = []
  314. origin_shape_list = []
  315. for idx, roi_per_im in enumerate(roi):
  316. rois_num_per_im = rois_num[idx]
  317. expand_im_shape = paddle.expand(im_shape[idx, :],
  318. [rois_num_per_im, 2])
  319. origin_shape_list.append(expand_im_shape)
  320. origin_shape = paddle.concat(origin_shape_list)
  321. # bbox_pred.shape: [N, C*4]
  322. # C=num_classes in faster/mask rcnn(bbox_head), C=1 in cascade rcnn(cascade_head)
  323. bbox = paddle.concat(roi)
  324. if bbox.shape[0] == 0:
  325. bbox = paddle.zeros([0, bbox_pred.shape[1]], dtype='float32')
  326. else:
  327. bbox = delta2bbox(bbox_pred, bbox, self.prior_box_var)
  328. scores = cls_prob[:, :-1]
  329. # bbox.shape: [N, C, 4]
  330. # bbox.shape[1] must be equal to scores.shape[1]
  331. bbox_num_class = bbox.shape[1]
  332. if bbox_num_class == 1:
  333. bbox = paddle.tile(bbox, [1, self.num_classes, 1])
  334. origin_h = paddle.unsqueeze(origin_shape[:, 0], axis=1)
  335. origin_w = paddle.unsqueeze(origin_shape[:, 1], axis=1)
  336. zeros = paddle.zeros_like(origin_h)
  337. x1 = paddle.maximum(paddle.minimum(bbox[:, :, 0], origin_w), zeros)
  338. y1 = paddle.maximum(paddle.minimum(bbox[:, :, 1], origin_h), zeros)
  339. x2 = paddle.maximum(paddle.minimum(bbox[:, :, 2], origin_w), zeros)
  340. y2 = paddle.maximum(paddle.minimum(bbox[:, :, 3], origin_h), zeros)
  341. bbox = paddle.stack([x1, y1, x2, y2], axis=-1)
  342. bboxes = (bbox, rois_num)
  343. return bboxes, scores
  344. @register
  345. @serializable
  346. class MultiClassNMS(object):
  347. def __init__(self,
  348. score_threshold=.05,
  349. nms_top_k=-1,
  350. keep_top_k=100,
  351. nms_threshold=.5,
  352. normalized=True,
  353. nms_eta=1.0,
  354. return_index=False,
  355. return_rois_num=True):
  356. super(MultiClassNMS, self).__init__()
  357. self.score_threshold = score_threshold
  358. self.nms_top_k = nms_top_k
  359. self.keep_top_k = keep_top_k
  360. self.nms_threshold = nms_threshold
  361. self.normalized = normalized
  362. self.nms_eta = nms_eta
  363. self.return_index = return_index
  364. self.return_rois_num = return_rois_num
  365. def __call__(self, bboxes, score, background_label=-1):
  366. """
  367. bboxes (Tensor|List[Tensor]): 1. (Tensor) Predicted bboxes with shape
  368. [N, M, 4], N is the batch size and M
  369. is the number of bboxes
  370. 2. (List[Tensor]) bboxes and bbox_num,
  371. bboxes have shape of [M, C, 4], C
  372. is the class number and bbox_num means
  373. the number of bboxes of each batch with
  374. shape [N,]
  375. score (Tensor): Predicted scores with shape [N, C, M] or [M, C]
  376. background_label (int): Ignore the background label; For example, RCNN
  377. is num_classes and YOLO is -1.
  378. """
  379. kwargs = self.__dict__.copy()
  380. if isinstance(bboxes, tuple):
  381. bboxes, bbox_num = bboxes
  382. kwargs.update({'rois_num': bbox_num})
  383. if background_label > -1:
  384. kwargs.update({'background_label': background_label})
  385. return ops.multiclass_nms(bboxes, score, **kwargs)
  386. @register
  387. @serializable
  388. class MatrixNMS(object):
  389. __append_doc__ = True
  390. def __init__(self,
  391. score_threshold=.05,
  392. post_threshold=.05,
  393. nms_top_k=-1,
  394. keep_top_k=100,
  395. use_gaussian=False,
  396. gaussian_sigma=2.,
  397. normalized=False,
  398. background_label=0):
  399. super(MatrixNMS, self).__init__()
  400. self.score_threshold = score_threshold
  401. self.post_threshold = post_threshold
  402. self.nms_top_k = nms_top_k
  403. self.keep_top_k = keep_top_k
  404. self.normalized = normalized
  405. self.use_gaussian = use_gaussian
  406. self.gaussian_sigma = gaussian_sigma
  407. self.background_label = background_label
  408. def __call__(self, bbox, score, *args):
  409. return ops.matrix_nms(
  410. bboxes=bbox,
  411. scores=score,
  412. score_threshold=self.score_threshold,
  413. post_threshold=self.post_threshold,
  414. nms_top_k=self.nms_top_k,
  415. keep_top_k=self.keep_top_k,
  416. use_gaussian=self.use_gaussian,
  417. gaussian_sigma=self.gaussian_sigma,
  418. background_label=self.background_label,
  419. normalized=self.normalized)
  420. @register
  421. @serializable
  422. class YOLOBox(object):
  423. __shared__ = ['num_classes']
  424. def __init__(self,
  425. num_classes=80,
  426. conf_thresh=0.005,
  427. downsample_ratio=32,
  428. clip_bbox=True,
  429. scale_x_y=1.):
  430. self.num_classes = num_classes
  431. self.conf_thresh = conf_thresh
  432. self.downsample_ratio = downsample_ratio
  433. self.clip_bbox = clip_bbox
  434. self.scale_x_y = scale_x_y
  435. def __call__(self,
  436. yolo_head_out,
  437. anchors,
  438. im_shape,
  439. scale_factor,
  440. var_weight=None):
  441. boxes_list = []
  442. scores_list = []
  443. origin_shape = im_shape / scale_factor
  444. origin_shape = paddle.cast(origin_shape, 'int32')
  445. for i, head_out in enumerate(yolo_head_out):
  446. boxes, scores = ops.yolo_box(head_out, origin_shape, anchors[i],
  447. self.num_classes, self.conf_thresh,
  448. self.downsample_ratio // 2**i,
  449. self.clip_bbox, self.scale_x_y)
  450. boxes_list.append(boxes)
  451. scores_list.append(paddle.transpose(scores, perm=[0, 2, 1]))
  452. yolo_boxes = paddle.concat(boxes_list, axis=1)
  453. yolo_scores = paddle.concat(scores_list, axis=2)
  454. return yolo_boxes, yolo_scores
  455. @register
  456. @serializable
  457. class SSDBox(object):
  458. def __init__(self, is_normalized=True):
  459. self.is_normalized = is_normalized
  460. self.norm_delta = float(not self.is_normalized)
  461. def __call__(self,
  462. preds,
  463. prior_boxes,
  464. im_shape,
  465. scale_factor,
  466. var_weight=None):
  467. boxes, scores = preds
  468. outputs = []
  469. for box, score, prior_box in zip(boxes, scores, prior_boxes):
  470. pb_w = prior_box[:, 2] - prior_box[:, 0] + self.norm_delta
  471. pb_h = prior_box[:, 3] - prior_box[:, 1] + self.norm_delta
  472. pb_x = prior_box[:, 0] + pb_w * 0.5
  473. pb_y = prior_box[:, 1] + pb_h * 0.5
  474. out_x = pb_x + box[:, :, 0] * pb_w * 0.1
  475. out_y = pb_y + box[:, :, 1] * pb_h * 0.1
  476. out_w = paddle.exp(box[:, :, 2] * 0.2) * pb_w
  477. out_h = paddle.exp(box[:, :, 3] * 0.2) * pb_h
  478. if self.is_normalized:
  479. h = paddle.unsqueeze(
  480. im_shape[:, 0] / scale_factor[:, 0], axis=-1)
  481. w = paddle.unsqueeze(
  482. im_shape[:, 1] / scale_factor[:, 1], axis=-1)
  483. output = paddle.stack(
  484. [(out_x - out_w / 2.) * w, (out_y - out_h / 2.) * h,
  485. (out_x + out_w / 2.) * w, (out_y + out_h / 2.) * h],
  486. axis=-1)
  487. else:
  488. output = paddle.stack(
  489. [
  490. out_x - out_w / 2., out_y - out_h / 2.,
  491. out_x + out_w / 2. - 1., out_y + out_h / 2. - 1.
  492. ],
  493. axis=-1)
  494. outputs.append(output)
  495. boxes = paddle.concat(outputs, axis=1)
  496. scores = F.softmax(paddle.concat(scores, axis=1))
  497. scores = paddle.transpose(scores, [0, 2, 1])
  498. return boxes, scores
  499. @register
  500. @serializable
  501. class AnchorGrid(object):
  502. """Generate anchor grid
  503. Args:
  504. image_size (int or list): input image size, may be a single integer or
  505. list of [h, w]. Default: 512
  506. min_level (int): min level of the feature pyramid. Default: 3
  507. max_level (int): max level of the feature pyramid. Default: 7
  508. anchor_base_scale: base anchor scale. Default: 4
  509. num_scales: number of anchor scales. Default: 3
  510. aspect_ratios: aspect ratios. default: [[1, 1], [1.4, 0.7], [0.7, 1.4]]
  511. """
  512. def __init__(self,
  513. image_size=512,
  514. min_level=3,
  515. max_level=7,
  516. anchor_base_scale=4,
  517. num_scales=3,
  518. aspect_ratios=[[1, 1], [1.4, 0.7], [0.7, 1.4]]):
  519. super(AnchorGrid, self).__init__()
  520. if isinstance(image_size, Integral):
  521. self.image_size = [image_size, image_size]
  522. else:
  523. self.image_size = image_size
  524. for dim in self.image_size:
  525. assert dim % 2 ** max_level == 0, \
  526. "image size should be multiple of the max level stride"
  527. self.min_level = min_level
  528. self.max_level = max_level
  529. self.anchor_base_scale = anchor_base_scale
  530. self.num_scales = num_scales
  531. self.aspect_ratios = aspect_ratios
  532. @property
  533. def base_cell(self):
  534. if not hasattr(self, '_base_cell'):
  535. self._base_cell = self.make_cell()
  536. return self._base_cell
  537. def make_cell(self):
  538. scales = [2**(i / self.num_scales) for i in range(self.num_scales)]
  539. scales = np.array(scales)
  540. ratios = np.array(self.aspect_ratios)
  541. ws = np.outer(scales, ratios[:, 0]).reshape(-1, 1)
  542. hs = np.outer(scales, ratios[:, 1]).reshape(-1, 1)
  543. anchors = np.hstack((-0.5 * ws, -0.5 * hs, 0.5 * ws, 0.5 * hs))
  544. return anchors
  545. def make_grid(self, stride):
  546. cell = self.base_cell * stride * self.anchor_base_scale
  547. x_steps = np.arange(stride // 2, self.image_size[1], stride)
  548. y_steps = np.arange(stride // 2, self.image_size[0], stride)
  549. offset_x, offset_y = np.meshgrid(x_steps, y_steps)
  550. offset_x = offset_x.flatten()
  551. offset_y = offset_y.flatten()
  552. offsets = np.stack((offset_x, offset_y, offset_x, offset_y), axis=-1)
  553. offsets = offsets[:, np.newaxis, :]
  554. return (cell + offsets).reshape(-1, 4)
  555. def generate(self):
  556. return [
  557. self.make_grid(2**l)
  558. for l in range(self.min_level, self.max_level + 1)
  559. ]
  560. def __call__(self):
  561. if not hasattr(self, '_anchor_vars'):
  562. anchor_vars = []
  563. helper = LayerHelper('anchor_grid')
  564. for idx, l in enumerate(range(self.min_level, self.max_level + 1)):
  565. stride = 2**l
  566. anchors = self.make_grid(stride)
  567. var = helper.create_parameter(
  568. attr=ParamAttr(name='anchors_{}'.format(idx)),
  569. shape=anchors.shape,
  570. dtype='float32',
  571. stop_gradient=True,
  572. default_initializer=NumpyArrayInitializer(anchors))
  573. anchor_vars.append(var)
  574. var.persistable = True
  575. self._anchor_vars = anchor_vars
  576. return self._anchor_vars
  577. @register
  578. @serializable
  579. class FCOSBox(object):
  580. __shared__ = ['num_classes']
  581. def __init__(self, num_classes=80):
  582. super(FCOSBox, self).__init__()
  583. self.num_classes = num_classes
  584. def _merge_hw(self, inputs, ch_type="channel_first"):
  585. """
  586. Merge h and w of the feature map into one dimension.
  587. Args:
  588. inputs (Tensor): Tensor of the input feature map
  589. ch_type (str): "channel_first" or "channel_last" style
  590. Return:
  591. new_shape (Tensor): The new shape after h and w merged
  592. """
  593. shape_ = paddle.shape(inputs)
  594. bs, ch, hi, wi = shape_[0], shape_[1], shape_[2], shape_[3]
  595. img_size = hi * wi
  596. img_size.stop_gradient = True
  597. if ch_type == "channel_first":
  598. new_shape = paddle.concat([bs, ch, img_size])
  599. elif ch_type == "channel_last":
  600. new_shape = paddle.concat([bs, img_size, ch])
  601. else:
  602. raise KeyError("Wrong ch_type %s" % ch_type)
  603. new_shape.stop_gradient = True
  604. return new_shape
  605. def _postprocessing_by_level(self, locations, box_cls, box_reg, box_ctn,
  606. scale_factor):
  607. """
  608. Postprocess each layer of the output with corresponding locations.
  609. Args:
  610. locations (Tensor): anchor points for current layer, [H*W, 2]
  611. box_cls (Tensor): categories prediction, [N, C, H, W],
  612. C is the number of classes
  613. box_reg (Tensor): bounding box prediction, [N, 4, H, W]
  614. box_ctn (Tensor): centerness prediction, [N, 1, H, W]
  615. scale_factor (Tensor): [h_scale, w_scale] for input images
  616. Return:
  617. box_cls_ch_last (Tensor): score for each category, in [N, C, M]
  618. C is the number of classes and M is the number of anchor points
  619. box_reg_decoding (Tensor): decoded bounding box, in [N, M, 4]
  620. last dimension is [x1, y1, x2, y2]
  621. """
  622. act_shape_cls = self._merge_hw(box_cls)
  623. box_cls_ch_last = paddle.reshape(x=box_cls, shape=act_shape_cls)
  624. box_cls_ch_last = F.sigmoid(box_cls_ch_last)
  625. act_shape_reg = self._merge_hw(box_reg)
  626. box_reg_ch_last = paddle.reshape(x=box_reg, shape=act_shape_reg)
  627. box_reg_ch_last = paddle.transpose(box_reg_ch_last, perm=[0, 2, 1])
  628. box_reg_decoding = paddle.stack(
  629. [
  630. locations[:, 0] - box_reg_ch_last[:, :, 0],
  631. locations[:, 1] - box_reg_ch_last[:, :, 1],
  632. locations[:, 0] + box_reg_ch_last[:, :, 2],
  633. locations[:, 1] + box_reg_ch_last[:, :, 3]
  634. ],
  635. axis=1)
  636. box_reg_decoding = paddle.transpose(box_reg_decoding, perm=[0, 2, 1])
  637. act_shape_ctn = self._merge_hw(box_ctn)
  638. box_ctn_ch_last = paddle.reshape(x=box_ctn, shape=act_shape_ctn)
  639. box_ctn_ch_last = F.sigmoid(box_ctn_ch_last)
  640. # recover the location to original image
  641. im_scale = paddle.concat([scale_factor, scale_factor], axis=1)
  642. box_reg_decoding = box_reg_decoding / im_scale
  643. box_cls_ch_last = box_cls_ch_last * box_ctn_ch_last
  644. return box_cls_ch_last, box_reg_decoding
  645. def __call__(self, locations, cls_logits, bboxes_reg, centerness,
  646. scale_factor):
  647. pred_boxes_ = []
  648. pred_scores_ = []
  649. for pts, cls, box, ctn in zip(locations, cls_logits, bboxes_reg,
  650. centerness):
  651. pred_scores_lvl, pred_boxes_lvl = self._postprocessing_by_level(
  652. pts, cls, box, ctn, scale_factor)
  653. pred_boxes_.append(pred_boxes_lvl)
  654. pred_scores_.append(pred_scores_lvl)
  655. pred_boxes = paddle.concat(pred_boxes_, axis=1)
  656. pred_scores = paddle.concat(pred_scores_, axis=2)
  657. return pred_boxes, pred_scores
  658. @register
  659. class TTFBox(object):
  660. __shared__ = ['down_ratio']
  661. def __init__(self, max_per_img=100, score_thresh=0.01, down_ratio=4):
  662. super(TTFBox, self).__init__()
  663. self.max_per_img = max_per_img
  664. self.score_thresh = score_thresh
  665. self.down_ratio = down_ratio
  666. def _simple_nms(self, heat, kernel=3):
  667. """
  668. Use maxpool to filter the max score, get local peaks.
  669. """
  670. pad = (kernel - 1) // 2
  671. hmax = F.max_pool2d(heat, kernel, stride=1, padding=pad)
  672. keep = paddle.cast(hmax == heat, 'float32')
  673. return heat * keep
  674. def _topk(self, scores):
  675. """
  676. Select top k scores and decode to get xy coordinates.
  677. """
  678. k = self.max_per_img
  679. shape_fm = paddle.shape(scores)
  680. shape_fm.stop_gradient = True
  681. cat, height, width = shape_fm[1], shape_fm[2], shape_fm[3]
  682. # batch size is 1
  683. scores_r = paddle.reshape(scores, [cat, -1])
  684. topk_scores, topk_inds = paddle.topk(scores_r, k)
  685. topk_scores, topk_inds = paddle.topk(scores_r, k)
  686. topk_ys = topk_inds // width
  687. topk_xs = topk_inds % width
  688. topk_score_r = paddle.reshape(topk_scores, [-1])
  689. topk_score, topk_ind = paddle.topk(topk_score_r, k)
  690. k_t = paddle.full(paddle.shape(topk_ind), k, dtype='int64')
  691. topk_clses = paddle.cast(paddle.floor_divide(topk_ind, k_t), 'float32')
  692. topk_inds = paddle.reshape(topk_inds, [-1])
  693. topk_ys = paddle.reshape(topk_ys, [-1, 1])
  694. topk_xs = paddle.reshape(topk_xs, [-1, 1])
  695. topk_inds = paddle.gather(topk_inds, topk_ind)
  696. topk_ys = paddle.gather(topk_ys, topk_ind)
  697. topk_xs = paddle.gather(topk_xs, topk_ind)
  698. return topk_score, topk_inds, topk_clses, topk_ys, topk_xs
  699. def __call__(self, hm, wh, im_shape, scale_factor):
  700. heatmap = F.sigmoid(hm)
  701. heat = self._simple_nms(heatmap)
  702. scores, inds, clses, ys, xs = self._topk(heat)
  703. ys = paddle.cast(ys, 'float32') * self.down_ratio
  704. xs = paddle.cast(xs, 'float32') * self.down_ratio
  705. scores = paddle.tensor.unsqueeze(scores, [1])
  706. clses = paddle.tensor.unsqueeze(clses, [1])
  707. wh_t = paddle.transpose(wh, [0, 2, 3, 1])
  708. wh = paddle.reshape(wh_t, [-1, paddle.shape(wh_t)[-1]])
  709. wh = paddle.gather(wh, inds)
  710. x1 = xs - wh[:, 0:1]
  711. y1 = ys - wh[:, 1:2]
  712. x2 = xs + wh[:, 2:3]
  713. y2 = ys + wh[:, 3:4]
  714. bboxes = paddle.concat([x1, y1, x2, y2], axis=1)
  715. scale_y = scale_factor[:, 0:1]
  716. scale_x = scale_factor[:, 1:2]
  717. scale_expand = paddle.concat(
  718. [scale_x, scale_y, scale_x, scale_y], axis=1)
  719. boxes_shape = paddle.shape(bboxes)
  720. boxes_shape.stop_gradient = True
  721. scale_expand = paddle.expand(scale_expand, shape=boxes_shape)
  722. bboxes = paddle.divide(bboxes, scale_expand)
  723. results = paddle.concat([clses, scores, bboxes], axis=1)
  724. # hack: append result with cls=-1 and score=1. to avoid all scores
  725. # are less than score_thresh which may cause error in gather.
  726. fill_r = paddle.to_tensor(np.array([[-1, 1, 0, 0, 0, 0]]))
  727. fill_r = paddle.cast(fill_r, results.dtype)
  728. results = paddle.concat([results, fill_r])
  729. scores = results[:, 1]
  730. valid_ind = paddle.nonzero(scores > self.score_thresh)
  731. results = paddle.gather(results, valid_ind)
  732. return results, paddle.shape(results)[0:1]
  733. @register
  734. @serializable
  735. class JDEBox(object):
  736. __shared__ = ['num_classes']
  737. def __init__(self, num_classes=1, conf_thresh=0.3, downsample_ratio=32):
  738. self.num_classes = num_classes
  739. self.conf_thresh = conf_thresh
  740. self.downsample_ratio = downsample_ratio
  741. def generate_anchor(self, nGh, nGw, anchor_wh):
  742. nA = len(anchor_wh)
  743. yv, xv = paddle.meshgrid([paddle.arange(nGh), paddle.arange(nGw)])
  744. mesh = paddle.stack(
  745. (xv, yv), axis=0).cast(dtype='float32') # 2 x nGh x nGw
  746. meshs = paddle.tile(mesh, [nA, 1, 1, 1])
  747. anchor_offset_mesh = anchor_wh[:, :, None][:, :, :, None].repeat(
  748. int(nGh), axis=-2).repeat(
  749. int(nGw), axis=-1)
  750. anchor_offset_mesh = paddle.to_tensor(
  751. anchor_offset_mesh.astype(np.float32))
  752. # nA x 2 x nGh x nGw
  753. anchor_mesh = paddle.concat([meshs, anchor_offset_mesh], axis=1)
  754. anchor_mesh = paddle.transpose(anchor_mesh,
  755. [0, 2, 3, 1]) # (nA x nGh x nGw) x 4
  756. return anchor_mesh
  757. def decode_delta(self, delta, fg_anchor_list):
  758. px, py, pw, ph = fg_anchor_list[:, 0], fg_anchor_list[:,1], \
  759. fg_anchor_list[:, 2], fg_anchor_list[:,3]
  760. dx, dy, dw, dh = delta[:, 0], delta[:, 1], delta[:, 2], delta[:, 3]
  761. gx = pw * dx + px
  762. gy = ph * dy + py
  763. gw = pw * paddle.exp(dw)
  764. gh = ph * paddle.exp(dh)
  765. gx1 = gx - gw * 0.5
  766. gy1 = gy - gh * 0.5
  767. gx2 = gx + gw * 0.5
  768. gy2 = gy + gh * 0.5
  769. return paddle.stack([gx1, gy1, gx2, gy2], axis=1)
  770. def decode_delta_map(self, delta_map, anchors):
  771. delta_map_shape = paddle.shape(delta_map)
  772. delta_map_shape.stop_gradient = True
  773. nB, nA, nGh, nGw, _ = delta_map_shape[:]
  774. anchor_mesh = self.generate_anchor(nGh, nGw, anchors)
  775. # only support bs=1
  776. anchor_mesh = paddle.unsqueeze(anchor_mesh, 0)
  777. pred_list = self.decode_delta(
  778. paddle.reshape(
  779. delta_map, shape=[-1, 4]),
  780. paddle.reshape(
  781. anchor_mesh, shape=[-1, 4]))
  782. pred_map = paddle.reshape(pred_list, shape=[nB, -1, 4])
  783. return pred_map
  784. def __call__(self, yolo_head_out, anchors):
  785. bbox_pred_list = []
  786. for i, head_out in enumerate(yolo_head_out):
  787. stride = self.downsample_ratio // 2**i
  788. anc_w, anc_h = anchors[i][0::2], anchors[i][1::2]
  789. anchor_vec = np.stack((anc_w, anc_h), axis=1) / stride
  790. nA = len(anc_w)
  791. boxes_shape = paddle.shape(head_out)
  792. boxes_shape.stop_gradient = True
  793. nB, nGh, nGw = boxes_shape[0], boxes_shape[-2], boxes_shape[-1]
  794. p = head_out.reshape((nB, nA, self.num_classes + 5, nGh, nGw))
  795. p = paddle.transpose(p, perm=[0, 1, 3, 4, 2]) # [nB, 4, 19, 34, 6]
  796. p_box = p[:, :, :, :, :4] # [nB, 4, 19, 34, 4]
  797. boxes = self.decode_delta_map(p_box, anchor_vec) # [nB, 4*19*34, 4]
  798. boxes = boxes * stride
  799. p_conf = paddle.transpose(
  800. p[:, :, :, :, 4:6], perm=[0, 4, 1, 2, 3]) # [nB, 2, 4, 19, 34]
  801. p_conf = F.softmax(
  802. p_conf,
  803. axis=1)[:, 1, :, :, :].unsqueeze(-1) # [nB, 4, 19, 34, 1]
  804. scores = paddle.reshape(p_conf, shape=[nB, -1, 1])
  805. bbox_pred_list.append(paddle.concat([boxes, scores], axis=-1))
  806. yolo_boxes_pred = paddle.concat(bbox_pred_list, axis=1)
  807. boxes_idx = paddle.nonzero(yolo_boxes_pred[:, :, -1] > self.conf_thresh)
  808. boxes_idx.stop_gradient = True
  809. if boxes_idx.shape[0] == 0: # TODO: deploy
  810. boxes_idx = paddle.to_tensor(np.array([[0]], dtype='int64'))
  811. yolo_boxes_out = paddle.to_tensor(
  812. np.array(
  813. [[[0.0, 0.0, 0.0, 0.0]]], dtype='float32'))
  814. yolo_scores_out = paddle.to_tensor(
  815. np.array(
  816. [[[0.0]]], dtype='float32'))
  817. return boxes_idx, yolo_boxes_out, yolo_scores_out
  818. yolo_boxes = paddle.gather_nd(yolo_boxes_pred, boxes_idx)
  819. yolo_boxes_out = paddle.reshape(yolo_boxes[:, :4], shape=[nB, -1, 4])
  820. yolo_scores_out = paddle.reshape(yolo_boxes[:, 4:5], shape=[nB, 1, -1])
  821. boxes_idx = boxes_idx[:, 1:]
  822. return boxes_idx, yolo_boxes_out, yolo_scores_out # [163], [1, 163, 4], [1, 1, 163]
  823. @register
  824. @serializable
  825. class MaskMatrixNMS(object):
  826. """
  827. Matrix NMS for multi-class masks.
  828. Args:
  829. update_threshold (float): Updated threshold of categroy score in second time.
  830. pre_nms_top_n (int): Number of total instance to be kept per image before NMS
  831. post_nms_top_n (int): Number of total instance to be kept per image after NMS.
  832. kernel (str): 'linear' or 'gaussian'.
  833. sigma (float): std in gaussian method.
  834. Input:
  835. seg_preds (Variable): shape (n, h, w), segmentation feature maps
  836. seg_masks (Variable): shape (n, h, w), segmentation feature maps
  837. cate_labels (Variable): shape (n), mask labels in descending order
  838. cate_scores (Variable): shape (n), mask scores in descending order
  839. sum_masks (Variable): a float tensor of the sum of seg_masks
  840. Returns:
  841. Variable: cate_scores, tensors of shape (n)
  842. """
  843. def __init__(self,
  844. update_threshold=0.05,
  845. pre_nms_top_n=500,
  846. post_nms_top_n=100,
  847. kernel='gaussian',
  848. sigma=2.0):
  849. super(MaskMatrixNMS, self).__init__()
  850. self.update_threshold = update_threshold
  851. self.pre_nms_top_n = pre_nms_top_n
  852. self.post_nms_top_n = post_nms_top_n
  853. self.kernel = kernel
  854. self.sigma = sigma
  855. def _sort_score(self, scores, top_num):
  856. if paddle.shape(scores)[0] > top_num:
  857. return paddle.topk(scores, top_num)[1]
  858. else:
  859. return paddle.argsort(scores, descending=True)
  860. def __call__(self,
  861. seg_preds,
  862. seg_masks,
  863. cate_labels,
  864. cate_scores,
  865. sum_masks=None):
  866. # sort and keep top nms_pre
  867. sort_inds = self._sort_score(cate_scores, self.pre_nms_top_n)
  868. seg_masks = paddle.gather(seg_masks, index=sort_inds)
  869. seg_preds = paddle.gather(seg_preds, index=sort_inds)
  870. sum_masks = paddle.gather(sum_masks, index=sort_inds)
  871. cate_scores = paddle.gather(cate_scores, index=sort_inds)
  872. cate_labels = paddle.gather(cate_labels, index=sort_inds)
  873. seg_masks = paddle.flatten(seg_masks, start_axis=1, stop_axis=-1)
  874. # inter.
  875. inter_matrix = paddle.mm(seg_masks, paddle.transpose(seg_masks, [1, 0]))
  876. n_samples = paddle.shape(cate_labels)
  877. # union.
  878. sum_masks_x = paddle.expand(sum_masks, shape=[n_samples, n_samples])
  879. # iou.
  880. iou_matrix = (inter_matrix / (
  881. sum_masks_x + paddle.transpose(sum_masks_x, [1, 0]) - inter_matrix))
  882. iou_matrix = paddle.triu(iou_matrix, diagonal=1)
  883. # label_specific matrix.
  884. cate_labels_x = paddle.expand(cate_labels, shape=[n_samples, n_samples])
  885. label_matrix = paddle.cast(
  886. (cate_labels_x == paddle.transpose(cate_labels_x, [1, 0])),
  887. 'float32')
  888. label_matrix = paddle.triu(label_matrix, diagonal=1)
  889. # IoU compensation
  890. compensate_iou = paddle.max((iou_matrix * label_matrix), axis=0)
  891. compensate_iou = paddle.expand(
  892. compensate_iou, shape=[n_samples, n_samples])
  893. compensate_iou = paddle.transpose(compensate_iou, [1, 0])
  894. # IoU decay
  895. decay_iou = iou_matrix * label_matrix
  896. # matrix nms
  897. if self.kernel == 'gaussian':
  898. decay_matrix = paddle.exp(-1 * self.sigma * (decay_iou**2))
  899. compensate_matrix = paddle.exp(-1 * self.sigma *
  900. (compensate_iou**2))
  901. decay_coefficient = paddle.min(decay_matrix / compensate_matrix,
  902. axis=0)
  903. elif self.kernel == 'linear':
  904. decay_matrix = (1 - decay_iou) / (1 - compensate_iou)
  905. decay_coefficient = paddle.min(decay_matrix, axis=0)
  906. else:
  907. raise NotImplementedError
  908. # update the score.
  909. cate_scores = cate_scores * decay_coefficient
  910. y = paddle.zeros(shape=paddle.shape(cate_scores), dtype='float32')
  911. keep = paddle.where(cate_scores >= self.update_threshold, cate_scores,
  912. y)
  913. keep = paddle.nonzero(keep)
  914. keep = paddle.squeeze(keep, axis=[1])
  915. # Prevent empty and increase fake data
  916. keep = paddle.concat(
  917. [keep, paddle.cast(paddle.shape(cate_scores)[0] - 1, 'int64')])
  918. seg_preds = paddle.gather(seg_preds, index=keep)
  919. cate_scores = paddle.gather(cate_scores, index=keep)
  920. cate_labels = paddle.gather(cate_labels, index=keep)
  921. # sort and keep top_k
  922. sort_inds = self._sort_score(cate_scores, self.post_nms_top_n)
  923. seg_preds = paddle.gather(seg_preds, index=sort_inds)
  924. cate_scores = paddle.gather(cate_scores, index=sort_inds)
  925. cate_labels = paddle.gather(cate_labels, index=sort_inds)
  926. return seg_preds, cate_scores, cate_labels
  927. def Conv2d(in_channels,
  928. out_channels,
  929. kernel_size,
  930. stride=1,
  931. padding=0,
  932. dilation=1,
  933. groups=1,
  934. bias=True,
  935. weight_init=Normal(std=0.001),
  936. bias_init=Constant(0.)):
  937. weight_attr = paddle.framework.ParamAttr(initializer=weight_init)
  938. if bias:
  939. bias_attr = paddle.framework.ParamAttr(initializer=bias_init)
  940. else:
  941. bias_attr = False
  942. conv = nn.Conv2D(
  943. in_channels,
  944. out_channels,
  945. kernel_size,
  946. stride,
  947. padding,
  948. dilation,
  949. groups,
  950. weight_attr=weight_attr,
  951. bias_attr=bias_attr)
  952. return conv
  953. def ConvTranspose2d(in_channels,
  954. out_channels,
  955. kernel_size,
  956. stride=1,
  957. padding=0,
  958. output_padding=0,
  959. groups=1,
  960. bias=True,
  961. dilation=1,
  962. weight_init=Normal(std=0.001),
  963. bias_init=Constant(0.)):
  964. weight_attr = paddle.framework.ParamAttr(initializer=weight_init)
  965. if bias:
  966. bias_attr = paddle.framework.ParamAttr(initializer=bias_init)
  967. else:
  968. bias_attr = False
  969. conv = nn.Conv2DTranspose(
  970. in_channels,
  971. out_channels,
  972. kernel_size,
  973. stride,
  974. padding,
  975. output_padding,
  976. dilation,
  977. groups,
  978. weight_attr=weight_attr,
  979. bias_attr=bias_attr)
  980. return conv
  981. def BatchNorm2d(num_features, eps=1e-05, momentum=0.9, affine=True):
  982. if not affine:
  983. weight_attr = False
  984. bias_attr = False
  985. else:
  986. weight_attr = None
  987. bias_attr = None
  988. batchnorm = nn.BatchNorm2D(
  989. num_features,
  990. momentum,
  991. eps,
  992. weight_attr=weight_attr,
  993. bias_attr=bias_attr)
  994. return batchnorm
  995. def ReLU():
  996. return nn.ReLU()
  997. def Upsample(scale_factor=None, mode='nearest', align_corners=False):
  998. return nn.Upsample(None, scale_factor, mode, align_corners)
  999. def MaxPool(kernel_size, stride, padding, ceil_mode=False):
  1000. return nn.MaxPool2D(kernel_size, stride, padding, ceil_mode=ceil_mode)
  1001. class Concat(nn.Layer):
  1002. def __init__(self, dim=0):
  1003. super(Concat, self).__init__()
  1004. self.dim = dim
  1005. def forward(self, inputs):
  1006. return paddle.concat(inputs, axis=self.dim)
  1007. def extra_repr(self):
  1008. return 'dim={}'.format(self.dim)