gfl_head.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import math
  18. import numpy as np
  19. import paddle
  20. import paddle.nn as nn
  21. import paddle.nn.functional as F
  22. from paddle import ParamAttr
  23. from paddle.nn.initializer import Normal, Constant
  24. from paddlex.ppdet.core.workspace import register
  25. from paddlex.ppdet.modeling.layers import ConvNormLayer
  26. from paddlex.ppdet.modeling.bbox_utils import distance2bbox, bbox2distance
  27. from paddlex.ppdet.data.transform.atss_assigner import bbox_overlaps
  28. class ScaleReg(nn.Layer):
  29. """
  30. Parameter for scaling the regression outputs.
  31. """
  32. def __init__(self):
  33. super(ScaleReg, self).__init__()
  34. self.scale_reg = self.create_parameter(
  35. shape=[1],
  36. attr=ParamAttr(initializer=Constant(value=1.)),
  37. dtype="float32")
  38. def forward(self, inputs):
  39. out = inputs * self.scale_reg
  40. return out
  41. class Integral(nn.Layer):
  42. """A fixed layer for calculating integral result from distribution.
  43. This layer calculates the target location by :math: `sum{P(y_i) * y_i}`,
  44. P(y_i) denotes the softmax vector that represents the discrete distribution
  45. y_i denotes the discrete set, usually {0, 1, 2, ..., reg_max}
  46. Args:
  47. reg_max (int): The maximal value of the discrete set. Default: 16. You
  48. may want to reset it according to your new dataset or related
  49. settings.
  50. """
  51. def __init__(self, reg_max=16):
  52. super(Integral, self).__init__()
  53. self.reg_max = reg_max
  54. self.register_buffer(
  55. 'project', paddle.linspace(0, self.reg_max, self.reg_max + 1))
  56. def forward(self, x):
  57. """Forward feature from the regression head to get integral result of
  58. bounding box location.
  59. Args:
  60. x (Tensor): Features of the regression head, shape (N, 4*(n+1)),
  61. n is self.reg_max.
  62. Returns:
  63. x (Tensor): Integral result of box locations, i.e., distance
  64. offsets from the box center in four directions, shape (N, 4).
  65. """
  66. x = F.softmax(x.reshape([-1, self.reg_max + 1]), axis=1)
  67. x = F.linear(x, self.project).reshape([-1, 4])
  68. return x
  69. @register
  70. class DGQP(nn.Layer):
  71. """Distribution-Guided Quality Predictor of GFocal head
  72. Args:
  73. reg_topk (int): top-k statistics of distribution to guide LQE
  74. reg_channels (int): hidden layer unit to generate LQE
  75. add_mean (bool): Whether to calculate the mean of top-k statistics
  76. """
  77. def __init__(self, reg_topk=4, reg_channels=64, add_mean=True):
  78. super(DGQP, self).__init__()
  79. self.reg_topk = reg_topk
  80. self.reg_channels = reg_channels
  81. self.add_mean = add_mean
  82. self.total_dim = reg_topk
  83. if add_mean:
  84. self.total_dim += 1
  85. self.reg_conv1 = self.add_sublayer(
  86. 'dgqp_reg_conv1',
  87. nn.Conv2D(
  88. in_channels=4 * self.total_dim,
  89. out_channels=self.reg_channels,
  90. kernel_size=1,
  91. weight_attr=ParamAttr(initializer=Normal(
  92. mean=0., std=0.01)),
  93. bias_attr=ParamAttr(initializer=Constant(value=0))))
  94. self.reg_conv2 = self.add_sublayer(
  95. 'dgqp_reg_conv2',
  96. nn.Conv2D(
  97. in_channels=self.reg_channels,
  98. out_channels=1,
  99. kernel_size=1,
  100. weight_attr=ParamAttr(initializer=Normal(
  101. mean=0., std=0.01)),
  102. bias_attr=ParamAttr(initializer=Constant(value=0))))
  103. def forward(self, x):
  104. """Forward feature from the regression head to get integral result of
  105. bounding box location.
  106. Args:
  107. x (Tensor): Features of the regression head, shape (N, 4*(n+1)),
  108. n is self.reg_max.
  109. Returns:
  110. x (Tensor): Integral result of box locations, i.e., distance
  111. offsets from the box center in four directions, shape (N, 4).
  112. """
  113. N, _, H, W = x.shape[:]
  114. prob = F.softmax(x.reshape([N, 4, -1, H, W]), axis=2)
  115. prob_topk, _ = prob.topk(self.reg_topk, axis=2)
  116. if self.add_mean:
  117. stat = paddle.concat(
  118. [prob_topk, prob_topk.mean(
  119. axis=2, keepdim=True)], axis=2)
  120. else:
  121. stat = prob_topk
  122. y = F.relu(self.reg_conv1(stat.reshape([N, -1, H, W])))
  123. y = F.sigmoid(self.reg_conv2(y))
  124. return y
  125. @register
  126. class GFLHead(nn.Layer):
  127. """
  128. GFLHead
  129. Args:
  130. conv_feat (object): Instance of 'FCOSFeat'
  131. num_classes (int): Number of classes
  132. fpn_stride (list): The stride of each FPN Layer
  133. prior_prob (float): Used to set the bias init for the class prediction layer
  134. loss_qfl (object):
  135. loss_dfl (object):
  136. loss_bbox (object):
  137. reg_max: Max value of integral set :math: `{0, ..., reg_max}`
  138. n QFL setting. Default: 16.
  139. """
  140. __inject__ = [
  141. 'conv_feat', 'dgqp_module', 'loss_qfl', 'loss_dfl', 'loss_bbox', 'nms'
  142. ]
  143. __shared__ = ['num_classes']
  144. def __init__(self,
  145. conv_feat='FCOSFeat',
  146. dgqp_module=None,
  147. num_classes=80,
  148. fpn_stride=[8, 16, 32, 64, 128],
  149. prior_prob=0.01,
  150. loss_qfl='QualityFocalLoss',
  151. loss_dfl='DistributionFocalLoss',
  152. loss_bbox='GIoULoss',
  153. reg_max=16,
  154. feat_in_chan=256,
  155. nms=None,
  156. nms_pre=1000,
  157. cell_offset=0):
  158. super(GFLHead, self).__init__()
  159. self.conv_feat = conv_feat
  160. self.dgqp_module = dgqp_module
  161. self.num_classes = num_classes
  162. self.fpn_stride = fpn_stride
  163. self.prior_prob = prior_prob
  164. self.loss_qfl = loss_qfl
  165. self.loss_dfl = loss_dfl
  166. self.loss_bbox = loss_bbox
  167. self.reg_max = reg_max
  168. self.feat_in_chan = feat_in_chan
  169. self.nms = nms
  170. self.nms_pre = nms_pre
  171. self.cell_offset = cell_offset
  172. self.use_sigmoid = self.loss_qfl.use_sigmoid
  173. if self.use_sigmoid:
  174. self.cls_out_channels = self.num_classes
  175. else:
  176. self.cls_out_channels = self.num_classes + 1
  177. conv_cls_name = "gfl_head_cls"
  178. bias_init_value = -math.log((1 - self.prior_prob) / self.prior_prob)
  179. self.gfl_head_cls = self.add_sublayer(
  180. conv_cls_name,
  181. nn.Conv2D(
  182. in_channels=self.feat_in_chan,
  183. out_channels=self.cls_out_channels,
  184. kernel_size=3,
  185. stride=1,
  186. padding=1,
  187. weight_attr=ParamAttr(initializer=Normal(
  188. mean=0., std=0.01)),
  189. bias_attr=ParamAttr(
  190. initializer=Constant(value=bias_init_value))))
  191. conv_reg_name = "gfl_head_reg"
  192. self.gfl_head_reg = self.add_sublayer(
  193. conv_reg_name,
  194. nn.Conv2D(
  195. in_channels=self.feat_in_chan,
  196. out_channels=4 * (self.reg_max + 1),
  197. kernel_size=3,
  198. stride=1,
  199. padding=1,
  200. weight_attr=ParamAttr(initializer=Normal(
  201. mean=0., std=0.01)),
  202. bias_attr=ParamAttr(initializer=Constant(value=0))))
  203. self.scales_regs = []
  204. for i in range(len(self.fpn_stride)):
  205. lvl = int(math.log(int(self.fpn_stride[i]), 2))
  206. feat_name = 'p{}_feat'.format(lvl)
  207. scale_reg = self.add_sublayer(feat_name, ScaleReg())
  208. self.scales_regs.append(scale_reg)
  209. self.distribution_project = Integral(self.reg_max)
  210. def forward(self, fpn_feats):
  211. assert len(fpn_feats) == len(
  212. self.fpn_stride
  213. ), "The size of fpn_feats is not equal to size of fpn_stride"
  214. cls_logits_list = []
  215. bboxes_reg_list = []
  216. for scale_reg, fpn_feat in zip(self.scales_regs, fpn_feats):
  217. conv_cls_feat, conv_reg_feat = self.conv_feat(fpn_feat)
  218. cls_logits = self.gfl_head_cls(conv_cls_feat)
  219. bbox_reg = scale_reg(self.gfl_head_reg(conv_reg_feat))
  220. if self.dgqp_module:
  221. quality_score = self.dgqp_module(bbox_reg)
  222. cls_logits = F.sigmoid(cls_logits) * quality_score
  223. if not self.training:
  224. cls_logits = F.sigmoid(cls_logits.transpose([0, 2, 3, 1]))
  225. bbox_reg = bbox_reg.transpose([0, 2, 3, 1])
  226. cls_logits_list.append(cls_logits)
  227. bboxes_reg_list.append(bbox_reg)
  228. return (cls_logits_list, bboxes_reg_list)
  229. def _images_to_levels(self, target, num_level_anchors):
  230. """
  231. Convert targets by image to targets by feature level.
  232. """
  233. level_targets = []
  234. start = 0
  235. for n in num_level_anchors:
  236. end = start + n
  237. level_targets.append(target[:, start:end].squeeze(0))
  238. start = end
  239. return level_targets
  240. def _grid_cells_to_center(self, grid_cells):
  241. """
  242. Get center location of each gird cell
  243. Args:
  244. grid_cells: grid cells of a feature map
  245. Returns:
  246. center points
  247. """
  248. cells_cx = (grid_cells[:, 2] + grid_cells[:, 0]) / 2
  249. cells_cy = (grid_cells[:, 3] + grid_cells[:, 1]) / 2
  250. return paddle.stack([cells_cx, cells_cy], axis=-1)
  251. def get_loss(self, gfl_head_outs, gt_meta):
  252. cls_logits, bboxes_reg = gfl_head_outs
  253. num_level_anchors = [
  254. featmap.shape[-2] * featmap.shape[-1] for featmap in cls_logits
  255. ]
  256. grid_cells_list = self._images_to_levels(gt_meta['grid_cells'],
  257. num_level_anchors)
  258. labels_list = self._images_to_levels(gt_meta['labels'],
  259. num_level_anchors)
  260. label_weights_list = self._images_to_levels(gt_meta['label_weights'],
  261. num_level_anchors)
  262. bbox_targets_list = self._images_to_levels(gt_meta['bbox_targets'],
  263. num_level_anchors)
  264. num_total_pos = sum(gt_meta['pos_num'])
  265. try:
  266. num_total_pos = paddle.distributed.all_reduce(num_total_pos.clone(
  267. )) / paddle.distributed.get_world_size()
  268. except:
  269. num_total_pos = max(num_total_pos, 1)
  270. loss_bbox_list, loss_dfl_list, loss_qfl_list, avg_factor = [], [], [], []
  271. for cls_score, bbox_pred, grid_cells, labels, label_weights, bbox_targets, stride in zip(
  272. cls_logits, bboxes_reg, grid_cells_list, labels_list,
  273. label_weights_list, bbox_targets_list, self.fpn_stride):
  274. grid_cells = grid_cells.reshape([-1, 4])
  275. cls_score = cls_score.transpose([0, 2, 3, 1]).reshape(
  276. [-1, self.cls_out_channels])
  277. bbox_pred = bbox_pred.transpose([0, 2, 3, 1]).reshape(
  278. [-1, 4 * (self.reg_max + 1)])
  279. bbox_targets = bbox_targets.reshape([-1, 4])
  280. labels = labels.reshape([-1])
  281. label_weights = label_weights.reshape([-1])
  282. bg_class_ind = self.num_classes
  283. pos_inds = paddle.nonzero(
  284. paddle.logical_and((labels >= 0), (labels < bg_class_ind)),
  285. as_tuple=False).squeeze(1)
  286. score = np.zeros(labels.shape)
  287. if len(pos_inds) > 0:
  288. pos_bbox_targets = paddle.gather(
  289. bbox_targets, pos_inds, axis=0)
  290. pos_bbox_pred = paddle.gather(bbox_pred, pos_inds, axis=0)
  291. pos_grid_cells = paddle.gather(grid_cells, pos_inds, axis=0)
  292. pos_grid_cell_centers = self._grid_cells_to_center(
  293. pos_grid_cells) / stride
  294. weight_targets = F.sigmoid(cls_score.detach())
  295. weight_targets = paddle.gather(
  296. weight_targets.max(axis=1, keepdim=True), pos_inds, axis=0)
  297. pos_bbox_pred_corners = self.distribution_project(
  298. pos_bbox_pred)
  299. pos_decode_bbox_pred = distance2bbox(pos_grid_cell_centers,
  300. pos_bbox_pred_corners)
  301. pos_decode_bbox_targets = pos_bbox_targets / stride
  302. bbox_iou = bbox_overlaps(
  303. pos_decode_bbox_pred.detach().numpy(),
  304. pos_decode_bbox_targets.detach().numpy(),
  305. is_aligned=True)
  306. score[pos_inds.numpy()] = bbox_iou
  307. pred_corners = pos_bbox_pred.reshape([-1, self.reg_max + 1])
  308. target_corners = bbox2distance(pos_grid_cell_centers,
  309. pos_decode_bbox_targets,
  310. self.reg_max).reshape([-1])
  311. # regression loss
  312. loss_bbox = paddle.sum(
  313. self.loss_bbox(pos_decode_bbox_pred,
  314. pos_decode_bbox_targets) * weight_targets)
  315. # dfl loss
  316. loss_dfl = self.loss_dfl(
  317. pred_corners,
  318. target_corners,
  319. weight=weight_targets.expand([-1, 4]).reshape([-1]),
  320. avg_factor=4.0)
  321. else:
  322. loss_bbox = bbox_pred.sum() * 0
  323. loss_dfl = bbox_pred.sum() * 0
  324. weight_targets = paddle.to_tensor([0], dtype='float32')
  325. # qfl loss
  326. score = paddle.to_tensor(score)
  327. loss_qfl = self.loss_qfl(
  328. cls_score, (labels, score),
  329. weight=label_weights,
  330. avg_factor=num_total_pos)
  331. loss_bbox_list.append(loss_bbox)
  332. loss_dfl_list.append(loss_dfl)
  333. loss_qfl_list.append(loss_qfl)
  334. avg_factor.append(weight_targets.sum())
  335. avg_factor = sum(avg_factor)
  336. try:
  337. avg_factor = paddle.distributed.all_reduce(avg_factor.clone())
  338. avg_factor = paddle.clip(
  339. avg_factor / paddle.distributed.get_world_size(), min=1)
  340. except:
  341. avg_factor = max(avg_factor.item(), 1)
  342. if avg_factor <= 0:
  343. loss_qfl = paddle.to_tensor(
  344. 0, dtype='float32', stop_gradient=False)
  345. loss_bbox = paddle.to_tensor(
  346. 0, dtype='float32', stop_gradient=False)
  347. loss_dfl = paddle.to_tensor(
  348. 0, dtype='float32', stop_gradient=False)
  349. else:
  350. losses_bbox = list(map(lambda x: x / avg_factor, loss_bbox_list))
  351. losses_dfl = list(map(lambda x: x / avg_factor, loss_dfl_list))
  352. loss_qfl = sum(loss_qfl_list)
  353. loss_bbox = sum(losses_bbox)
  354. loss_dfl = sum(losses_dfl)
  355. loss_states = dict(
  356. loss_qfl=loss_qfl, loss_bbox=loss_bbox, loss_dfl=loss_dfl)
  357. return loss_states
  358. def get_single_level_center_point(self,
  359. featmap_size,
  360. stride,
  361. cell_offset=0):
  362. """
  363. Generate pixel centers of a single stage feature map.
  364. Args:
  365. featmap_size: height and width of the feature map
  366. stride: down sample stride of the feature map
  367. Returns:
  368. y and x of the center points
  369. """
  370. h, w = featmap_size
  371. x_range = (paddle.arange(w, dtype='float32') + cell_offset) * stride
  372. y_range = (paddle.arange(h, dtype='float32') + cell_offset) * stride
  373. y, x = paddle.meshgrid(y_range, x_range)
  374. y = y.flatten()
  375. x = x.flatten()
  376. return y, x
  377. def get_bboxes_single(self,
  378. cls_scores,
  379. bbox_preds,
  380. img_shape,
  381. scale_factor,
  382. rescale=True,
  383. cell_offset=0):
  384. assert len(cls_scores) == len(bbox_preds)
  385. mlvl_bboxes = []
  386. mlvl_scores = []
  387. for stride, cls_score, bbox_pred in zip(self.fpn_stride, cls_scores,
  388. bbox_preds):
  389. featmap_size = [
  390. paddle.shape(cls_score)[0], paddle.shape(cls_score)[1]
  391. ]
  392. y, x = self.get_single_level_center_point(
  393. featmap_size, stride, cell_offset=cell_offset)
  394. center_points = paddle.stack([x, y], axis=-1)
  395. scores = cls_score.reshape([-1, self.cls_out_channels])
  396. bbox_pred = self.distribution_project(bbox_pred) * stride
  397. if scores.shape[0] > self.nms_pre:
  398. max_scores = scores.max(axis=1)
  399. _, topk_inds = max_scores.topk(self.nms_pre)
  400. center_points = center_points.gather(topk_inds)
  401. bbox_pred = bbox_pred.gather(topk_inds)
  402. scores = scores.gather(topk_inds)
  403. bboxes = distance2bbox(
  404. center_points, bbox_pred, max_shape=img_shape)
  405. mlvl_bboxes.append(bboxes)
  406. mlvl_scores.append(scores)
  407. mlvl_bboxes = paddle.concat(mlvl_bboxes)
  408. if rescale:
  409. # [h_scale, w_scale] to [w_scale, h_scale, w_scale, h_scale]
  410. im_scale = paddle.concat([scale_factor[::-1], scale_factor[::-1]])
  411. mlvl_bboxes /= im_scale
  412. mlvl_scores = paddle.concat(mlvl_scores)
  413. mlvl_scores = mlvl_scores.transpose([1, 0])
  414. return mlvl_bboxes, mlvl_scores
  415. def decode(self, cls_scores, bbox_preds, im_shape, scale_factor,
  416. cell_offset):
  417. batch_bboxes = []
  418. batch_scores = []
  419. for img_id in range(cls_scores[0].shape[0]):
  420. num_levels = len(cls_scores)
  421. cls_score_list = [cls_scores[i][img_id] for i in range(num_levels)]
  422. bbox_pred_list = [bbox_preds[i][img_id] for i in range(num_levels)]
  423. bboxes, scores = self.get_bboxes_single(
  424. cls_score_list,
  425. bbox_pred_list,
  426. im_shape[img_id],
  427. scale_factor[img_id],
  428. cell_offset=cell_offset)
  429. batch_bboxes.append(bboxes)
  430. batch_scores.append(scores)
  431. batch_bboxes = paddle.stack(batch_bboxes, axis=0)
  432. batch_scores = paddle.stack(batch_scores, axis=0)
  433. return batch_bboxes, batch_scores
  434. def post_process(self, gfl_head_outs, im_shape, scale_factor):
  435. cls_scores, bboxes_reg = gfl_head_outs
  436. bboxes, score = self.decode(cls_scores, bboxes_reg, im_shape,
  437. scale_factor, self.cell_offset)
  438. bbox_pred, bbox_num, _ = self.nms(bboxes, score)
  439. return bbox_pred, bbox_num