target.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import paddle
  16. from ..bbox_utils import bbox2delta, bbox_overlaps
  17. def rpn_anchor_target(anchors,
  18. gt_boxes,
  19. rpn_batch_size_per_im,
  20. rpn_positive_overlap,
  21. rpn_negative_overlap,
  22. rpn_fg_fraction,
  23. use_random=True,
  24. batch_size=1,
  25. ignore_thresh=-1,
  26. is_crowd=None,
  27. weights=[1., 1., 1., 1.]):
  28. tgt_labels = []
  29. tgt_bboxes = []
  30. tgt_deltas = []
  31. for i in range(batch_size):
  32. gt_bbox = gt_boxes[i]
  33. is_crowd_i = is_crowd[i] if is_crowd else None
  34. # Step1: match anchor and gt_bbox
  35. matches, match_labels = label_box(
  36. anchors, gt_bbox, rpn_positive_overlap, rpn_negative_overlap, True,
  37. ignore_thresh, is_crowd_i)
  38. # Step2: sample anchor
  39. fg_inds, bg_inds = subsample_labels(match_labels,
  40. rpn_batch_size_per_im,
  41. rpn_fg_fraction, 0, use_random)
  42. # Fill with the ignore label (-1), then set positive and negative labels
  43. labels = paddle.full(match_labels.shape, -1, dtype='int32')
  44. if bg_inds.shape[0] > 0:
  45. labels = paddle.scatter(labels, bg_inds,
  46. paddle.zeros_like(bg_inds))
  47. if fg_inds.shape[0] > 0:
  48. labels = paddle.scatter(labels, fg_inds, paddle.ones_like(fg_inds))
  49. # Step3: make output
  50. if gt_bbox.shape[0] == 0:
  51. matched_gt_boxes = paddle.zeros([0, 4])
  52. tgt_delta = paddle.zeros([0, 4])
  53. else:
  54. matched_gt_boxes = paddle.gather(gt_bbox, matches)
  55. tgt_delta = bbox2delta(anchors, matched_gt_boxes, weights)
  56. matched_gt_boxes.stop_gradient = True
  57. tgt_delta.stop_gradient = True
  58. labels.stop_gradient = True
  59. tgt_labels.append(labels)
  60. tgt_bboxes.append(matched_gt_boxes)
  61. tgt_deltas.append(tgt_delta)
  62. return tgt_labels, tgt_bboxes, tgt_deltas
  63. def label_box(anchors,
  64. gt_boxes,
  65. positive_overlap,
  66. negative_overlap,
  67. allow_low_quality,
  68. ignore_thresh,
  69. is_crowd=None):
  70. iou = bbox_overlaps(gt_boxes, anchors)
  71. n_gt = gt_boxes.shape[0]
  72. if n_gt == 0 or is_crowd is None:
  73. n_gt_crowd = 0
  74. else:
  75. n_gt_crowd = paddle.nonzero(is_crowd).shape[0]
  76. if iou.shape[0] == 0 or n_gt_crowd == n_gt:
  77. # No truth, assign everything to background
  78. default_matches = paddle.full((iou.shape[1], ), 0, dtype='int64')
  79. default_match_labels = paddle.full((iou.shape[1], ), 0, dtype='int32')
  80. return default_matches, default_match_labels
  81. # if ignore_thresh > 0, remove anchor if it is closed to
  82. # one of the crowded ground-truth
  83. if n_gt_crowd > 0:
  84. N_a = anchors.shape[0]
  85. ones = paddle.ones([N_a])
  86. mask = is_crowd * ones
  87. if ignore_thresh > 0:
  88. crowd_iou = iou * mask
  89. valid = (paddle.sum((crowd_iou > ignore_thresh).cast('int32'),
  90. axis=0) > 0).cast('float32')
  91. iou = iou * (1 - valid) - valid
  92. # ignore the iou between anchor and crowded ground-truth
  93. iou = iou * (1 - mask) - mask
  94. matched_vals, matches = paddle.topk(iou, k=1, axis=0)
  95. match_labels = paddle.full(matches.shape, -1, dtype='int32')
  96. # set ignored anchor with iou = -1
  97. neg_cond = paddle.logical_and(matched_vals > -1,
  98. matched_vals < negative_overlap)
  99. match_labels = paddle.where(neg_cond,
  100. paddle.zeros_like(match_labels), match_labels)
  101. match_labels = paddle.where(matched_vals >= positive_overlap,
  102. paddle.ones_like(match_labels), match_labels)
  103. if allow_low_quality:
  104. highest_quality_foreach_gt = iou.max(axis=1, keepdim=True)
  105. pred_inds_with_highest_quality = paddle.logical_and(
  106. iou > 0, iou == highest_quality_foreach_gt).cast('int32').sum(
  107. 0, keepdim=True)
  108. match_labels = paddle.where(pred_inds_with_highest_quality > 0,
  109. paddle.ones_like(match_labels),
  110. match_labels)
  111. matches = matches.flatten()
  112. match_labels = match_labels.flatten()
  113. return matches, match_labels
  114. def subsample_labels(labels,
  115. num_samples,
  116. fg_fraction,
  117. bg_label=0,
  118. use_random=True):
  119. positive = paddle.nonzero(
  120. paddle.logical_and(labels != -1, labels != bg_label))
  121. negative = paddle.nonzero(labels == bg_label)
  122. fg_num = int(num_samples * fg_fraction)
  123. fg_num = min(positive.numel(), fg_num)
  124. bg_num = num_samples - fg_num
  125. bg_num = min(negative.numel(), bg_num)
  126. if fg_num == 0 and bg_num == 0:
  127. fg_inds = paddle.zeros([0], dtype='int32')
  128. bg_inds = paddle.zeros([0], dtype='int32')
  129. return fg_inds, bg_inds
  130. # randomly select positive and negative examples
  131. negative = negative.cast('int32').flatten()
  132. bg_perm = paddle.randperm(negative.numel(), dtype='int32')
  133. bg_perm = paddle.slice(bg_perm, axes=[0], starts=[0], ends=[bg_num])
  134. if use_random:
  135. bg_inds = paddle.gather(negative, bg_perm)
  136. else:
  137. bg_inds = paddle.slice(negative, axes=[0], starts=[0], ends=[bg_num])
  138. if fg_num == 0:
  139. fg_inds = paddle.zeros([0], dtype='int32')
  140. return fg_inds, bg_inds
  141. positive = positive.cast('int32').flatten()
  142. fg_perm = paddle.randperm(positive.numel(), dtype='int32')
  143. fg_perm = paddle.slice(fg_perm, axes=[0], starts=[0], ends=[fg_num])
  144. if use_random:
  145. fg_inds = paddle.gather(positive, fg_perm)
  146. else:
  147. fg_inds = paddle.slice(positive, axes=[0], starts=[0], ends=[fg_num])
  148. return fg_inds, bg_inds
  149. def generate_proposal_target(rpn_rois,
  150. gt_classes,
  151. gt_boxes,
  152. batch_size_per_im,
  153. fg_fraction,
  154. fg_thresh,
  155. bg_thresh,
  156. num_classes,
  157. ignore_thresh=-1.,
  158. is_crowd=None,
  159. use_random=True,
  160. is_cascade=False,
  161. cascade_iou=0.5):
  162. rois_with_gt = []
  163. tgt_labels = []
  164. tgt_bboxes = []
  165. tgt_gt_inds = []
  166. new_rois_num = []
  167. # In cascade rcnn, the threshold for foreground and background
  168. # is used from cascade_iou
  169. fg_thresh = cascade_iou if is_cascade else fg_thresh
  170. bg_thresh = cascade_iou if is_cascade else bg_thresh
  171. for i, rpn_roi in enumerate(rpn_rois):
  172. gt_bbox = gt_boxes[i]
  173. is_crowd_i = is_crowd[i] if is_crowd else None
  174. gt_class = paddle.squeeze(gt_classes[i], axis=-1)
  175. # Concat RoIs and gt boxes except cascade rcnn or none gt
  176. if not is_cascade and gt_bbox.shape[0] > 0:
  177. bbox = paddle.concat([rpn_roi, gt_bbox])
  178. else:
  179. bbox = rpn_roi
  180. # Step1: label bbox
  181. matches, match_labels = label_box(bbox, gt_bbox, fg_thresh, bg_thresh,
  182. False, ignore_thresh, is_crowd_i)
  183. # Step2: sample bbox
  184. sampled_inds, sampled_gt_classes = sample_bbox(
  185. matches, match_labels, gt_class, batch_size_per_im, fg_fraction,
  186. num_classes, use_random, is_cascade)
  187. # Step3: make output
  188. rois_per_image = bbox if is_cascade else paddle.gather(bbox,
  189. sampled_inds)
  190. sampled_gt_ind = matches if is_cascade else paddle.gather(matches,
  191. sampled_inds)
  192. if gt_bbox.shape[0] > 0:
  193. sampled_bbox = paddle.gather(gt_bbox, sampled_gt_ind)
  194. else:
  195. num = rois_per_image.shape[0]
  196. sampled_bbox = paddle.zeros([num, 4], dtype='float32')
  197. rois_per_image.stop_gradient = True
  198. sampled_gt_ind.stop_gradient = True
  199. sampled_bbox.stop_gradient = True
  200. tgt_labels.append(sampled_gt_classes)
  201. tgt_bboxes.append(sampled_bbox)
  202. rois_with_gt.append(rois_per_image)
  203. tgt_gt_inds.append(sampled_gt_ind)
  204. new_rois_num.append(paddle.shape(sampled_inds)[0])
  205. new_rois_num = paddle.concat(new_rois_num)
  206. return rois_with_gt, tgt_labels, tgt_bboxes, tgt_gt_inds, new_rois_num
  207. def sample_bbox(matches,
  208. match_labels,
  209. gt_classes,
  210. batch_size_per_im,
  211. fg_fraction,
  212. num_classes,
  213. use_random=True,
  214. is_cascade=False):
  215. n_gt = gt_classes.shape[0]
  216. if n_gt == 0:
  217. # No truth, assign everything to background
  218. gt_classes = paddle.ones(matches.shape, dtype='int32') * num_classes
  219. #return matches, match_labels + num_classes
  220. else:
  221. gt_classes = paddle.gather(gt_classes, matches)
  222. gt_classes = paddle.where(match_labels == 0,
  223. paddle.ones_like(gt_classes) * num_classes,
  224. gt_classes)
  225. gt_classes = paddle.where(match_labels == -1,
  226. paddle.ones_like(gt_classes) * -1,
  227. gt_classes)
  228. if is_cascade:
  229. index = paddle.arange(matches.shape[0])
  230. return index, gt_classes
  231. rois_per_image = int(batch_size_per_im)
  232. fg_inds, bg_inds = subsample_labels(gt_classes, rois_per_image,
  233. fg_fraction, num_classes, use_random)
  234. if fg_inds.shape[0] == 0 and bg_inds.shape[0] == 0:
  235. # fake output labeled with -1 when all boxes are neither
  236. # foreground nor background
  237. sampled_inds = paddle.zeros([1], dtype='int32')
  238. else:
  239. sampled_inds = paddle.concat([fg_inds, bg_inds])
  240. sampled_gt_classes = paddle.gather(gt_classes, sampled_inds)
  241. return sampled_inds, sampled_gt_classes
  242. def polygons_to_mask(polygons, height, width):
  243. """
  244. Args:
  245. polygons (list[ndarray]): each array has shape (Nx2,)
  246. height, width (int)
  247. Returns:
  248. ndarray: a bool mask of shape (height, width)
  249. """
  250. import pycocotools.mask as mask_util
  251. assert len(polygons) > 0, "COCOAPI does not support empty polygons"
  252. rles = mask_util.frPyObjects(polygons, height, width)
  253. rle = mask_util.merge(rles)
  254. return mask_util.decode(rle).astype(np.bool)
  255. def rasterize_polygons_within_box(poly, box, resolution):
  256. w, h = box[2] - box[0], box[3] - box[1]
  257. polygons = [np.asarray(p, dtype=np.float64) for p in poly]
  258. for p in polygons:
  259. p[0::2] = p[0::2] - box[0]
  260. p[1::2] = p[1::2] - box[1]
  261. ratio_h = resolution / max(h, 0.1)
  262. ratio_w = resolution / max(w, 0.1)
  263. if ratio_h == ratio_w:
  264. for p in polygons:
  265. p *= ratio_h
  266. else:
  267. for p in polygons:
  268. p[0::2] *= ratio_w
  269. p[1::2] *= ratio_h
  270. # 3. Rasterize the polygons with coco api
  271. mask = polygons_to_mask(polygons, resolution, resolution)
  272. mask = paddle.to_tensor(mask, dtype='int32')
  273. return mask
  274. def generate_mask_target(gt_segms, rois, labels_int32, sampled_gt_inds,
  275. num_classes, resolution):
  276. mask_rois = []
  277. mask_rois_num = []
  278. tgt_masks = []
  279. tgt_classes = []
  280. mask_index = []
  281. tgt_weights = []
  282. for k in range(len(rois)):
  283. labels_per_im = labels_int32[k]
  284. # select rois labeled with foreground
  285. fg_inds = paddle.nonzero(
  286. paddle.logical_and(labels_per_im != -1, labels_per_im !=
  287. num_classes))
  288. has_fg = True
  289. # generate fake roi if foreground is empty
  290. if fg_inds.numel() == 0:
  291. has_fg = False
  292. fg_inds = paddle.ones([1], dtype='int32')
  293. inds_per_im = sampled_gt_inds[k]
  294. inds_per_im = paddle.gather(inds_per_im, fg_inds)
  295. rois_per_im = rois[k]
  296. fg_rois = paddle.gather(rois_per_im, fg_inds)
  297. # Copy the foreground roi to cpu
  298. # to generate mask target with ground-truth
  299. boxes = fg_rois.numpy()
  300. gt_segms_per_im = gt_segms[k]
  301. new_segm = []
  302. inds_per_im = inds_per_im.numpy()
  303. if len(gt_segms_per_im) > 0:
  304. for i in inds_per_im:
  305. new_segm.append(gt_segms_per_im[i])
  306. fg_inds_new = fg_inds.reshape([-1]).numpy()
  307. results = []
  308. if len(gt_segms_per_im) > 0:
  309. for j in fg_inds_new:
  310. results.append(
  311. rasterize_polygons_within_box(new_segm[j], boxes[j],
  312. resolution))
  313. else:
  314. results.append(
  315. paddle.ones(
  316. [resolution, resolution], dtype='int32'))
  317. fg_classes = paddle.gather(labels_per_im, fg_inds)
  318. weight = paddle.ones([fg_rois.shape[0]], dtype='float32')
  319. if not has_fg:
  320. # now all sampled classes are background
  321. # which will cause error in loss calculation,
  322. # make fake classes with weight of 0.
  323. fg_classes = paddle.zeros([1], dtype='int32')
  324. weight = weight - 1
  325. tgt_mask = paddle.stack(results)
  326. tgt_mask.stop_gradient = True
  327. fg_rois.stop_gradient = True
  328. mask_index.append(fg_inds)
  329. mask_rois.append(fg_rois)
  330. mask_rois_num.append(paddle.shape(fg_rois)[0])
  331. tgt_classes.append(fg_classes)
  332. tgt_masks.append(tgt_mask)
  333. tgt_weights.append(weight)
  334. mask_index = paddle.concat(mask_index)
  335. mask_rois_num = paddle.concat(mask_rois_num)
  336. tgt_classes = paddle.concat(tgt_classes, axis=0)
  337. tgt_masks = paddle.concat(tgt_masks, axis=0)
  338. tgt_weights = paddle.concat(tgt_weights, axis=0)
  339. return mask_rois, mask_rois_num, tgt_classes, tgt_masks, mask_index, tgt_weights
  340. def libra_sample_pos(max_overlaps, max_classes, pos_inds, num_expected):
  341. if len(pos_inds) <= num_expected:
  342. return pos_inds
  343. else:
  344. unique_gt_inds = np.unique(max_classes[pos_inds])
  345. num_gts = len(unique_gt_inds)
  346. num_per_gt = int(round(num_expected / float(num_gts)) + 1)
  347. sampled_inds = []
  348. for i in unique_gt_inds:
  349. inds = np.nonzero(max_classes == i)[0]
  350. before_len = len(inds)
  351. inds = list(set(inds) & set(pos_inds))
  352. after_len = len(inds)
  353. if len(inds) > num_per_gt:
  354. inds = np.random.choice(inds, size=num_per_gt, replace=False)
  355. sampled_inds.extend(list(inds)) # combine as a new sampler
  356. if len(sampled_inds) < num_expected:
  357. num_extra = num_expected - len(sampled_inds)
  358. extra_inds = np.array(list(set(pos_inds) - set(sampled_inds)))
  359. assert len(sampled_inds) + len(extra_inds) == len(pos_inds), \
  360. "sum of sampled_inds({}) and extra_inds({}) length must be equal with pos_inds({})!".format(
  361. len(sampled_inds), len(extra_inds), len(pos_inds))
  362. if len(extra_inds) > num_extra:
  363. extra_inds = np.random.choice(
  364. extra_inds, size=num_extra, replace=False)
  365. sampled_inds.extend(extra_inds.tolist())
  366. elif len(sampled_inds) > num_expected:
  367. sampled_inds = np.random.choice(
  368. sampled_inds, size=num_expected, replace=False)
  369. return paddle.to_tensor(sampled_inds)
  370. def libra_sample_via_interval(max_overlaps, full_set, num_expected, floor_thr,
  371. num_bins, bg_thresh):
  372. max_iou = max_overlaps.max()
  373. iou_interval = (max_iou - floor_thr) / num_bins
  374. per_num_expected = int(num_expected / num_bins)
  375. sampled_inds = []
  376. for i in range(num_bins):
  377. start_iou = floor_thr + i * iou_interval
  378. end_iou = floor_thr + (i + 1) * iou_interval
  379. tmp_set = set(
  380. np.where(
  381. np.logical_and(max_overlaps >= start_iou, max_overlaps <
  382. end_iou))[0])
  383. tmp_inds = list(tmp_set & full_set)
  384. if len(tmp_inds) > per_num_expected:
  385. tmp_sampled_set = np.random.choice(
  386. tmp_inds, size=per_num_expected, replace=False)
  387. else:
  388. tmp_sampled_set = np.array(tmp_inds, dtype=np.int)
  389. sampled_inds.append(tmp_sampled_set)
  390. sampled_inds = np.concatenate(sampled_inds)
  391. if len(sampled_inds) < num_expected:
  392. num_extra = num_expected - len(sampled_inds)
  393. extra_inds = np.array(list(full_set - set(sampled_inds)))
  394. assert len(sampled_inds) + len(extra_inds) == len(full_set), \
  395. "sum of sampled_inds({}) and extra_inds({}) length must be equal with full_set({})!".format(
  396. len(sampled_inds), len(extra_inds), len(full_set))
  397. if len(extra_inds) > num_extra:
  398. extra_inds = np.random.choice(extra_inds, num_extra, replace=False)
  399. sampled_inds = np.concatenate([sampled_inds, extra_inds])
  400. return sampled_inds
  401. def libra_sample_neg(max_overlaps,
  402. max_classes,
  403. neg_inds,
  404. num_expected,
  405. floor_thr=-1,
  406. floor_fraction=0,
  407. num_bins=3,
  408. bg_thresh=0.5):
  409. if len(neg_inds) <= num_expected:
  410. return neg_inds
  411. else:
  412. # balance sampling for negative samples
  413. neg_set = set(neg_inds.tolist())
  414. if floor_thr > 0:
  415. floor_set = set(
  416. np.where(
  417. np.logical_and(max_overlaps >= 0, max_overlaps <
  418. floor_thr))[0])
  419. iou_sampling_set = set(np.where(max_overlaps >= floor_thr)[0])
  420. elif floor_thr == 0:
  421. floor_set = set(np.where(max_overlaps == 0)[0])
  422. iou_sampling_set = set(np.where(max_overlaps > floor_thr)[0])
  423. else:
  424. floor_set = set()
  425. iou_sampling_set = set(np.where(max_overlaps > floor_thr)[0])
  426. floor_thr = 0
  427. floor_neg_inds = list(floor_set & neg_set)
  428. iou_sampling_neg_inds = list(iou_sampling_set & neg_set)
  429. num_expected_iou_sampling = int(num_expected * (1 - floor_fraction))
  430. if len(iou_sampling_neg_inds) > num_expected_iou_sampling:
  431. if num_bins >= 2:
  432. iou_sampled_inds = libra_sample_via_interval(
  433. max_overlaps,
  434. set(iou_sampling_neg_inds), num_expected_iou_sampling,
  435. floor_thr, num_bins, bg_thresh)
  436. else:
  437. iou_sampled_inds = np.random.choice(
  438. iou_sampling_neg_inds,
  439. size=num_expected_iou_sampling,
  440. replace=False)
  441. else:
  442. iou_sampled_inds = np.array(iou_sampling_neg_inds, dtype=np.int)
  443. num_expected_floor = num_expected - len(iou_sampled_inds)
  444. if len(floor_neg_inds) > num_expected_floor:
  445. sampled_floor_inds = np.random.choice(
  446. floor_neg_inds, size=num_expected_floor, replace=False)
  447. else:
  448. sampled_floor_inds = np.array(floor_neg_inds, dtype=np.int)
  449. sampled_inds = np.concatenate((sampled_floor_inds, iou_sampled_inds))
  450. if len(sampled_inds) < num_expected:
  451. num_extra = num_expected - len(sampled_inds)
  452. extra_inds = np.array(list(neg_set - set(sampled_inds)))
  453. if len(extra_inds) > num_extra:
  454. extra_inds = np.random.choice(
  455. extra_inds, size=num_extra, replace=False)
  456. sampled_inds = np.concatenate((sampled_inds, extra_inds))
  457. return paddle.to_tensor(sampled_inds)
  458. def libra_label_box(anchors, gt_boxes, gt_classes, positive_overlap,
  459. negative_overlap, num_classes):
  460. # TODO: use paddle API to speed up
  461. gt_classes = gt_classes.numpy()
  462. gt_overlaps = np.zeros((anchors.shape[0], num_classes))
  463. matches = np.zeros((anchors.shape[0]), dtype=np.int32)
  464. if len(gt_boxes) > 0:
  465. proposal_to_gt_overlaps = bbox_overlaps(anchors, gt_boxes).numpy()
  466. overlaps_argmax = proposal_to_gt_overlaps.argmax(axis=1)
  467. overlaps_max = proposal_to_gt_overlaps.max(axis=1)
  468. # Boxes which with non-zero overlap with gt boxes
  469. overlapped_boxes_ind = np.where(overlaps_max > 0)[0]
  470. overlapped_boxes_gt_classes = gt_classes[overlaps_argmax[
  471. overlapped_boxes_ind]]
  472. for idx in range(len(overlapped_boxes_ind)):
  473. gt_overlaps[overlapped_boxes_ind[idx], overlapped_boxes_gt_classes[
  474. idx]] = overlaps_max[overlapped_boxes_ind[idx]]
  475. matches[overlapped_boxes_ind[idx]] = overlaps_argmax[
  476. overlapped_boxes_ind[idx]]
  477. gt_overlaps = paddle.to_tensor(gt_overlaps)
  478. matches = paddle.to_tensor(matches)
  479. matched_vals = paddle.max(gt_overlaps, axis=1)
  480. match_labels = paddle.full(matches.shape, -1, dtype='int32')
  481. match_labels = paddle.where(matched_vals < negative_overlap,
  482. paddle.zeros_like(match_labels), match_labels)
  483. match_labels = paddle.where(matched_vals >= positive_overlap,
  484. paddle.ones_like(match_labels), match_labels)
  485. return matches, match_labels, matched_vals
  486. def libra_sample_bbox(matches,
  487. match_labels,
  488. matched_vals,
  489. gt_classes,
  490. batch_size_per_im,
  491. num_classes,
  492. fg_fraction,
  493. fg_thresh,
  494. bg_thresh,
  495. num_bins,
  496. use_random=True,
  497. is_cascade_rcnn=False):
  498. rois_per_image = int(batch_size_per_im)
  499. fg_rois_per_im = int(np.round(fg_fraction * rois_per_image))
  500. bg_rois_per_im = rois_per_image - fg_rois_per_im
  501. if is_cascade_rcnn:
  502. fg_inds = paddle.nonzero(matched_vals >= fg_thresh)
  503. bg_inds = paddle.nonzero(matched_vals < bg_thresh)
  504. else:
  505. matched_vals_np = matched_vals.numpy()
  506. match_labels_np = match_labels.numpy()
  507. # sample fg
  508. fg_inds = paddle.nonzero(matched_vals >= fg_thresh).flatten()
  509. fg_nums = int(np.minimum(fg_rois_per_im, fg_inds.shape[0]))
  510. if (fg_inds.shape[0] > fg_nums) and use_random:
  511. fg_inds = libra_sample_pos(matched_vals_np, match_labels_np,
  512. fg_inds.numpy(), fg_rois_per_im)
  513. fg_inds = fg_inds[:fg_nums]
  514. # sample bg
  515. bg_inds = paddle.nonzero(matched_vals < bg_thresh).flatten()
  516. bg_nums = int(np.minimum(rois_per_image - fg_nums, bg_inds.shape[0]))
  517. if (bg_inds.shape[0] > bg_nums) and use_random:
  518. bg_inds = libra_sample_neg(
  519. matched_vals_np,
  520. match_labels_np,
  521. bg_inds.numpy(),
  522. bg_rois_per_im,
  523. num_bins=num_bins,
  524. bg_thresh=bg_thresh)
  525. bg_inds = bg_inds[:bg_nums]
  526. sampled_inds = paddle.concat([fg_inds, bg_inds])
  527. gt_classes = paddle.gather(gt_classes, matches)
  528. gt_classes = paddle.where(match_labels == 0,
  529. paddle.ones_like(gt_classes) * num_classes,
  530. gt_classes)
  531. gt_classes = paddle.where(match_labels == -1,
  532. paddle.ones_like(gt_classes) * -1,
  533. gt_classes)
  534. sampled_gt_classes = paddle.gather(gt_classes, sampled_inds)
  535. return sampled_inds, sampled_gt_classes
  536. def libra_generate_proposal_target(rpn_rois,
  537. gt_classes,
  538. gt_boxes,
  539. batch_size_per_im,
  540. fg_fraction,
  541. fg_thresh,
  542. bg_thresh,
  543. num_classes,
  544. use_random=True,
  545. is_cascade_rcnn=False,
  546. max_overlaps=None,
  547. num_bins=3):
  548. rois_with_gt = []
  549. tgt_labels = []
  550. tgt_bboxes = []
  551. sampled_max_overlaps = []
  552. tgt_gt_inds = []
  553. new_rois_num = []
  554. for i, rpn_roi in enumerate(rpn_rois):
  555. max_overlap = max_overlaps[i] if is_cascade_rcnn else None
  556. gt_bbox = gt_boxes[i]
  557. gt_class = paddle.squeeze(gt_classes[i], axis=-1)
  558. if is_cascade_rcnn:
  559. rpn_roi = filter_roi(rpn_roi, max_overlap)
  560. bbox = paddle.concat([rpn_roi, gt_bbox])
  561. # Step1: label bbox
  562. matches, match_labels, matched_vals = libra_label_box(
  563. bbox, gt_bbox, gt_class, fg_thresh, bg_thresh, num_classes)
  564. # Step2: sample bbox
  565. sampled_inds, sampled_gt_classes = libra_sample_bbox(
  566. matches, match_labels, matched_vals, gt_class, batch_size_per_im,
  567. num_classes, fg_fraction, fg_thresh, bg_thresh, num_bins,
  568. use_random, is_cascade_rcnn)
  569. # Step3: make output
  570. rois_per_image = paddle.gather(bbox, sampled_inds)
  571. sampled_gt_ind = paddle.gather(matches, sampled_inds)
  572. sampled_bbox = paddle.gather(gt_bbox, sampled_gt_ind)
  573. sampled_overlap = paddle.gather(matched_vals, sampled_inds)
  574. rois_per_image.stop_gradient = True
  575. sampled_gt_ind.stop_gradient = True
  576. sampled_bbox.stop_gradient = True
  577. sampled_overlap.stop_gradient = True
  578. tgt_labels.append(sampled_gt_classes)
  579. tgt_bboxes.append(sampled_bbox)
  580. rois_with_gt.append(rois_per_image)
  581. sampled_max_overlaps.append(sampled_overlap)
  582. tgt_gt_inds.append(sampled_gt_ind)
  583. new_rois_num.append(paddle.shape(sampled_inds)[0])
  584. new_rois_num = paddle.concat(new_rois_num)
  585. # rois_with_gt, tgt_labels, tgt_bboxes, tgt_gt_inds, new_rois_num
  586. return rois_with_gt, tgt_labels, tgt_bboxes, tgt_gt_inds, new_rois_num