target_layer.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import sys
  15. import paddle
  16. from paddlex.ppdet.core.workspace import register, serializable
  17. from .target import rpn_anchor_target, generate_proposal_target, generate_mask_target, libra_generate_proposal_target
  18. import numpy as np
  19. @register
  20. @serializable
  21. class RPNTargetAssign(object):
  22. """
  23. RPN targets assignment module
  24. The assignment consists of three steps:
  25. 1. Match anchor and ground-truth box, label the anchor with foreground
  26. or background sample
  27. 2. Sample anchors to keep the properly ratio between foreground and
  28. background
  29. 3. Generate the targets for classification and regression branch
  30. Args:
  31. batch_size_per_im (int): Total number of RPN samples per image.
  32. default 256
  33. fg_fraction (float): Fraction of anchors that is labeled
  34. foreground, default 0.5
  35. positive_overlap (float): Minimum overlap required between an anchor
  36. and ground-truth box for the (anchor, gt box) pair to be
  37. a foreground sample. default 0.7
  38. negative_overlap (float): Maximum overlap allowed between an anchor
  39. and ground-truth box for the (anchor, gt box) pair to be
  40. a background sample. default 0.3
  41. ignore_thresh(float): Threshold for ignoring the is_crowd ground-truth
  42. if the value is larger than zero.
  43. use_random (bool): Use random sampling to choose foreground and
  44. background boxes, default true.
  45. """
  46. def __init__(self,
  47. batch_size_per_im=256,
  48. fg_fraction=0.5,
  49. positive_overlap=0.7,
  50. negative_overlap=0.3,
  51. ignore_thresh=-1.,
  52. use_random=True):
  53. super(RPNTargetAssign, self).__init__()
  54. self.batch_size_per_im = batch_size_per_im
  55. self.fg_fraction = fg_fraction
  56. self.positive_overlap = positive_overlap
  57. self.negative_overlap = negative_overlap
  58. self.ignore_thresh = ignore_thresh
  59. self.use_random = use_random
  60. def __call__(self, inputs, anchors):
  61. """
  62. inputs: ground-truth instances.
  63. anchor_box (Tensor): [num_anchors, 4], num_anchors are all anchors in all feature maps.
  64. """
  65. gt_boxes = inputs['gt_bbox']
  66. is_crowd = inputs.get('is_crowd', None)
  67. batch_size = len(gt_boxes)
  68. tgt_labels, tgt_bboxes, tgt_deltas = rpn_anchor_target(
  69. anchors, gt_boxes, self.batch_size_per_im, self.positive_overlap,
  70. self.negative_overlap, self.fg_fraction, self.use_random,
  71. batch_size, self.ignore_thresh, is_crowd)
  72. norm = self.batch_size_per_im * batch_size
  73. return tgt_labels, tgt_bboxes, tgt_deltas, norm
  74. @register
  75. class BBoxAssigner(object):
  76. __shared__ = ['num_classes']
  77. """
  78. RCNN targets assignment module
  79. The assignment consists of three steps:
  80. 1. Match RoIs and ground-truth box, label the RoIs with foreground
  81. or background sample
  82. 2. Sample anchors to keep the properly ratio between foreground and
  83. background
  84. 3. Generate the targets for classification and regression branch
  85. Args:
  86. batch_size_per_im (int): Total number of RoIs per image.
  87. default 512
  88. fg_fraction (float): Fraction of RoIs that is labeled
  89. foreground, default 0.25
  90. fg_thresh (float): Minimum overlap required between a RoI
  91. and ground-truth box for the (roi, gt box) pair to be
  92. a foreground sample. default 0.5
  93. bg_thresh (float): Maximum overlap allowed between a RoI
  94. and ground-truth box for the (roi, gt box) pair to be
  95. a background sample. default 0.5
  96. ignore_thresh(float): Threshold for ignoring the is_crowd ground-truth
  97. if the value is larger than zero.
  98. use_random (bool): Use random sampling to choose foreground and
  99. background boxes, default true
  100. cascade_iou (list[iou]): The list of overlap to select foreground and
  101. background of each stage, which is only used In Cascade RCNN.
  102. num_classes (int): The number of class.
  103. """
  104. def __init__(self,
  105. batch_size_per_im=512,
  106. fg_fraction=.25,
  107. fg_thresh=.5,
  108. bg_thresh=.5,
  109. ignore_thresh=-1.,
  110. use_random=True,
  111. cascade_iou=[0.5, 0.6, 0.7],
  112. num_classes=80):
  113. super(BBoxAssigner, self).__init__()
  114. self.batch_size_per_im = batch_size_per_im
  115. self.fg_fraction = fg_fraction
  116. self.fg_thresh = fg_thresh
  117. self.bg_thresh = bg_thresh
  118. self.ignore_thresh = ignore_thresh
  119. self.use_random = use_random
  120. self.cascade_iou = cascade_iou
  121. self.num_classes = num_classes
  122. def __call__(self,
  123. rpn_rois,
  124. rpn_rois_num,
  125. inputs,
  126. stage=0,
  127. is_cascade=False):
  128. gt_classes = inputs['gt_class']
  129. gt_boxes = inputs['gt_bbox']
  130. is_crowd = inputs.get('is_crowd', None)
  131. # rois, tgt_labels, tgt_bboxes, tgt_gt_inds
  132. # new_rois_num
  133. outs = generate_proposal_target(
  134. rpn_rois, gt_classes, gt_boxes, self.batch_size_per_im,
  135. self.fg_fraction, self.fg_thresh, self.bg_thresh, self.num_classes,
  136. self.ignore_thresh, is_crowd, self.use_random, is_cascade,
  137. self.cascade_iou[stage])
  138. rois = outs[0]
  139. rois_num = outs[-1]
  140. # tgt_labels, tgt_bboxes, tgt_gt_inds
  141. targets = outs[1:4]
  142. return rois, rois_num, targets
  143. @register
  144. class BBoxLibraAssigner(object):
  145. __shared__ = ['num_classes']
  146. """
  147. Libra-RCNN targets assignment module
  148. The assignment consists of three steps:
  149. 1. Match RoIs and ground-truth box, label the RoIs with foreground
  150. or background sample
  151. 2. Sample anchors to keep the properly ratio between foreground and
  152. background
  153. 3. Generate the targets for classification and regression branch
  154. Args:
  155. batch_size_per_im (int): Total number of RoIs per image.
  156. default 512
  157. fg_fraction (float): Fraction of RoIs that is labeled
  158. foreground, default 0.25
  159. fg_thresh (float): Minimum overlap required between a RoI
  160. and ground-truth box for the (roi, gt box) pair to be
  161. a foreground sample. default 0.5
  162. bg_thresh (float): Maximum overlap allowed between a RoI
  163. and ground-truth box for the (roi, gt box) pair to be
  164. a background sample. default 0.5
  165. use_random (bool): Use random sampling to choose foreground and
  166. background boxes, default true
  167. cascade_iou (list[iou]): The list of overlap to select foreground and
  168. background of each stage, which is only used In Cascade RCNN.
  169. num_classes (int): The number of class.
  170. num_bins (int): The number of libra_sample.
  171. """
  172. def __init__(self,
  173. batch_size_per_im=512,
  174. fg_fraction=.25,
  175. fg_thresh=.5,
  176. bg_thresh=.5,
  177. use_random=True,
  178. cascade_iou=[0.5, 0.6, 0.7],
  179. num_classes=80,
  180. num_bins=3):
  181. super(BBoxLibraAssigner, self).__init__()
  182. self.batch_size_per_im = batch_size_per_im
  183. self.fg_fraction = fg_fraction
  184. self.fg_thresh = fg_thresh
  185. self.bg_thresh = bg_thresh
  186. self.use_random = use_random
  187. self.cascade_iou = cascade_iou
  188. self.num_classes = num_classes
  189. self.num_bins = num_bins
  190. def __call__(self,
  191. rpn_rois,
  192. rpn_rois_num,
  193. inputs,
  194. stage=0,
  195. is_cascade=False):
  196. gt_classes = inputs['gt_class']
  197. gt_boxes = inputs['gt_bbox']
  198. # rois, tgt_labels, tgt_bboxes, tgt_gt_inds
  199. outs = libra_generate_proposal_target(
  200. rpn_rois, gt_classes, gt_boxes, self.batch_size_per_im,
  201. self.fg_fraction, self.fg_thresh, self.bg_thresh, self.num_classes,
  202. self.use_random, is_cascade, self.cascade_iou[stage],
  203. self.num_bins)
  204. rois = outs[0]
  205. rois_num = outs[-1]
  206. # tgt_labels, tgt_bboxes, tgt_gt_inds
  207. targets = outs[1:4]
  208. return rois, rois_num, targets
  209. @register
  210. @serializable
  211. class MaskAssigner(object):
  212. __shared__ = ['num_classes', 'mask_resolution']
  213. """
  214. Mask targets assignment module
  215. The assignment consists of three steps:
  216. 1. Select RoIs labels with foreground.
  217. 2. Encode the RoIs and corresponding gt polygons to generate
  218. mask target
  219. Args:
  220. num_classes (int): The number of class
  221. mask_resolution (int): The resolution of mask target, default 14
  222. """
  223. def __init__(self, num_classes=80, mask_resolution=14):
  224. super(MaskAssigner, self).__init__()
  225. self.num_classes = num_classes
  226. self.mask_resolution = mask_resolution
  227. def __call__(self, rois, tgt_labels, tgt_gt_inds, inputs):
  228. gt_segms = inputs['gt_poly']
  229. outs = generate_mask_target(gt_segms, rois, tgt_labels, tgt_gt_inds,
  230. self.num_classes, self.mask_resolution)
  231. # mask_rois, mask_rois_num, tgt_classes, tgt_masks, mask_index, tgt_weights
  232. return outs
  233. @register
  234. class RBoxAssigner(object):
  235. """
  236. assigner of rbox
  237. Args:
  238. pos_iou_thr (float): threshold of pos samples
  239. neg_iou_thr (float): threshold of neg samples
  240. min_iou_thr (float): the min threshold of samples
  241. ignore_iof_thr (int): the ignored threshold
  242. """
  243. def __init__(self,
  244. pos_iou_thr=0.5,
  245. neg_iou_thr=0.4,
  246. min_iou_thr=0.0,
  247. ignore_iof_thr=-2):
  248. super(RBoxAssigner, self).__init__()
  249. self.pos_iou_thr = pos_iou_thr
  250. self.neg_iou_thr = neg_iou_thr
  251. self.min_iou_thr = min_iou_thr
  252. self.ignore_iof_thr = ignore_iof_thr
  253. def anchor_valid(self, anchors):
  254. """
  255. Args:
  256. anchor: M x 4
  257. Returns:
  258. """
  259. if anchors.ndim == 3:
  260. anchors = anchors.reshape(-1, anchors.shape[-1])
  261. assert anchors.ndim == 2
  262. anchor_num = anchors.shape[0]
  263. anchor_valid = np.ones((anchor_num), np.int32)
  264. anchor_inds = np.arange(anchor_num)
  265. return anchor_inds
  266. def rbox2delta(self,
  267. proposals,
  268. gt,
  269. means=[0, 0, 0, 0, 0],
  270. stds=[1, 1, 1, 1, 1]):
  271. """
  272. Args:
  273. proposals: tensor [N, 5]
  274. gt: gt [N, 5]
  275. means: means [5]
  276. stds: stds [5]
  277. Returns:
  278. """
  279. proposals = proposals.astype(np.float64)
  280. PI = np.pi
  281. gt_widths = gt[..., 2]
  282. gt_heights = gt[..., 3]
  283. gt_angle = gt[..., 4]
  284. proposals_widths = proposals[..., 2]
  285. proposals_heights = proposals[..., 3]
  286. proposals_angle = proposals[..., 4]
  287. coord = gt[..., 0:2] - proposals[..., 0:2]
  288. dx = (np.cos(proposals[..., 4]) * coord[..., 0] +
  289. np.sin(proposals[..., 4]) * coord[..., 1]) / proposals_widths
  290. dy = (-np.sin(proposals[..., 4]) * coord[..., 0] +
  291. np.cos(proposals[..., 4]) * coord[..., 1]) / proposals_heights
  292. dw = np.log(gt_widths / proposals_widths)
  293. dh = np.log(gt_heights / proposals_heights)
  294. da = (gt_angle - proposals_angle)
  295. da = (da + PI / 4) % PI - PI / 4
  296. da /= PI
  297. deltas = np.stack([dx, dy, dw, dh, da], axis=-1)
  298. means = np.array(means, dtype=deltas.dtype)
  299. stds = np.array(stds, dtype=deltas.dtype)
  300. deltas = (deltas - means) / stds
  301. deltas = deltas.astype(np.float32)
  302. return deltas
  303. def assign_anchor(self,
  304. anchors,
  305. gt_bboxes,
  306. gt_lables,
  307. pos_iou_thr,
  308. neg_iou_thr,
  309. min_iou_thr=0.0,
  310. ignore_iof_thr=-2):
  311. """
  312. Args:
  313. anchors:
  314. gt_bboxes:[M, 5] rc,yc,w,h,angle
  315. gt_lables:
  316. Returns:
  317. """
  318. assert anchors.shape[1] == 4 or anchors.shape[1] == 5
  319. assert gt_bboxes.shape[1] == 4 or gt_bboxes.shape[1] == 5
  320. anchors_xc_yc = anchors
  321. gt_bboxes_xc_yc = gt_bboxes
  322. # calc rbox iou
  323. anchors_xc_yc = anchors_xc_yc.astype(np.float32)
  324. gt_bboxes_xc_yc = gt_bboxes_xc_yc.astype(np.float32)
  325. anchors_xc_yc = paddle.to_tensor(anchors_xc_yc)
  326. gt_bboxes_xc_yc = paddle.to_tensor(gt_bboxes_xc_yc)
  327. try:
  328. from rbox_iou_ops import rbox_iou
  329. except Exception as e:
  330. print("import custom_ops error, try install rbox_iou_ops " \
  331. "following ppdet/ext_op/README.md", e)
  332. sys.stdout.flush()
  333. sys.exit(-1)
  334. iou = rbox_iou(gt_bboxes_xc_yc, anchors_xc_yc)
  335. iou = iou.numpy()
  336. iou = iou.T
  337. # every gt's anchor's index
  338. gt_bbox_anchor_inds = iou.argmax(axis=0)
  339. gt_bbox_anchor_iou = iou[gt_bbox_anchor_inds, np.arange(iou.shape[1])]
  340. gt_bbox_anchor_iou_inds = np.where(iou == gt_bbox_anchor_iou)[0]
  341. # every anchor's gt bbox's index
  342. anchor_gt_bbox_inds = iou.argmax(axis=1)
  343. anchor_gt_bbox_iou = iou[np.arange(iou.shape[0]), anchor_gt_bbox_inds]
  344. # (1) set labels=-2 as default
  345. labels = np.ones((iou.shape[0], ), dtype=np.int32) * ignore_iof_thr
  346. # (2) assign ignore
  347. labels[anchor_gt_bbox_iou < min_iou_thr] = ignore_iof_thr
  348. # (3) assign neg_ids -1
  349. assign_neg_ids1 = anchor_gt_bbox_iou >= min_iou_thr
  350. assign_neg_ids2 = anchor_gt_bbox_iou < neg_iou_thr
  351. assign_neg_ids = np.logical_and(assign_neg_ids1, assign_neg_ids2)
  352. labels[assign_neg_ids] = -1
  353. # anchor_gt_bbox_iou_inds
  354. # (4) assign max_iou as pos_ids >=0
  355. anchor_gt_bbox_iou_inds = anchor_gt_bbox_inds[gt_bbox_anchor_iou_inds]
  356. # gt_bbox_anchor_iou_inds = np.logical_and(gt_bbox_anchor_iou_inds, anchor_gt_bbox_iou >= min_iou_thr)
  357. labels[gt_bbox_anchor_iou_inds] = gt_lables[anchor_gt_bbox_iou_inds]
  358. # (5) assign >= pos_iou_thr as pos_ids
  359. iou_pos_iou_thr_ids = anchor_gt_bbox_iou >= pos_iou_thr
  360. iou_pos_iou_thr_ids_box_inds = anchor_gt_bbox_inds[iou_pos_iou_thr_ids]
  361. labels[iou_pos_iou_thr_ids] = gt_lables[iou_pos_iou_thr_ids_box_inds]
  362. return anchor_gt_bbox_inds, anchor_gt_bbox_iou, labels
  363. def __call__(self, anchors, gt_bboxes, gt_labels, is_crowd):
  364. assert anchors.ndim == 2
  365. assert anchors.shape[1] == 5
  366. assert gt_bboxes.ndim == 2
  367. assert gt_bboxes.shape[1] == 5
  368. pos_iou_thr = self.pos_iou_thr
  369. neg_iou_thr = self.neg_iou_thr
  370. min_iou_thr = self.min_iou_thr
  371. ignore_iof_thr = self.ignore_iof_thr
  372. anchor_num = anchors.shape[0]
  373. gt_bboxes = gt_bboxes
  374. is_crowd_slice = is_crowd
  375. not_crowd_inds = np.where(is_crowd_slice == 0)
  376. # Step1: match anchor and gt_bbox
  377. anchor_gt_bbox_inds, anchor_gt_bbox_iou, labels = self.assign_anchor(
  378. anchors, gt_bboxes,
  379. gt_labels.reshape(-1), pos_iou_thr, neg_iou_thr, min_iou_thr,
  380. ignore_iof_thr)
  381. # Step2: sample anchor
  382. pos_inds = np.where(labels >= 0)[0]
  383. neg_inds = np.where(labels == -1)[0]
  384. # Step3: make output
  385. anchors_num = anchors.shape[0]
  386. bbox_targets = np.zeros_like(anchors)
  387. bbox_weights = np.zeros_like(anchors)
  388. bbox_gt_bboxes = np.zeros_like(anchors)
  389. pos_labels = np.ones(anchors_num, dtype=np.int32) * -1
  390. pos_labels_weights = np.zeros(anchors_num, dtype=np.float32)
  391. pos_sampled_anchors = anchors[pos_inds]
  392. pos_sampled_gt_boxes = gt_bboxes[anchor_gt_bbox_inds[pos_inds]]
  393. if len(pos_inds) > 0:
  394. pos_bbox_targets = self.rbox2delta(pos_sampled_anchors,
  395. pos_sampled_gt_boxes)
  396. bbox_targets[pos_inds, :] = pos_bbox_targets
  397. bbox_gt_bboxes[pos_inds, :] = pos_sampled_gt_boxes
  398. bbox_weights[pos_inds, :] = 1.0
  399. pos_labels[pos_inds] = labels[pos_inds]
  400. pos_labels_weights[pos_inds] = 1.0
  401. if len(neg_inds) > 0:
  402. pos_labels_weights[neg_inds] = 1.0
  403. return (pos_labels, pos_labels_weights, bbox_targets, bbox_weights,
  404. bbox_gt_bboxes, pos_inds, neg_inds)