s2anet_head.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. from paddle import ParamAttr
  16. import paddle.nn as nn
  17. import paddle.nn.functional as F
  18. from paddle.nn.initializer import Normal, Constant
  19. from paddlex.ppdet.core.workspace import register
  20. from paddlex.ppdet.modeling import bbox_utils
  21. from paddlex.ppdet.modeling.proposal_generator.target_layer import RBoxAssigner
  22. import numpy as np
  23. class S2ANetAnchorGenerator(nn.Layer):
  24. """
  25. AnchorGenerator by paddle
  26. """
  27. def __init__(self, base_size, scales, ratios, scale_major=True, ctr=None):
  28. super(S2ANetAnchorGenerator, self).__init__()
  29. self.base_size = base_size
  30. self.scales = paddle.to_tensor(scales)
  31. self.ratios = paddle.to_tensor(ratios)
  32. self.scale_major = scale_major
  33. self.ctr = ctr
  34. self.base_anchors = self.gen_base_anchors()
  35. @property
  36. def num_base_anchors(self):
  37. return self.base_anchors.shape[0]
  38. def gen_base_anchors(self):
  39. w = self.base_size
  40. h = self.base_size
  41. if self.ctr is None:
  42. x_ctr = 0.5 * (w - 1)
  43. y_ctr = 0.5 * (h - 1)
  44. else:
  45. x_ctr, y_ctr = self.ctr
  46. h_ratios = paddle.sqrt(self.ratios)
  47. w_ratios = 1 / h_ratios
  48. if self.scale_major:
  49. ws = (w * w_ratios[:] * self.scales[:]).reshape([-1])
  50. hs = (h * h_ratios[:] * self.scales[:]).reshape([-1])
  51. else:
  52. ws = (w * self.scales[:] * w_ratios[:]).reshape([-1])
  53. hs = (h * self.scales[:] * h_ratios[:]).reshape([-1])
  54. base_anchors = paddle.stack(
  55. [
  56. x_ctr - 0.5 * (ws - 1), y_ctr - 0.5 * (hs - 1),
  57. x_ctr + 0.5 * (ws - 1), y_ctr + 0.5 * (hs - 1)
  58. ],
  59. axis=-1)
  60. base_anchors = paddle.round(base_anchors)
  61. return base_anchors
  62. def _meshgrid(self, x, y, row_major=True):
  63. yy, xx = paddle.meshgrid(x, y)
  64. yy = yy.reshape([-1])
  65. xx = xx.reshape([-1])
  66. if row_major:
  67. return xx, yy
  68. else:
  69. return yy, xx
  70. def forward(self, featmap_size, stride=16):
  71. # featmap_size*stride project it to original area
  72. base_anchors = self.base_anchors
  73. feat_h = featmap_size[0]
  74. feat_w = featmap_size[1]
  75. shift_x = paddle.arange(0, feat_w, 1, 'int32') * stride
  76. shift_y = paddle.arange(0, feat_h, 1, 'int32') * stride
  77. shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
  78. shifts = paddle.stack([shift_xx, shift_yy, shift_xx, shift_yy], axis=-1)
  79. all_anchors = base_anchors[:, :] + shifts[:, :]
  80. all_anchors = all_anchors.reshape([feat_h * feat_w, 4])
  81. return all_anchors
  82. def valid_flags(self, featmap_size, valid_size):
  83. feat_h, feat_w = featmap_size
  84. valid_h, valid_w = valid_size
  85. assert valid_h <= feat_h and valid_w <= feat_w
  86. valid_x = paddle.zeros([feat_w], dtype='uint8')
  87. valid_y = paddle.zeros([feat_h], dtype='uint8')
  88. valid_x[:valid_w] = 1
  89. valid_y[:valid_h] = 1
  90. valid_xx, valid_yy = self._meshgrid(valid_x, valid_y)
  91. valid = valid_xx & valid_yy
  92. valid = valid[:, None].expand(
  93. [valid.size(0), self.num_base_anchors]).reshape([-1])
  94. return valid
  95. class AlignConv(nn.Layer):
  96. def __init__(self, in_channels, out_channels, kernel_size=3, groups=1):
  97. super(AlignConv, self).__init__()
  98. self.kernel_size = kernel_size
  99. self.align_conv = paddle.vision.ops.DeformConv2D(
  100. in_channels,
  101. out_channels,
  102. kernel_size=self.kernel_size,
  103. padding=(self.kernel_size - 1) // 2,
  104. groups=groups,
  105. weight_attr=ParamAttr(initializer=Normal(0, 0.01)),
  106. bias_attr=None)
  107. @paddle.no_grad()
  108. def get_offset(self, anchors, featmap_size, stride):
  109. """
  110. Args:
  111. anchors: [M,5] xc,yc,w,h,angle
  112. featmap_size: (feat_h, feat_w)
  113. stride: 8
  114. Returns:
  115. """
  116. anchors = paddle.reshape(anchors, [-1, 5]) # (NA,5)
  117. dtype = anchors.dtype
  118. feat_h, feat_w = featmap_size
  119. pad = (self.kernel_size - 1) // 2
  120. idx = paddle.arange(-pad, pad + 1, dtype=dtype)
  121. yy, xx = paddle.meshgrid(idx, idx)
  122. xx = paddle.reshape(xx, [-1])
  123. yy = paddle.reshape(yy, [-1])
  124. # get sampling locations of default conv
  125. xc = paddle.arange(0, feat_w, dtype=dtype)
  126. yc = paddle.arange(0, feat_h, dtype=dtype)
  127. yc, xc = paddle.meshgrid(yc, xc)
  128. xc = paddle.reshape(xc, [-1, 1])
  129. yc = paddle.reshape(yc, [-1, 1])
  130. x_conv = xc + xx
  131. y_conv = yc + yy
  132. # get sampling locations of anchors
  133. # x_ctr, y_ctr, w, h, a = np.unbind(anchors, dim=1)
  134. x_ctr = anchors[:, 0]
  135. y_ctr = anchors[:, 1]
  136. w = anchors[:, 2]
  137. h = anchors[:, 3]
  138. a = anchors[:, 4]
  139. x_ctr = paddle.reshape(x_ctr, [x_ctr.shape[0], 1])
  140. y_ctr = paddle.reshape(y_ctr, [y_ctr.shape[0], 1])
  141. w = paddle.reshape(w, [w.shape[0], 1])
  142. h = paddle.reshape(h, [h.shape[0], 1])
  143. a = paddle.reshape(a, [a.shape[0], 1])
  144. x_ctr = x_ctr / stride
  145. y_ctr = y_ctr / stride
  146. w_s = w / stride
  147. h_s = h / stride
  148. cos, sin = paddle.cos(a), paddle.sin(a)
  149. dw, dh = w_s / self.kernel_size, h_s / self.kernel_size
  150. x, y = dw * xx, dh * yy
  151. xr = cos * x - sin * y
  152. yr = sin * x + cos * y
  153. x_anchor, y_anchor = xr + x_ctr, yr + y_ctr
  154. # get offset filed
  155. offset_x = x_anchor - x_conv
  156. offset_y = y_anchor - y_conv
  157. # x, y in anchors is opposite in image coordinates,
  158. # so we stack them with y, x other than x, y
  159. offset = paddle.stack([offset_y, offset_x], axis=-1)
  160. # NA,ks*ks*2
  161. # [NA, ks, ks, 2] --> [NA, ks*ks*2]
  162. offset = paddle.reshape(offset, [offset.shape[0], -1])
  163. # [NA, ks*ks*2] --> [ks*ks*2, NA]
  164. offset = paddle.transpose(offset, [1, 0])
  165. # [NA, ks*ks*2] --> [1, ks*ks*2, H, W]
  166. offset = paddle.reshape(offset, [1, -1, feat_h, feat_w])
  167. return offset
  168. def forward(self, x, refine_anchors, stride):
  169. featmap_size = (x.shape[2], x.shape[3])
  170. offset = self.get_offset(refine_anchors, featmap_size, stride)
  171. x = F.relu(self.align_conv(x, offset))
  172. return x
  173. @register
  174. class S2ANetHead(nn.Layer):
  175. """
  176. S2Anet head
  177. Args:
  178. stacked_convs (int): number of stacked_convs
  179. feat_in (int): input channels of feat
  180. feat_out (int): output channels of feat
  181. num_classes (int): num_classes
  182. anchor_strides (list): stride of anchors
  183. anchor_scales (list): scale of anchors
  184. anchor_ratios (list): ratios of anchors
  185. target_means (list): target_means
  186. target_stds (list): target_stds
  187. align_conv_type (str): align_conv_type ['Conv', 'AlignConv']
  188. align_conv_size (int): kernel size of align_conv
  189. use_sigmoid_cls (bool): use sigmoid_cls or not
  190. reg_loss_weight (list): loss weight for regression
  191. """
  192. __shared__ = ['num_classes']
  193. __inject__ = ['anchor_assign']
  194. def __init__(self,
  195. stacked_convs=2,
  196. feat_in=256,
  197. feat_out=256,
  198. num_classes=15,
  199. anchor_strides=[8, 16, 32, 64, 128],
  200. anchor_scales=[4],
  201. anchor_ratios=[1.0],
  202. target_means=0.0,
  203. target_stds=1.0,
  204. align_conv_type='AlignConv',
  205. align_conv_size=3,
  206. use_sigmoid_cls=True,
  207. anchor_assign=RBoxAssigner().__dict__,
  208. reg_loss_weight=[1.0, 1.0, 1.0, 1.0, 1.0],
  209. cls_loss_weight=[1.0, 1.0]):
  210. super(S2ANetHead, self).__init__()
  211. self.stacked_convs = stacked_convs
  212. self.feat_in = feat_in
  213. self.feat_out = feat_out
  214. self.anchor_list = None
  215. self.anchor_scales = anchor_scales
  216. self.anchor_ratios = anchor_ratios
  217. self.anchor_strides = anchor_strides
  218. self.anchor_base_sizes = list(anchor_strides)
  219. self.target_means = target_means
  220. self.target_stds = target_stds
  221. assert align_conv_type in ['AlignConv', 'Conv', 'DCN']
  222. self.align_conv_type = align_conv_type
  223. self.align_conv_size = align_conv_size
  224. self.use_sigmoid_cls = use_sigmoid_cls
  225. self.cls_out_channels = num_classes if self.use_sigmoid_cls else 1
  226. self.sampling = False
  227. self.anchor_assign = anchor_assign
  228. self.reg_loss_weight = reg_loss_weight
  229. self.cls_loss_weight = cls_loss_weight
  230. self.s2anet_head_out = None
  231. # anchor
  232. self.anchor_generators = []
  233. for anchor_base in self.anchor_base_sizes:
  234. self.anchor_generators.append(
  235. S2ANetAnchorGenerator(anchor_base, anchor_scales,
  236. anchor_ratios))
  237. self.anchor_generators = paddle.nn.LayerList(self.anchor_generators)
  238. self.add_sublayer('s2anet_anchor_gen', self.anchor_generators)
  239. self.fam_cls_convs = nn.Sequential()
  240. self.fam_reg_convs = nn.Sequential()
  241. for i in range(self.stacked_convs):
  242. chan_in = self.feat_in if i == 0 else self.feat_out
  243. self.fam_cls_convs.add_sublayer(
  244. 'fam_cls_conv_{}'.format(i),
  245. nn.Conv2D(
  246. in_channels=chan_in,
  247. out_channels=self.feat_out,
  248. kernel_size=3,
  249. padding=1,
  250. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  251. bias_attr=ParamAttr(initializer=Constant(0))))
  252. self.fam_cls_convs.add_sublayer('fam_cls_conv_{}_act'.format(i),
  253. nn.ReLU())
  254. self.fam_reg_convs.add_sublayer(
  255. 'fam_reg_conv_{}'.format(i),
  256. nn.Conv2D(
  257. in_channels=chan_in,
  258. out_channels=self.feat_out,
  259. kernel_size=3,
  260. padding=1,
  261. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  262. bias_attr=ParamAttr(initializer=Constant(0))))
  263. self.fam_reg_convs.add_sublayer('fam_reg_conv_{}_act'.format(i),
  264. nn.ReLU())
  265. self.fam_reg = nn.Conv2D(
  266. self.feat_out,
  267. 5,
  268. 1,
  269. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  270. bias_attr=ParamAttr(initializer=Constant(0)))
  271. prior_prob = 0.01
  272. bias_init = float(-np.log((1 - prior_prob) / prior_prob))
  273. self.fam_cls = nn.Conv2D(
  274. self.feat_out,
  275. self.cls_out_channels,
  276. 1,
  277. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  278. bias_attr=ParamAttr(initializer=Constant(bias_init)))
  279. if self.align_conv_type == "AlignConv":
  280. self.align_conv = AlignConv(self.feat_out, self.feat_out,
  281. self.align_conv_size)
  282. elif self.align_conv_type == "Conv":
  283. self.align_conv = nn.Conv2D(
  284. self.feat_out,
  285. self.feat_out,
  286. self.align_conv_size,
  287. padding=(self.align_conv_size - 1) // 2,
  288. bias_attr=ParamAttr(initializer=Constant(0)))
  289. elif self.align_conv_type == "DCN":
  290. self.align_conv_offset = nn.Conv2D(
  291. self.feat_out,
  292. 2 * self.align_conv_size**2,
  293. 1,
  294. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  295. bias_attr=ParamAttr(initializer=Constant(0)))
  296. self.align_conv = paddle.vision.ops.DeformConv2D(
  297. self.feat_out,
  298. self.feat_out,
  299. self.align_conv_size,
  300. padding=(self.align_conv_size - 1) // 2,
  301. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  302. bias_attr=False)
  303. self.or_conv = nn.Conv2D(
  304. self.feat_out,
  305. self.feat_out,
  306. kernel_size=3,
  307. padding=1,
  308. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  309. bias_attr=ParamAttr(initializer=Constant(0)))
  310. # ODM
  311. self.odm_cls_convs = nn.Sequential()
  312. self.odm_reg_convs = nn.Sequential()
  313. for i in range(self.stacked_convs):
  314. ch_in = self.feat_out
  315. # ch_in = int(self.feat_out / 8) if i == 0 else self.feat_out
  316. self.odm_cls_convs.add_sublayer(
  317. 'odm_cls_conv_{}'.format(i),
  318. nn.Conv2D(
  319. in_channels=ch_in,
  320. out_channels=self.feat_out,
  321. kernel_size=3,
  322. stride=1,
  323. padding=1,
  324. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  325. bias_attr=ParamAttr(initializer=Constant(0))))
  326. self.odm_cls_convs.add_sublayer('odm_cls_conv_{}_act'.format(i),
  327. nn.ReLU())
  328. self.odm_reg_convs.add_sublayer(
  329. 'odm_reg_conv_{}'.format(i),
  330. nn.Conv2D(
  331. in_channels=self.feat_out,
  332. out_channels=self.feat_out,
  333. kernel_size=3,
  334. stride=1,
  335. padding=1,
  336. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  337. bias_attr=ParamAttr(initializer=Constant(0))))
  338. self.odm_reg_convs.add_sublayer('odm_reg_conv_{}_act'.format(i),
  339. nn.ReLU())
  340. self.odm_cls = nn.Conv2D(
  341. self.feat_out,
  342. self.cls_out_channels,
  343. 3,
  344. padding=1,
  345. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  346. bias_attr=ParamAttr(initializer=Constant(bias_init)))
  347. self.odm_reg = nn.Conv2D(
  348. self.feat_out,
  349. 5,
  350. 3,
  351. padding=1,
  352. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  353. bias_attr=ParamAttr(initializer=Constant(0)))
  354. self.featmap_size_list = []
  355. self.init_anchors_list = []
  356. self.rbox_anchors_list = []
  357. self.refine_anchor_list = []
  358. def forward(self, feats):
  359. fam_reg_branch_list = []
  360. fam_cls_branch_list = []
  361. odm_reg_branch_list = []
  362. odm_cls_branch_list = []
  363. fam_reg1_branch_list = []
  364. self.featmap_size_list = []
  365. self.init_anchors_list = []
  366. self.rbox_anchors_list = []
  367. self.refine_anchor_list = []
  368. for i, feat in enumerate(feats):
  369. # prepare anchor
  370. featmap_size = paddle.shape(feat)[-2:]
  371. self.featmap_size_list.append(featmap_size)
  372. init_anchors = self.anchor_generators[i](featmap_size,
  373. self.anchor_strides[i])
  374. init_anchors = paddle.reshape(
  375. init_anchors, [featmap_size[0] * featmap_size[1], 4])
  376. self.init_anchors_list.append(init_anchors)
  377. rbox_anchors = self.rect2rbox(init_anchors)
  378. self.rbox_anchors_list.append(rbox_anchors)
  379. fam_cls_feat = self.fam_cls_convs(feat)
  380. fam_cls = self.fam_cls(fam_cls_feat)
  381. # [N, CLS, H, W] --> [N, H, W, CLS]
  382. fam_cls = fam_cls.transpose([0, 2, 3, 1])
  383. fam_cls_reshape = paddle.reshape(
  384. fam_cls, [fam_cls.shape[0], -1, self.cls_out_channels])
  385. fam_cls_branch_list.append(fam_cls_reshape)
  386. fam_reg_feat = self.fam_reg_convs(feat)
  387. fam_reg = self.fam_reg(fam_reg_feat)
  388. # [N, 5, H, W] --> [N, H, W, 5]
  389. fam_reg = fam_reg.transpose([0, 2, 3, 1])
  390. fam_reg_reshape = paddle.reshape(fam_reg, [fam_reg.shape[0], -1, 5])
  391. fam_reg_branch_list.append(fam_reg_reshape)
  392. # refine anchors
  393. fam_reg1 = fam_reg.clone()
  394. fam_reg1.stop_gradient = True
  395. rbox_anchors.stop_gradient = True
  396. fam_reg1_branch_list.append(fam_reg1)
  397. refine_anchor = self.bbox_decode(
  398. fam_reg1, rbox_anchors, self.target_stds, self.target_means)
  399. self.refine_anchor_list.append(refine_anchor)
  400. if self.align_conv_type == 'AlignConv':
  401. align_feat = self.align_conv(feat,
  402. refine_anchor.clone(),
  403. self.anchor_strides[i])
  404. elif self.align_conv_type == 'DCN':
  405. align_offset = self.align_conv_offset(feat)
  406. align_feat = self.align_conv(feat, align_offset)
  407. elif self.align_conv_type == 'Conv':
  408. align_feat = self.align_conv(feat)
  409. or_feat = self.or_conv(align_feat)
  410. odm_reg_feat = or_feat
  411. odm_cls_feat = or_feat
  412. odm_reg_feat = self.odm_reg_convs(odm_reg_feat)
  413. odm_cls_feat = self.odm_cls_convs(odm_cls_feat)
  414. odm_cls_score = self.odm_cls(odm_cls_feat)
  415. # [N, CLS, H, W] --> [N, H, W, CLS]
  416. odm_cls_score = odm_cls_score.transpose([0, 2, 3, 1])
  417. odm_cls_score_reshape = paddle.reshape(
  418. odm_cls_score,
  419. [odm_cls_score.shape[0], -1, self.cls_out_channels])
  420. odm_cls_branch_list.append(odm_cls_score_reshape)
  421. odm_bbox_pred = self.odm_reg(odm_reg_feat)
  422. # [N, 5, H, W] --> [N, H, W, 5]
  423. odm_bbox_pred = odm_bbox_pred.transpose([0, 2, 3, 1])
  424. odm_bbox_pred_reshape = paddle.reshape(
  425. odm_bbox_pred, [odm_bbox_pred.shape[0], -1, 5])
  426. odm_reg_branch_list.append(odm_bbox_pred_reshape)
  427. self.s2anet_head_out = (fam_cls_branch_list, fam_reg_branch_list,
  428. odm_cls_branch_list, odm_reg_branch_list)
  429. return self.s2anet_head_out
  430. def rect2rbox(self, bboxes):
  431. """
  432. :param bboxes: shape (n, 4) (xmin, ymin, xmax, ymax)
  433. :return: dbboxes: shape (n, 5) (x_ctr, y_ctr, w, h, angle)
  434. """
  435. num_boxes = paddle.shape(bboxes)[0]
  436. x_ctr = (bboxes[:, 2] + bboxes[:, 0]) / 2.0
  437. y_ctr = (bboxes[:, 3] + bboxes[:, 1]) / 2.0
  438. edges1 = paddle.abs(bboxes[:, 2] - bboxes[:, 0])
  439. edges2 = paddle.abs(bboxes[:, 3] - bboxes[:, 1])
  440. rbox_w = paddle.maximum(edges1, edges2)
  441. rbox_h = paddle.minimum(edges1, edges2)
  442. # set angle
  443. inds = edges1 < edges2
  444. inds = paddle.cast(inds, 'int32')
  445. inds1 = inds * paddle.arange(0, num_boxes)
  446. rboxes_angle = inds1 * np.pi / 2.0
  447. rboxes = paddle.stack(
  448. (x_ctr, y_ctr, rbox_w, rbox_h, rboxes_angle), axis=1)
  449. return rboxes
  450. # deltas to rbox
  451. def delta2rbox(self, rrois, deltas, means, stds, wh_ratio_clip=1e-6):
  452. """
  453. :param rrois: (cx, cy, w, h, theta)
  454. :param deltas: (dx, dy, dw, dh, dtheta)
  455. :param means: means of anchor
  456. :param stds: stds of anchor
  457. :param wh_ratio_clip: clip threshold of wh_ratio
  458. :return:
  459. """
  460. deltas = paddle.reshape(deltas, [-1, 5])
  461. rrois = paddle.reshape(rrois, [-1, 5])
  462. pd_means = paddle.ones(shape=[5]) * means
  463. pd_stds = paddle.ones(shape=[5]) * stds
  464. denorm_deltas = deltas * pd_stds + pd_means
  465. dx = denorm_deltas[:, 0]
  466. dy = denorm_deltas[:, 1]
  467. dw = denorm_deltas[:, 2]
  468. dh = denorm_deltas[:, 3]
  469. dangle = denorm_deltas[:, 4]
  470. max_ratio = np.abs(np.log(wh_ratio_clip))
  471. dw = paddle.clip(dw, min=-max_ratio, max=max_ratio)
  472. dh = paddle.clip(dh, min=-max_ratio, max=max_ratio)
  473. rroi_x = rrois[:, 0]
  474. rroi_y = rrois[:, 1]
  475. rroi_w = rrois[:, 2]
  476. rroi_h = rrois[:, 3]
  477. rroi_angle = rrois[:, 4]
  478. gx = dx * rroi_w * paddle.cos(rroi_angle) - dy * rroi_h * paddle.sin(
  479. rroi_angle) + rroi_x
  480. gy = dx * rroi_w * paddle.sin(rroi_angle) + dy * rroi_h * paddle.cos(
  481. rroi_angle) + rroi_y
  482. gw = rroi_w * dw.exp()
  483. gh = rroi_h * dh.exp()
  484. ga = np.pi * dangle + rroi_angle
  485. ga = (ga + np.pi / 4) % np.pi - np.pi / 4
  486. bboxes = paddle.stack([gx, gy, gw, gh, ga], axis=-1)
  487. return bboxes
  488. def bbox_decode(self, bbox_preds, anchors, stds, means, wh_ratio_clip=1e-6):
  489. """decode bbox from deltas
  490. Args:
  491. bbox_preds: bbox_preds, shape=[N,H,W,5]
  492. anchors: anchors, shape=[H,W,5]
  493. return:
  494. bboxes: return decoded bboxes, shape=[N*H*W,5]
  495. """
  496. num_imgs, H, W, _ = bbox_preds.shape
  497. bbox_delta = paddle.reshape(bbox_preds, [-1, 5])
  498. bboxes = self.delta2rbox(anchors, bbox_delta, means, stds,
  499. wh_ratio_clip)
  500. return bboxes
  501. def get_prediction(self, nms_pre):
  502. refine_anchors = self.refine_anchor_list
  503. fam_cls_branch_list, fam_reg_branch_list, odm_cls_branch_list, odm_reg_branch_list = self.s2anet_head_out
  504. pred_scores, pred_bboxes = self.get_bboxes(
  505. odm_cls_branch_list,
  506. odm_reg_branch_list,
  507. refine_anchors,
  508. nms_pre,
  509. cls_out_channels=self.cls_out_channels,
  510. use_sigmoid_cls=self.use_sigmoid_cls)
  511. return pred_scores, pred_bboxes
  512. def smooth_l1_loss(self, pred, label, delta=1.0 / 9.0):
  513. """
  514. Args:
  515. pred: pred score
  516. label: label
  517. delta: delta
  518. Returns: loss
  519. """
  520. assert pred.shape == label.shape and label.numel() > 0
  521. assert delta > 0
  522. diff = paddle.abs(pred - label)
  523. loss = paddle.where(diff < delta, 0.5 * diff * diff / delta,
  524. diff - 0.5 * delta)
  525. return loss
  526. def get_fam_loss(self, fam_target, s2anet_head_out):
  527. (feat_labels, feat_label_weights, feat_bbox_targets, feat_bbox_weights,
  528. pos_inds, neg_inds) = fam_target
  529. fam_cls_score, fam_bbox_pred = s2anet_head_out
  530. # step1: sample count
  531. num_total_samples = len(pos_inds) + len(
  532. neg_inds) if self.sampling else len(pos_inds)
  533. num_total_samples = max(1, num_total_samples)
  534. # step2: calc cls loss
  535. feat_labels = feat_labels.reshape(-1)
  536. feat_label_weights = feat_label_weights.reshape(-1)
  537. fam_cls_score = paddle.squeeze(fam_cls_score, axis=0)
  538. fam_cls_score1 = fam_cls_score
  539. # gt_classes 0~14(data), feat_labels 0~14, sigmoid_focal_loss need class>=1
  540. feat_labels = feat_labels + 1
  541. feat_labels = paddle.to_tensor(feat_labels)
  542. feat_labels_one_hot = F.one_hot(feat_labels, self.cls_out_channels + 1)
  543. feat_labels_one_hot = feat_labels_one_hot[:, 1:]
  544. feat_labels_one_hot.stop_gradient = True
  545. num_total_samples = paddle.to_tensor(
  546. num_total_samples, dtype='float32', stop_gradient=True)
  547. fam_cls = F.sigmoid_focal_loss(
  548. fam_cls_score1,
  549. feat_labels_one_hot,
  550. normalizer=num_total_samples,
  551. reduction='none')
  552. feat_label_weights = feat_label_weights.reshape(
  553. feat_label_weights.shape[0], 1)
  554. feat_label_weights = np.repeat(
  555. feat_label_weights, self.cls_out_channels, axis=1)
  556. feat_label_weights = paddle.to_tensor(
  557. feat_label_weights, stop_gradient=True)
  558. fam_cls = fam_cls * feat_label_weights
  559. fam_cls_total = paddle.sum(fam_cls)
  560. # step3: regression loss
  561. feat_bbox_targets = paddle.to_tensor(
  562. feat_bbox_targets, dtype='float32', stop_gradient=True)
  563. feat_bbox_targets = paddle.reshape(feat_bbox_targets, [-1, 5])
  564. fam_bbox_pred = paddle.squeeze(fam_bbox_pred, axis=0)
  565. fam_bbox_pred = paddle.reshape(fam_bbox_pred, [-1, 5])
  566. fam_bbox = self.smooth_l1_loss(fam_bbox_pred, feat_bbox_targets)
  567. loss_weight = paddle.to_tensor(
  568. self.reg_loss_weight, dtype='float32', stop_gradient=True)
  569. fam_bbox = paddle.multiply(fam_bbox, loss_weight)
  570. feat_bbox_weights = paddle.to_tensor(
  571. feat_bbox_weights, stop_gradient=True)
  572. fam_bbox = fam_bbox * feat_bbox_weights
  573. fam_bbox_total = paddle.sum(fam_bbox) / num_total_samples
  574. fam_cls_loss_weight = paddle.to_tensor(
  575. self.cls_loss_weight[0], dtype='float32', stop_gradient=True)
  576. fam_cls_loss = fam_cls_total * fam_cls_loss_weight
  577. fam_reg_loss = paddle.add_n(fam_bbox_total)
  578. return fam_cls_loss, fam_reg_loss
  579. def get_odm_loss(self, odm_target, s2anet_head_out):
  580. (feat_labels, feat_label_weights, feat_bbox_targets, feat_bbox_weights,
  581. pos_inds, neg_inds) = odm_target
  582. odm_cls_score, odm_bbox_pred = s2anet_head_out
  583. # step1: sample count
  584. num_total_samples = len(pos_inds) + len(
  585. neg_inds) if self.sampling else len(pos_inds)
  586. num_total_samples = max(1, num_total_samples)
  587. # step2: calc cls loss
  588. feat_labels = feat_labels.reshape(-1)
  589. feat_label_weights = feat_label_weights.reshape(-1)
  590. odm_cls_score = paddle.squeeze(odm_cls_score, axis=0)
  591. odm_cls_score1 = odm_cls_score
  592. # gt_classes 0~14(data), feat_labels 0~14, sigmoid_focal_loss need class>=1
  593. # for debug 0426
  594. feat_labels = feat_labels + 1
  595. feat_labels = paddle.to_tensor(feat_labels)
  596. feat_labels_one_hot = F.one_hot(feat_labels, self.cls_out_channels + 1)
  597. feat_labels_one_hot = feat_labels_one_hot[:, 1:]
  598. feat_labels_one_hot.stop_gradient = True
  599. num_total_samples = paddle.to_tensor(
  600. num_total_samples, dtype='float32', stop_gradient=True)
  601. odm_cls = F.sigmoid_focal_loss(
  602. odm_cls_score1,
  603. feat_labels_one_hot,
  604. normalizer=num_total_samples,
  605. reduction='none')
  606. feat_label_weights = feat_label_weights.reshape(
  607. feat_label_weights.shape[0], 1)
  608. feat_label_weights = np.repeat(
  609. feat_label_weights, self.cls_out_channels, axis=1)
  610. feat_label_weights = paddle.to_tensor(
  611. feat_label_weights, stop_gradient=True)
  612. odm_cls = odm_cls * feat_label_weights
  613. odm_cls_total = paddle.sum(odm_cls)
  614. # step3: regression loss
  615. feat_bbox_targets = paddle.to_tensor(
  616. feat_bbox_targets, dtype='float32', stop_gradient=True)
  617. feat_bbox_targets = paddle.reshape(feat_bbox_targets, [-1, 5])
  618. odm_bbox_pred = paddle.squeeze(odm_bbox_pred, axis=0)
  619. odm_bbox_pred = paddle.reshape(odm_bbox_pred, [-1, 5])
  620. odm_bbox = self.smooth_l1_loss(odm_bbox_pred, feat_bbox_targets)
  621. loss_weight = paddle.to_tensor(
  622. self.reg_loss_weight, dtype='float32', stop_gradient=True)
  623. odm_bbox = paddle.multiply(odm_bbox, loss_weight)
  624. feat_bbox_weights = paddle.to_tensor(
  625. feat_bbox_weights, stop_gradient=True)
  626. odm_bbox = odm_bbox * feat_bbox_weights
  627. odm_bbox_total = paddle.sum(odm_bbox) / num_total_samples
  628. odm_cls_loss_weight = paddle.to_tensor(
  629. self.cls_loss_weight[0], dtype='float32', stop_gradient=True)
  630. odm_cls_loss = odm_cls_total * odm_cls_loss_weight
  631. odm_reg_loss = paddle.add_n(odm_bbox_total)
  632. return odm_cls_loss, odm_reg_loss
  633. def get_loss(self, inputs):
  634. # inputs: im_id image im_shape scale_factor gt_bbox gt_class is_crowd
  635. # compute loss
  636. fam_cls_loss_lst = []
  637. fam_reg_loss_lst = []
  638. odm_cls_loss_lst = []
  639. odm_reg_loss_lst = []
  640. im_shape = inputs['im_shape']
  641. for im_id in range(im_shape.shape[0]):
  642. np_im_shape = inputs['im_shape'][im_id].numpy()
  643. np_scale_factor = inputs['scale_factor'][im_id].numpy()
  644. # data_format: (xc, yc, w, h, theta)
  645. gt_bboxes = inputs['gt_rbox'][im_id].numpy()
  646. gt_labels = inputs['gt_class'][im_id].numpy()
  647. is_crowd = inputs['is_crowd'][im_id].numpy()
  648. gt_labels = gt_labels + 1
  649. # FAM
  650. for idx, rbox_anchors in enumerate(self.rbox_anchors_list):
  651. rbox_anchors = rbox_anchors.numpy()
  652. rbox_anchors = rbox_anchors.reshape(-1, 5)
  653. im_fam_target = self.anchor_assign(rbox_anchors, gt_bboxes,
  654. gt_labels, is_crowd)
  655. # feat
  656. fam_cls_feat = self.s2anet_head_out[0][idx][im_id]
  657. fam_reg_feat = self.s2anet_head_out[1][idx][im_id]
  658. im_s2anet_fam_feat = (fam_cls_feat, fam_reg_feat)
  659. im_fam_cls_loss, im_fam_reg_loss = self.get_fam_loss(
  660. im_fam_target, im_s2anet_fam_feat)
  661. fam_cls_loss_lst.append(im_fam_cls_loss)
  662. fam_reg_loss_lst.append(im_fam_reg_loss)
  663. # ODM
  664. for idx, refine_anchors in enumerate(self.refine_anchor_list):
  665. refine_anchors = refine_anchors.numpy()
  666. refine_anchors = refine_anchors.reshape(-1, 5)
  667. im_odm_target = self.anchor_assign(refine_anchors, gt_bboxes,
  668. gt_labels, is_crowd)
  669. odm_cls_feat = self.s2anet_head_out[2][idx][im_id]
  670. odm_reg_feat = self.s2anet_head_out[3][idx][im_id]
  671. im_s2anet_odm_feat = (odm_cls_feat, odm_reg_feat)
  672. im_odm_cls_loss, im_odm_reg_loss = self.get_odm_loss(
  673. im_odm_target, im_s2anet_odm_feat)
  674. odm_cls_loss_lst.append(im_odm_cls_loss)
  675. odm_reg_loss_lst.append(im_odm_reg_loss)
  676. fam_cls_loss = paddle.add_n(fam_cls_loss_lst)
  677. fam_reg_loss = paddle.add_n(fam_reg_loss_lst)
  678. odm_cls_loss = paddle.add_n(odm_cls_loss_lst)
  679. odm_reg_loss = paddle.add_n(odm_reg_loss_lst)
  680. return {
  681. 'fam_cls_loss': fam_cls_loss,
  682. 'fam_reg_loss': fam_reg_loss,
  683. 'odm_cls_loss': odm_cls_loss,
  684. 'odm_reg_loss': odm_reg_loss
  685. }
  686. def get_bboxes(self, cls_score_list, bbox_pred_list, mlvl_anchors, nms_pre,
  687. cls_out_channels, use_sigmoid_cls):
  688. assert len(cls_score_list) == len(bbox_pred_list) == len(mlvl_anchors)
  689. mlvl_bboxes = []
  690. mlvl_scores = []
  691. idx = 0
  692. for cls_score, bbox_pred, anchors in zip(cls_score_list, bbox_pred_list,
  693. mlvl_anchors):
  694. cls_score = paddle.reshape(cls_score, [-1, cls_out_channels])
  695. if use_sigmoid_cls:
  696. scores = F.sigmoid(cls_score)
  697. else:
  698. scores = F.softmax(cls_score, axis=-1)
  699. # bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 5)
  700. bbox_pred = paddle.transpose(bbox_pred, [1, 2, 0])
  701. bbox_pred = paddle.reshape(bbox_pred, [-1, 5])
  702. anchors = paddle.reshape(anchors, [-1, 5])
  703. if nms_pre > 0 and scores.shape[0] > nms_pre:
  704. # Get maximum scores for foreground classes.
  705. if use_sigmoid_cls:
  706. max_scores = paddle.max(scores, axis=1)
  707. else:
  708. max_scores = paddle.max(scores[:, 1:], axis=1)
  709. topk_val, topk_inds = paddle.topk(max_scores, nms_pre)
  710. anchors = paddle.gather(anchors, topk_inds)
  711. bbox_pred = paddle.gather(bbox_pred, topk_inds)
  712. scores = paddle.gather(scores, topk_inds)
  713. bboxes = self.delta2rbox(anchors, bbox_pred, self.target_means,
  714. self.target_stds)
  715. mlvl_bboxes.append(bboxes)
  716. mlvl_scores.append(scores)
  717. idx += 1
  718. mlvl_bboxes = paddle.concat(mlvl_bboxes, axis=0)
  719. mlvl_scores = paddle.concat(mlvl_scores)
  720. if use_sigmoid_cls:
  721. # Add a dummy background class to the front when using sigmoid
  722. padding = paddle.zeros(
  723. [mlvl_scores.shape[0], 1], dtype=mlvl_scores.dtype)
  724. mlvl_scores = paddle.concat([padding, mlvl_scores], axis=1)
  725. return mlvl_scores, mlvl_bboxes