target.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import paddle
  16. from ..bbox_utils import bbox2delta, bbox_overlaps
  17. def rpn_anchor_target(anchors,
  18. gt_boxes,
  19. rpn_batch_size_per_im,
  20. rpn_positive_overlap,
  21. rpn_negative_overlap,
  22. rpn_fg_fraction,
  23. use_random=True,
  24. batch_size=1,
  25. ignore_thresh=-1,
  26. is_crowd=None,
  27. weights=[1., 1., 1., 1.],
  28. assign_on_cpu=False):
  29. tgt_labels = []
  30. tgt_bboxes = []
  31. tgt_deltas = []
  32. for i in range(batch_size):
  33. gt_bbox = gt_boxes[i]
  34. is_crowd_i = is_crowd[i] if is_crowd else None
  35. # Step1: match anchor and gt_bbox
  36. matches, match_labels = label_box(
  37. anchors, gt_bbox, rpn_positive_overlap, rpn_negative_overlap, True,
  38. ignore_thresh, is_crowd_i, assign_on_cpu)
  39. # Step2: sample anchor
  40. fg_inds, bg_inds = subsample_labels(match_labels,
  41. rpn_batch_size_per_im,
  42. rpn_fg_fraction, 0, use_random)
  43. # Fill with the ignore label (-1), then set positive and negative labels
  44. labels = paddle.full(match_labels.shape, -1, dtype='int32')
  45. if bg_inds.shape[0] > 0:
  46. labels = paddle.scatter(labels, bg_inds,
  47. paddle.zeros_like(bg_inds))
  48. if fg_inds.shape[0] > 0:
  49. labels = paddle.scatter(labels, fg_inds, paddle.ones_like(fg_inds))
  50. # Step3: make output
  51. if gt_bbox.shape[0] == 0:
  52. matched_gt_boxes = paddle.zeros([matches.shape[0], 4])
  53. tgt_delta = paddle.zeros([matches.shape[0], 4])
  54. else:
  55. matched_gt_boxes = paddle.gather(gt_bbox, matches)
  56. tgt_delta = bbox2delta(anchors, matched_gt_boxes, weights)
  57. matched_gt_boxes.stop_gradient = True
  58. tgt_delta.stop_gradient = True
  59. labels.stop_gradient = True
  60. tgt_labels.append(labels)
  61. tgt_bboxes.append(matched_gt_boxes)
  62. tgt_deltas.append(tgt_delta)
  63. return tgt_labels, tgt_bboxes, tgt_deltas
  64. def label_box(anchors,
  65. gt_boxes,
  66. positive_overlap,
  67. negative_overlap,
  68. allow_low_quality,
  69. ignore_thresh,
  70. is_crowd=None,
  71. assign_on_cpu=False):
  72. if assign_on_cpu:
  73. paddle.set_device("cpu")
  74. iou = bbox_overlaps(gt_boxes, anchors)
  75. paddle.set_device("gpu")
  76. else:
  77. iou = bbox_overlaps(gt_boxes, anchors)
  78. n_gt = gt_boxes.shape[0]
  79. if n_gt == 0 or is_crowd is None:
  80. n_gt_crowd = 0
  81. else:
  82. n_gt_crowd = paddle.nonzero(is_crowd).shape[0]
  83. if iou.shape[0] == 0 or n_gt_crowd == n_gt:
  84. # No truth, assign everything to background
  85. default_matches = paddle.full((iou.shape[1], ), 0, dtype='int64')
  86. default_match_labels = paddle.full((iou.shape[1], ), 0, dtype='int32')
  87. return default_matches, default_match_labels
  88. # if ignore_thresh > 0, remove anchor if it is closed to
  89. # one of the crowded ground-truth
  90. if n_gt_crowd > 0:
  91. N_a = anchors.shape[0]
  92. ones = paddle.ones([N_a])
  93. mask = is_crowd * ones
  94. if ignore_thresh > 0:
  95. crowd_iou = iou * mask
  96. valid = (paddle.sum((crowd_iou > ignore_thresh).cast('int32'),
  97. axis=0) > 0).cast('float32')
  98. iou = iou * (1 - valid) - valid
  99. # ignore the iou between anchor and crowded ground-truth
  100. iou = iou * (1 - mask) - mask
  101. matched_vals, matches = paddle.topk(iou, k=1, axis=0)
  102. match_labels = paddle.full(matches.shape, -1, dtype='int32')
  103. # set ignored anchor with iou = -1
  104. neg_cond = paddle.logical_and(matched_vals > -1,
  105. matched_vals < negative_overlap)
  106. match_labels = paddle.where(neg_cond,
  107. paddle.zeros_like(match_labels), match_labels)
  108. match_labels = paddle.where(matched_vals >= positive_overlap,
  109. paddle.ones_like(match_labels), match_labels)
  110. if allow_low_quality:
  111. highest_quality_foreach_gt = iou.max(axis=1, keepdim=True)
  112. pred_inds_with_highest_quality = paddle.logical_and(
  113. iou > 0, iou == highest_quality_foreach_gt).cast('int32').sum(
  114. 0, keepdim=True)
  115. match_labels = paddle.where(pred_inds_with_highest_quality > 0,
  116. paddle.ones_like(match_labels),
  117. match_labels)
  118. matches = matches.flatten()
  119. match_labels = match_labels.flatten()
  120. return matches, match_labels
  121. def subsample_labels(labels,
  122. num_samples,
  123. fg_fraction,
  124. bg_label=0,
  125. use_random=True):
  126. positive = paddle.nonzero(
  127. paddle.logical_and(labels != -1, labels != bg_label))
  128. negative = paddle.nonzero(labels == bg_label)
  129. fg_num = int(num_samples * fg_fraction)
  130. fg_num = min(positive.numel(), fg_num)
  131. bg_num = num_samples - fg_num
  132. bg_num = min(negative.numel(), bg_num)
  133. if fg_num == 0 and bg_num == 0:
  134. fg_inds = paddle.zeros([0], dtype='int32')
  135. bg_inds = paddle.zeros([0], dtype='int32')
  136. return fg_inds, bg_inds
  137. # randomly select positive and negative examples
  138. negative = negative.cast('int32').flatten()
  139. bg_perm = paddle.randperm(negative.numel(), dtype='int32')
  140. bg_perm = paddle.slice(bg_perm, axes=[0], starts=[0], ends=[bg_num])
  141. if use_random:
  142. bg_inds = paddle.gather(negative, bg_perm)
  143. else:
  144. bg_inds = paddle.slice(negative, axes=[0], starts=[0], ends=[bg_num])
  145. if fg_num == 0:
  146. fg_inds = paddle.zeros([0], dtype='int32')
  147. return fg_inds, bg_inds
  148. positive = positive.cast('int32').flatten()
  149. fg_perm = paddle.randperm(positive.numel(), dtype='int32')
  150. fg_perm = paddle.slice(fg_perm, axes=[0], starts=[0], ends=[fg_num])
  151. if use_random:
  152. fg_inds = paddle.gather(positive, fg_perm)
  153. else:
  154. fg_inds = paddle.slice(positive, axes=[0], starts=[0], ends=[fg_num])
  155. return fg_inds, bg_inds
  156. def generate_proposal_target(rpn_rois,
  157. gt_classes,
  158. gt_boxes,
  159. batch_size_per_im,
  160. fg_fraction,
  161. fg_thresh,
  162. bg_thresh,
  163. num_classes,
  164. ignore_thresh=-1.,
  165. is_crowd=None,
  166. use_random=True,
  167. is_cascade=False,
  168. cascade_iou=0.5,
  169. assign_on_cpu=False):
  170. rois_with_gt = []
  171. tgt_labels = []
  172. tgt_bboxes = []
  173. tgt_gt_inds = []
  174. new_rois_num = []
  175. # In cascade rcnn, the threshold for foreground and background
  176. # is used from cascade_iou
  177. fg_thresh = cascade_iou if is_cascade else fg_thresh
  178. bg_thresh = cascade_iou if is_cascade else bg_thresh
  179. for i, rpn_roi in enumerate(rpn_rois):
  180. gt_bbox = gt_boxes[i]
  181. is_crowd_i = is_crowd[i] if is_crowd else None
  182. gt_class = paddle.squeeze(gt_classes[i], axis=-1)
  183. # Concat RoIs and gt boxes except cascade rcnn or none gt
  184. if not is_cascade and gt_bbox.shape[0] > 0:
  185. bbox = paddle.concat([rpn_roi, gt_bbox])
  186. else:
  187. bbox = rpn_roi
  188. # Step1: label bbox
  189. matches, match_labels = label_box(bbox, gt_bbox, fg_thresh, bg_thresh,
  190. False, ignore_thresh, is_crowd_i,
  191. assign_on_cpu)
  192. # Step2: sample bbox
  193. sampled_inds, sampled_gt_classes = sample_bbox(
  194. matches, match_labels, gt_class, batch_size_per_im, fg_fraction,
  195. num_classes, use_random, is_cascade)
  196. # Step3: make output
  197. rois_per_image = bbox if is_cascade else paddle.gather(bbox,
  198. sampled_inds)
  199. sampled_gt_ind = matches if is_cascade else paddle.gather(matches,
  200. sampled_inds)
  201. if gt_bbox.shape[0] > 0:
  202. sampled_bbox = paddle.gather(gt_bbox, sampled_gt_ind)
  203. else:
  204. num = rois_per_image.shape[0]
  205. sampled_bbox = paddle.zeros([num, 4], dtype='float32')
  206. rois_per_image.stop_gradient = True
  207. sampled_gt_ind.stop_gradient = True
  208. sampled_bbox.stop_gradient = True
  209. tgt_labels.append(sampled_gt_classes)
  210. tgt_bboxes.append(sampled_bbox)
  211. rois_with_gt.append(rois_per_image)
  212. tgt_gt_inds.append(sampled_gt_ind)
  213. new_rois_num.append(paddle.shape(sampled_inds)[0])
  214. new_rois_num = paddle.concat(new_rois_num)
  215. return rois_with_gt, tgt_labels, tgt_bboxes, tgt_gt_inds, new_rois_num
  216. def sample_bbox(matches,
  217. match_labels,
  218. gt_classes,
  219. batch_size_per_im,
  220. fg_fraction,
  221. num_classes,
  222. use_random=True,
  223. is_cascade=False):
  224. n_gt = gt_classes.shape[0]
  225. if n_gt == 0:
  226. # No truth, assign everything to background
  227. gt_classes = paddle.ones(matches.shape, dtype='int32') * num_classes
  228. #return matches, match_labels + num_classes
  229. else:
  230. gt_classes = paddle.gather(gt_classes, matches)
  231. gt_classes = paddle.where(match_labels == 0,
  232. paddle.ones_like(gt_classes) * num_classes,
  233. gt_classes)
  234. gt_classes = paddle.where(match_labels == -1,
  235. paddle.ones_like(gt_classes) * -1,
  236. gt_classes)
  237. if is_cascade:
  238. index = paddle.arange(matches.shape[0])
  239. return index, gt_classes
  240. rois_per_image = int(batch_size_per_im)
  241. fg_inds, bg_inds = subsample_labels(gt_classes, rois_per_image,
  242. fg_fraction, num_classes, use_random)
  243. if fg_inds.shape[0] == 0 and bg_inds.shape[0] == 0:
  244. # fake output labeled with -1 when all boxes are neither
  245. # foreground nor background
  246. sampled_inds = paddle.zeros([1], dtype='int32')
  247. else:
  248. sampled_inds = paddle.concat([fg_inds, bg_inds])
  249. sampled_gt_classes = paddle.gather(gt_classes, sampled_inds)
  250. return sampled_inds, sampled_gt_classes
  251. def polygons_to_mask(polygons, height, width):
  252. """
  253. Convert the polygons to mask format
  254. Args:
  255. polygons (list[ndarray]): each array has shape (Nx2,)
  256. height (int): mask height
  257. width (int): mask width
  258. Returns:
  259. ndarray: a bool mask of shape (height, width)
  260. """
  261. import pycocotools.mask as mask_util
  262. assert len(polygons) > 0, "COCOAPI does not support empty polygons"
  263. rles = mask_util.frPyObjects(polygons, height, width)
  264. rle = mask_util.merge(rles)
  265. return mask_util.decode(rle).astype(np.bool)
  266. def rasterize_polygons_within_box(poly, box, resolution):
  267. w, h = box[2] - box[0], box[3] - box[1]
  268. polygons = [np.asarray(p, dtype=np.float64) for p in poly]
  269. for p in polygons:
  270. p[0::2] = p[0::2] - box[0]
  271. p[1::2] = p[1::2] - box[1]
  272. ratio_h = resolution / max(h, 0.1)
  273. ratio_w = resolution / max(w, 0.1)
  274. if ratio_h == ratio_w:
  275. for p in polygons:
  276. p *= ratio_h
  277. else:
  278. for p in polygons:
  279. p[0::2] *= ratio_w
  280. p[1::2] *= ratio_h
  281. # 3. Rasterize the polygons with coco api
  282. mask = polygons_to_mask(polygons, resolution, resolution)
  283. mask = paddle.to_tensor(mask, dtype='int32')
  284. return mask
  285. def generate_mask_target(gt_segms, rois, labels_int32, sampled_gt_inds,
  286. num_classes, resolution):
  287. mask_rois = []
  288. mask_rois_num = []
  289. tgt_masks = []
  290. tgt_classes = []
  291. mask_index = []
  292. tgt_weights = []
  293. for k in range(len(rois)):
  294. labels_per_im = labels_int32[k]
  295. # select rois labeled with foreground
  296. fg_inds = paddle.nonzero(
  297. paddle.logical_and(labels_per_im != -1, labels_per_im !=
  298. num_classes))
  299. has_fg = True
  300. # generate fake roi if foreground is empty
  301. if fg_inds.numel() == 0:
  302. has_fg = False
  303. fg_inds = paddle.ones([1], dtype='int32')
  304. inds_per_im = sampled_gt_inds[k]
  305. inds_per_im = paddle.gather(inds_per_im, fg_inds)
  306. rois_per_im = rois[k]
  307. fg_rois = paddle.gather(rois_per_im, fg_inds)
  308. # Copy the foreground roi to cpu
  309. # to generate mask target with ground-truth
  310. boxes = fg_rois.numpy()
  311. gt_segms_per_im = gt_segms[k]
  312. new_segm = []
  313. inds_per_im = inds_per_im.numpy()
  314. if len(gt_segms_per_im) > 0:
  315. for i in inds_per_im:
  316. new_segm.append(gt_segms_per_im[i])
  317. fg_inds_new = fg_inds.reshape([-1]).numpy()
  318. results = []
  319. if len(gt_segms_per_im) > 0:
  320. for j in fg_inds_new:
  321. results.append(
  322. rasterize_polygons_within_box(new_segm[j], boxes[j],
  323. resolution))
  324. else:
  325. results.append(
  326. paddle.ones(
  327. [resolution, resolution], dtype='int32'))
  328. fg_classes = paddle.gather(labels_per_im, fg_inds)
  329. weight = paddle.ones([fg_rois.shape[0]], dtype='float32')
  330. if not has_fg:
  331. # now all sampled classes are background
  332. # which will cause error in loss calculation,
  333. # make fake classes with weight of 0.
  334. fg_classes = paddle.zeros([1], dtype='int32')
  335. weight = weight - 1
  336. tgt_mask = paddle.stack(results)
  337. tgt_mask.stop_gradient = True
  338. fg_rois.stop_gradient = True
  339. mask_index.append(fg_inds)
  340. mask_rois.append(fg_rois)
  341. mask_rois_num.append(paddle.shape(fg_rois)[0])
  342. tgt_classes.append(fg_classes)
  343. tgt_masks.append(tgt_mask)
  344. tgt_weights.append(weight)
  345. mask_index = paddle.concat(mask_index)
  346. mask_rois_num = paddle.concat(mask_rois_num)
  347. tgt_classes = paddle.concat(tgt_classes, axis=0)
  348. tgt_masks = paddle.concat(tgt_masks, axis=0)
  349. tgt_weights = paddle.concat(tgt_weights, axis=0)
  350. return mask_rois, mask_rois_num, tgt_classes, tgt_masks, mask_index, tgt_weights
  351. def libra_sample_pos(max_overlaps, max_classes, pos_inds, num_expected):
  352. if len(pos_inds) <= num_expected:
  353. return pos_inds
  354. else:
  355. unique_gt_inds = np.unique(max_classes[pos_inds])
  356. num_gts = len(unique_gt_inds)
  357. num_per_gt = int(round(num_expected / float(num_gts)) + 1)
  358. sampled_inds = []
  359. for i in unique_gt_inds:
  360. inds = np.nonzero(max_classes == i)[0]
  361. before_len = len(inds)
  362. inds = list(set(inds) & set(pos_inds))
  363. after_len = len(inds)
  364. if len(inds) > num_per_gt:
  365. inds = np.random.choice(inds, size=num_per_gt, replace=False)
  366. sampled_inds.extend(list(inds)) # combine as a new sampler
  367. if len(sampled_inds) < num_expected:
  368. num_extra = num_expected - len(sampled_inds)
  369. extra_inds = np.array(list(set(pos_inds) - set(sampled_inds)))
  370. assert len(sampled_inds) + len(extra_inds) == len(pos_inds), \
  371. "sum of sampled_inds({}) and extra_inds({}) length must be equal with pos_inds({})!".format(
  372. len(sampled_inds), len(extra_inds), len(pos_inds))
  373. if len(extra_inds) > num_extra:
  374. extra_inds = np.random.choice(
  375. extra_inds, size=num_extra, replace=False)
  376. sampled_inds.extend(extra_inds.tolist())
  377. elif len(sampled_inds) > num_expected:
  378. sampled_inds = np.random.choice(
  379. sampled_inds, size=num_expected, replace=False)
  380. return paddle.to_tensor(sampled_inds)
  381. def libra_sample_via_interval(max_overlaps, full_set, num_expected, floor_thr,
  382. num_bins, bg_thresh):
  383. max_iou = max_overlaps.max()
  384. iou_interval = (max_iou - floor_thr) / num_bins
  385. per_num_expected = int(num_expected / num_bins)
  386. sampled_inds = []
  387. for i in range(num_bins):
  388. start_iou = floor_thr + i * iou_interval
  389. end_iou = floor_thr + (i + 1) * iou_interval
  390. tmp_set = set(
  391. np.where(
  392. np.logical_and(max_overlaps >= start_iou, max_overlaps <
  393. end_iou))[0])
  394. tmp_inds = list(tmp_set & full_set)
  395. if len(tmp_inds) > per_num_expected:
  396. tmp_sampled_set = np.random.choice(
  397. tmp_inds, size=per_num_expected, replace=False)
  398. else:
  399. tmp_sampled_set = np.array(tmp_inds, dtype=np.int)
  400. sampled_inds.append(tmp_sampled_set)
  401. sampled_inds = np.concatenate(sampled_inds)
  402. if len(sampled_inds) < num_expected:
  403. num_extra = num_expected - len(sampled_inds)
  404. extra_inds = np.array(list(full_set - set(sampled_inds)))
  405. assert len(sampled_inds) + len(extra_inds) == len(full_set), \
  406. "sum of sampled_inds({}) and extra_inds({}) length must be equal with full_set({})!".format(
  407. len(sampled_inds), len(extra_inds), len(full_set))
  408. if len(extra_inds) > num_extra:
  409. extra_inds = np.random.choice(extra_inds, num_extra, replace=False)
  410. sampled_inds = np.concatenate([sampled_inds, extra_inds])
  411. return sampled_inds
  412. def libra_sample_neg(max_overlaps,
  413. max_classes,
  414. neg_inds,
  415. num_expected,
  416. floor_thr=-1,
  417. floor_fraction=0,
  418. num_bins=3,
  419. bg_thresh=0.5):
  420. if len(neg_inds) <= num_expected:
  421. return neg_inds
  422. else:
  423. # balance sampling for negative samples
  424. neg_set = set(neg_inds.tolist())
  425. if floor_thr > 0:
  426. floor_set = set(
  427. np.where(
  428. np.logical_and(max_overlaps >= 0, max_overlaps <
  429. floor_thr))[0])
  430. iou_sampling_set = set(np.where(max_overlaps >= floor_thr)[0])
  431. elif floor_thr == 0:
  432. floor_set = set(np.where(max_overlaps == 0)[0])
  433. iou_sampling_set = set(np.where(max_overlaps > floor_thr)[0])
  434. else:
  435. floor_set = set()
  436. iou_sampling_set = set(np.where(max_overlaps > floor_thr)[0])
  437. floor_thr = 0
  438. floor_neg_inds = list(floor_set & neg_set)
  439. iou_sampling_neg_inds = list(iou_sampling_set & neg_set)
  440. num_expected_iou_sampling = int(num_expected * (1 - floor_fraction))
  441. if len(iou_sampling_neg_inds) > num_expected_iou_sampling:
  442. if num_bins >= 2:
  443. iou_sampled_inds = libra_sample_via_interval(
  444. max_overlaps,
  445. set(iou_sampling_neg_inds), num_expected_iou_sampling,
  446. floor_thr, num_bins, bg_thresh)
  447. else:
  448. iou_sampled_inds = np.random.choice(
  449. iou_sampling_neg_inds,
  450. size=num_expected_iou_sampling,
  451. replace=False)
  452. else:
  453. iou_sampled_inds = np.array(iou_sampling_neg_inds, dtype=np.int)
  454. num_expected_floor = num_expected - len(iou_sampled_inds)
  455. if len(floor_neg_inds) > num_expected_floor:
  456. sampled_floor_inds = np.random.choice(
  457. floor_neg_inds, size=num_expected_floor, replace=False)
  458. else:
  459. sampled_floor_inds = np.array(floor_neg_inds, dtype=np.int)
  460. sampled_inds = np.concatenate((sampled_floor_inds, iou_sampled_inds))
  461. if len(sampled_inds) < num_expected:
  462. num_extra = num_expected - len(sampled_inds)
  463. extra_inds = np.array(list(neg_set - set(sampled_inds)))
  464. if len(extra_inds) > num_extra:
  465. extra_inds = np.random.choice(
  466. extra_inds, size=num_extra, replace=False)
  467. sampled_inds = np.concatenate((sampled_inds, extra_inds))
  468. return paddle.to_tensor(sampled_inds)
  469. def libra_label_box(anchors, gt_boxes, gt_classes, positive_overlap,
  470. negative_overlap, num_classes):
  471. # TODO: use paddle API to speed up
  472. gt_classes = gt_classes.numpy()
  473. gt_overlaps = np.zeros((anchors.shape[0], num_classes))
  474. matches = np.zeros((anchors.shape[0]), dtype=np.int32)
  475. if len(gt_boxes) > 0:
  476. proposal_to_gt_overlaps = bbox_overlaps(anchors, gt_boxes).numpy()
  477. overlaps_argmax = proposal_to_gt_overlaps.argmax(axis=1)
  478. overlaps_max = proposal_to_gt_overlaps.max(axis=1)
  479. # Boxes which with non-zero overlap with gt boxes
  480. overlapped_boxes_ind = np.where(overlaps_max > 0)[0]
  481. overlapped_boxes_gt_classes = gt_classes[overlaps_argmax[
  482. overlapped_boxes_ind]]
  483. for idx in range(len(overlapped_boxes_ind)):
  484. gt_overlaps[overlapped_boxes_ind[idx], overlapped_boxes_gt_classes[
  485. idx]] = overlaps_max[overlapped_boxes_ind[idx]]
  486. matches[overlapped_boxes_ind[idx]] = overlaps_argmax[
  487. overlapped_boxes_ind[idx]]
  488. gt_overlaps = paddle.to_tensor(gt_overlaps)
  489. matches = paddle.to_tensor(matches)
  490. matched_vals = paddle.max(gt_overlaps, axis=1)
  491. match_labels = paddle.full(matches.shape, -1, dtype='int32')
  492. match_labels = paddle.where(matched_vals < negative_overlap,
  493. paddle.zeros_like(match_labels), match_labels)
  494. match_labels = paddle.where(matched_vals >= positive_overlap,
  495. paddle.ones_like(match_labels), match_labels)
  496. return matches, match_labels, matched_vals
  497. def libra_sample_bbox(matches,
  498. match_labels,
  499. matched_vals,
  500. gt_classes,
  501. batch_size_per_im,
  502. num_classes,
  503. fg_fraction,
  504. fg_thresh,
  505. bg_thresh,
  506. num_bins,
  507. use_random=True,
  508. is_cascade_rcnn=False):
  509. rois_per_image = int(batch_size_per_im)
  510. fg_rois_per_im = int(np.round(fg_fraction * rois_per_image))
  511. bg_rois_per_im = rois_per_image - fg_rois_per_im
  512. if is_cascade_rcnn:
  513. fg_inds = paddle.nonzero(matched_vals >= fg_thresh)
  514. bg_inds = paddle.nonzero(matched_vals < bg_thresh)
  515. else:
  516. matched_vals_np = matched_vals.numpy()
  517. match_labels_np = match_labels.numpy()
  518. # sample fg
  519. fg_inds = paddle.nonzero(matched_vals >= fg_thresh).flatten()
  520. fg_nums = int(np.minimum(fg_rois_per_im, fg_inds.shape[0]))
  521. if (fg_inds.shape[0] > fg_nums) and use_random:
  522. fg_inds = libra_sample_pos(matched_vals_np, match_labels_np,
  523. fg_inds.numpy(), fg_rois_per_im)
  524. fg_inds = fg_inds[:fg_nums]
  525. # sample bg
  526. bg_inds = paddle.nonzero(matched_vals < bg_thresh).flatten()
  527. bg_nums = int(np.minimum(rois_per_image - fg_nums, bg_inds.shape[0]))
  528. if (bg_inds.shape[0] > bg_nums) and use_random:
  529. bg_inds = libra_sample_neg(
  530. matched_vals_np,
  531. match_labels_np,
  532. bg_inds.numpy(),
  533. bg_rois_per_im,
  534. num_bins=num_bins,
  535. bg_thresh=bg_thresh)
  536. bg_inds = bg_inds[:bg_nums]
  537. sampled_inds = paddle.concat([fg_inds, bg_inds])
  538. gt_classes = paddle.gather(gt_classes, matches)
  539. gt_classes = paddle.where(match_labels == 0,
  540. paddle.ones_like(gt_classes) * num_classes,
  541. gt_classes)
  542. gt_classes = paddle.where(match_labels == -1,
  543. paddle.ones_like(gt_classes) * -1,
  544. gt_classes)
  545. sampled_gt_classes = paddle.gather(gt_classes, sampled_inds)
  546. return sampled_inds, sampled_gt_classes
  547. def libra_generate_proposal_target(rpn_rois,
  548. gt_classes,
  549. gt_boxes,
  550. batch_size_per_im,
  551. fg_fraction,
  552. fg_thresh,
  553. bg_thresh,
  554. num_classes,
  555. use_random=True,
  556. is_cascade_rcnn=False,
  557. max_overlaps=None,
  558. num_bins=3):
  559. rois_with_gt = []
  560. tgt_labels = []
  561. tgt_bboxes = []
  562. sampled_max_overlaps = []
  563. tgt_gt_inds = []
  564. new_rois_num = []
  565. for i, rpn_roi in enumerate(rpn_rois):
  566. max_overlap = max_overlaps[i] if is_cascade_rcnn else None
  567. gt_bbox = gt_boxes[i]
  568. gt_class = paddle.squeeze(gt_classes[i], axis=-1)
  569. if is_cascade_rcnn:
  570. rpn_roi = filter_roi(rpn_roi, max_overlap)
  571. bbox = paddle.concat([rpn_roi, gt_bbox])
  572. # Step1: label bbox
  573. matches, match_labels, matched_vals = libra_label_box(
  574. bbox, gt_bbox, gt_class, fg_thresh, bg_thresh, num_classes)
  575. # Step2: sample bbox
  576. sampled_inds, sampled_gt_classes = libra_sample_bbox(
  577. matches, match_labels, matched_vals, gt_class, batch_size_per_im,
  578. num_classes, fg_fraction, fg_thresh, bg_thresh, num_bins,
  579. use_random, is_cascade_rcnn)
  580. # Step3: make output
  581. rois_per_image = paddle.gather(bbox, sampled_inds)
  582. sampled_gt_ind = paddle.gather(matches, sampled_inds)
  583. sampled_bbox = paddle.gather(gt_bbox, sampled_gt_ind)
  584. sampled_overlap = paddle.gather(matched_vals, sampled_inds)
  585. rois_per_image.stop_gradient = True
  586. sampled_gt_ind.stop_gradient = True
  587. sampled_bbox.stop_gradient = True
  588. sampled_overlap.stop_gradient = True
  589. tgt_labels.append(sampled_gt_classes)
  590. tgt_bboxes.append(sampled_bbox)
  591. rois_with_gt.append(rois_per_image)
  592. sampled_max_overlaps.append(sampled_overlap)
  593. tgt_gt_inds.append(sampled_gt_ind)
  594. new_rois_num.append(paddle.shape(sampled_inds)[0])
  595. new_rois_num = paddle.concat(new_rois_num)
  596. # rois_with_gt, tgt_labels, tgt_bboxes, tgt_gt_inds, new_rois_num
  597. return rois_with_gt, tgt_labels, tgt_bboxes, tgt_gt_inds, new_rois_num