batch_operators.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. try:
  18. from collections.abc import Sequence
  19. except Exception:
  20. from collections import Sequence
  21. import cv2
  22. import numpy as np
  23. from .operators import register_op, BaseOperator, Resize
  24. from .op_helper import jaccard_overlap, gaussian2D
  25. from scipy import ndimage
  26. from paddlex.ppdet.modeling import bbox_utils
  27. from paddlex.ppdet.utils.logger import setup_logger
  28. logger = setup_logger(__name__)
  29. __all__ = [
  30. 'PadBatch', 'BatchRandomResize', 'Gt2YoloTarget', 'Gt2FCOSTarget',
  31. 'Gt2TTFTarget', 'Gt2Solov2Target'
  32. ]
  33. @register_op
  34. class PadBatch(BaseOperator):
  35. """
  36. Pad a batch of samples so they can be divisible by a stride.
  37. The layout of each image should be 'CHW'.
  38. Args:
  39. pad_to_stride (int): If `pad_to_stride > 0`, pad zeros to ensure
  40. height and width is divisible by `pad_to_stride`.
  41. """
  42. def __init__(self, pad_to_stride=0):
  43. super(PadBatch, self).__init__()
  44. self.pad_to_stride = pad_to_stride
  45. def __call__(self, samples, context=None):
  46. """
  47. Args:
  48. samples (list): a batch of sample, each is dict.
  49. """
  50. coarsest_stride = self.pad_to_stride
  51. max_shape = np.array([data['image'].shape for data in samples]).max(
  52. axis=0)
  53. if coarsest_stride > 0:
  54. max_shape[1] = int(
  55. np.ceil(max_shape[1] / coarsest_stride) * coarsest_stride)
  56. max_shape[2] = int(
  57. np.ceil(max_shape[2] / coarsest_stride) * coarsest_stride)
  58. for data in samples:
  59. im = data['image']
  60. im_c, im_h, im_w = im.shape[:]
  61. padding_im = np.zeros(
  62. (im_c, max_shape[1], max_shape[2]), dtype=np.float32)
  63. padding_im[:, :im_h, :im_w] = im
  64. data['image'] = padding_im
  65. if 'semantic' in data and data['semantic'] is not None:
  66. semantic = data['semantic']
  67. padding_sem = np.zeros(
  68. (1, max_shape[1], max_shape[2]), dtype=np.float32)
  69. padding_sem[:, :im_h, :im_w] = semantic
  70. data['semantic'] = padding_sem
  71. if 'gt_segm' in data and data['gt_segm'] is not None:
  72. gt_segm = data['gt_segm']
  73. padding_segm = np.zeros(
  74. (gt_segm.shape[0], max_shape[1], max_shape[2]),
  75. dtype=np.uint8)
  76. padding_segm[:, :im_h, :im_w] = gt_segm
  77. data['gt_segm'] = padding_segm
  78. if 'gt_rbox2poly' in data and data['gt_rbox2poly'] is not None:
  79. # ploy to rbox
  80. polys = data['gt_rbox2poly']
  81. rbox = bbox_utils.poly2rbox(polys)
  82. data['gt_rbox'] = rbox
  83. return samples
  84. @register_op
  85. class BatchRandomResize(BaseOperator):
  86. """
  87. Resize image to target size randomly. random target_size and interpolation method
  88. Args:
  89. target_size (int, list, tuple): image target size, if random size is True, must be list or tuple
  90. keep_ratio (bool): whether keep_raio or not, default true
  91. interp (int): the interpolation method
  92. random_size (bool): whether random select target size of image
  93. random_interp (bool): whether random select interpolation method
  94. """
  95. def __init__(self,
  96. target_size,
  97. keep_ratio,
  98. interp=cv2.INTER_NEAREST,
  99. random_size=True,
  100. random_interp=False):
  101. super(BatchRandomResize, self).__init__()
  102. self.keep_ratio = keep_ratio
  103. self.interps = [
  104. cv2.INTER_NEAREST,
  105. cv2.INTER_LINEAR,
  106. cv2.INTER_AREA,
  107. cv2.INTER_CUBIC,
  108. cv2.INTER_LANCZOS4,
  109. ]
  110. self.interp = interp
  111. assert isinstance(target_size, (
  112. int, Sequence)), "target_size must be int, list or tuple"
  113. if random_size and not isinstance(target_size, list):
  114. raise TypeError(
  115. "Type of target_size is invalid when random_size is True. Must be List, now is {}".
  116. format(type(target_size)))
  117. self.target_size = target_size
  118. self.random_size = random_size
  119. self.random_interp = random_interp
  120. def __call__(self, samples, context=None):
  121. if self.random_size:
  122. index = np.random.choice(len(self.target_size))
  123. target_size = self.target_size[index]
  124. else:
  125. target_size = self.target_size
  126. if self.random_interp:
  127. interp = np.random.choice(self.interps)
  128. else:
  129. interp = self.interp
  130. resizer = Resize(
  131. target_size, keep_ratio=self.keep_ratio, interp=interp)
  132. return resizer(samples, context=context)
  133. @register_op
  134. class Gt2YoloTarget(BaseOperator):
  135. """
  136. Generate YOLOv3 targets by groud truth data, this operator is only used in
  137. fine grained YOLOv3 loss mode
  138. """
  139. def __init__(self,
  140. anchors,
  141. anchor_masks,
  142. downsample_ratios,
  143. num_classes=80,
  144. iou_thresh=1.):
  145. super(Gt2YoloTarget, self).__init__()
  146. self.anchors = anchors
  147. self.anchor_masks = anchor_masks
  148. self.downsample_ratios = downsample_ratios
  149. self.num_classes = num_classes
  150. self.iou_thresh = iou_thresh
  151. def __call__(self, samples, context=None):
  152. assert len(self.anchor_masks) == len(self.downsample_ratios), \
  153. "anchor_masks', and 'downsample_ratios' should have same length."
  154. h, w = samples[0]['image'].shape[1:3]
  155. an_hw = np.array(self.anchors) / np.array([[w, h]])
  156. for sample in samples:
  157. # im, gt_bbox, gt_class, gt_score = sample
  158. im = sample['image']
  159. gt_bbox = sample['gt_bbox']
  160. gt_class = sample['gt_class']
  161. if 'gt_score' not in sample:
  162. sample['gt_score'] = np.ones(
  163. (gt_bbox.shape[0], 1), dtype=np.float32)
  164. gt_score = sample['gt_score']
  165. for i, (
  166. mask, downsample_ratio
  167. ) in enumerate(zip(self.anchor_masks, self.downsample_ratios)):
  168. grid_h = int(h / downsample_ratio)
  169. grid_w = int(w / downsample_ratio)
  170. target = np.zeros(
  171. (len(mask), 6 + self.num_classes, grid_h, grid_w),
  172. dtype=np.float32)
  173. for b in range(gt_bbox.shape[0]):
  174. gx, gy, gw, gh = gt_bbox[b, :]
  175. cls = gt_class[b]
  176. score = gt_score[b]
  177. if gw <= 0. or gh <= 0. or score <= 0.:
  178. continue
  179. # find best match anchor index
  180. best_iou = 0.
  181. best_idx = -1
  182. for an_idx in range(an_hw.shape[0]):
  183. iou = jaccard_overlap(
  184. [0., 0., gw, gh],
  185. [0., 0., an_hw[an_idx, 0], an_hw[an_idx, 1]])
  186. if iou > best_iou:
  187. best_iou = iou
  188. best_idx = an_idx
  189. gi = int(gx * grid_w)
  190. gj = int(gy * grid_h)
  191. # gtbox should be regresed in this layes if best match
  192. # anchor index in anchor mask of this layer
  193. if best_idx in mask:
  194. best_n = mask.index(best_idx)
  195. # x, y, w, h, scale
  196. target[best_n, 0, gj, gi] = gx * grid_w - gi
  197. target[best_n, 1, gj, gi] = gy * grid_h - gj
  198. target[best_n, 2, gj, gi] = np.log(
  199. gw * w / self.anchors[best_idx][0])
  200. target[best_n, 3, gj, gi] = np.log(
  201. gh * h / self.anchors[best_idx][1])
  202. target[best_n, 4, gj, gi] = 2.0 - gw * gh
  203. # objectness record gt_score
  204. target[best_n, 5, gj, gi] = score
  205. # classification
  206. target[best_n, 6 + cls, gj, gi] = 1.
  207. # For non-matched anchors, calculate the target if the iou
  208. # between anchor and gt is larger than iou_thresh
  209. if self.iou_thresh < 1:
  210. for idx, mask_i in enumerate(mask):
  211. if mask_i == best_idx: continue
  212. iou = jaccard_overlap(
  213. [0., 0., gw, gh],
  214. [0., 0., an_hw[mask_i, 0], an_hw[mask_i, 1]])
  215. if iou > self.iou_thresh and target[idx, 5, gj,
  216. gi] == 0.:
  217. # x, y, w, h, scale
  218. target[idx, 0, gj, gi] = gx * grid_w - gi
  219. target[idx, 1, gj, gi] = gy * grid_h - gj
  220. target[idx, 2, gj, gi] = np.log(
  221. gw * w / self.anchors[mask_i][0])
  222. target[idx, 3, gj, gi] = np.log(
  223. gh * h / self.anchors[mask_i][1])
  224. target[idx, 4, gj, gi] = 2.0 - gw * gh
  225. # objectness record gt_score
  226. target[idx, 5, gj, gi] = score
  227. # classification
  228. target[idx, 6 + cls, gj, gi] = 1.
  229. sample['target{}'.format(i)] = target
  230. # remove useless gt_class and gt_score after target calculated
  231. sample.pop('gt_class')
  232. sample.pop('gt_score')
  233. return samples
  234. @register_op
  235. class Gt2FCOSTarget(BaseOperator):
  236. """
  237. Generate FCOS targets by groud truth data
  238. """
  239. def __init__(self,
  240. object_sizes_boundary,
  241. center_sampling_radius,
  242. downsample_ratios,
  243. norm_reg_targets=False):
  244. super(Gt2FCOSTarget, self).__init__()
  245. self.center_sampling_radius = center_sampling_radius
  246. self.downsample_ratios = downsample_ratios
  247. self.INF = np.inf
  248. self.object_sizes_boundary = [-1] + object_sizes_boundary + [self.INF]
  249. object_sizes_of_interest = []
  250. for i in range(len(self.object_sizes_boundary) - 1):
  251. object_sizes_of_interest.append([
  252. self.object_sizes_boundary[i],
  253. self.object_sizes_boundary[i + 1]
  254. ])
  255. self.object_sizes_of_interest = object_sizes_of_interest
  256. self.norm_reg_targets = norm_reg_targets
  257. def _compute_points(self, w, h):
  258. """
  259. compute the corresponding points in each feature map
  260. :param h: image height
  261. :param w: image width
  262. :return: points from all feature map
  263. """
  264. locations = []
  265. for stride in self.downsample_ratios:
  266. shift_x = np.arange(0, w, stride).astype(np.float32)
  267. shift_y = np.arange(0, h, stride).astype(np.float32)
  268. shift_x, shift_y = np.meshgrid(shift_x, shift_y)
  269. shift_x = shift_x.flatten()
  270. shift_y = shift_y.flatten()
  271. location = np.stack([shift_x, shift_y], axis=1) + stride // 2
  272. locations.append(location)
  273. num_points_each_level = [len(location) for location in locations]
  274. locations = np.concatenate(locations, axis=0)
  275. return locations, num_points_each_level
  276. def _convert_xywh2xyxy(self, gt_bbox, w, h):
  277. """
  278. convert the bounding box from style xywh to xyxy
  279. :param gt_bbox: bounding boxes normalized into [0, 1]
  280. :param w: image width
  281. :param h: image height
  282. :return: bounding boxes in xyxy style
  283. """
  284. bboxes = gt_bbox.copy()
  285. bboxes[:, [0, 2]] = bboxes[:, [0, 2]] * w
  286. bboxes[:, [1, 3]] = bboxes[:, [1, 3]] * h
  287. bboxes[:, 2] = bboxes[:, 0] + bboxes[:, 2]
  288. bboxes[:, 3] = bboxes[:, 1] + bboxes[:, 3]
  289. return bboxes
  290. def _check_inside_boxes_limited(self, gt_bbox, xs, ys,
  291. num_points_each_level):
  292. """
  293. check if points is within the clipped boxes
  294. :param gt_bbox: bounding boxes
  295. :param xs: horizontal coordinate of points
  296. :param ys: vertical coordinate of points
  297. :return: the mask of points is within gt_box or not
  298. """
  299. bboxes = np.reshape(
  300. gt_bbox, newshape=[1, gt_bbox.shape[0], gt_bbox.shape[1]])
  301. bboxes = np.tile(bboxes, reps=[xs.shape[0], 1, 1])
  302. ct_x = (bboxes[:, :, 0] + bboxes[:, :, 2]) / 2
  303. ct_y = (bboxes[:, :, 1] + bboxes[:, :, 3]) / 2
  304. beg = 0
  305. clipped_box = bboxes.copy()
  306. for lvl, stride in enumerate(self.downsample_ratios):
  307. end = beg + num_points_each_level[lvl]
  308. stride_exp = self.center_sampling_radius * stride
  309. clipped_box[beg:end, :, 0] = np.maximum(
  310. bboxes[beg:end, :, 0], ct_x[beg:end, :] - stride_exp)
  311. clipped_box[beg:end, :, 1] = np.maximum(
  312. bboxes[beg:end, :, 1], ct_y[beg:end, :] - stride_exp)
  313. clipped_box[beg:end, :, 2] = np.minimum(
  314. bboxes[beg:end, :, 2], ct_x[beg:end, :] + stride_exp)
  315. clipped_box[beg:end, :, 3] = np.minimum(
  316. bboxes[beg:end, :, 3], ct_y[beg:end, :] + stride_exp)
  317. beg = end
  318. l_res = xs - clipped_box[:, :, 0]
  319. r_res = clipped_box[:, :, 2] - xs
  320. t_res = ys - clipped_box[:, :, 1]
  321. b_res = clipped_box[:, :, 3] - ys
  322. clipped_box_reg_targets = np.stack(
  323. [l_res, t_res, r_res, b_res], axis=2)
  324. inside_gt_box = np.min(clipped_box_reg_targets, axis=2) > 0
  325. return inside_gt_box
  326. def __call__(self, samples, context=None):
  327. assert len(self.object_sizes_of_interest) == len(self.downsample_ratios), \
  328. "object_sizes_of_interest', and 'downsample_ratios' should have same length."
  329. for sample in samples:
  330. # im, gt_bbox, gt_class, gt_score = sample
  331. im = sample['image']
  332. bboxes = sample['gt_bbox']
  333. gt_class = sample['gt_class']
  334. # calculate the locations
  335. h, w = im.shape[1:3]
  336. points, num_points_each_level = self._compute_points(w, h)
  337. object_scale_exp = []
  338. for i, num_pts in enumerate(num_points_each_level):
  339. object_scale_exp.append(
  340. np.tile(
  341. np.array([self.object_sizes_of_interest[i]]),
  342. reps=[num_pts, 1]))
  343. object_scale_exp = np.concatenate(object_scale_exp, axis=0)
  344. gt_area = (bboxes[:, 2] - bboxes[:, 0]) * (
  345. bboxes[:, 3] - bboxes[:, 1])
  346. xs, ys = points[:, 0], points[:, 1]
  347. xs = np.reshape(xs, newshape=[xs.shape[0], 1])
  348. xs = np.tile(xs, reps=[1, bboxes.shape[0]])
  349. ys = np.reshape(ys, newshape=[ys.shape[0], 1])
  350. ys = np.tile(ys, reps=[1, bboxes.shape[0]])
  351. l_res = xs - bboxes[:, 0]
  352. r_res = bboxes[:, 2] - xs
  353. t_res = ys - bboxes[:, 1]
  354. b_res = bboxes[:, 3] - ys
  355. reg_targets = np.stack([l_res, t_res, r_res, b_res], axis=2)
  356. if self.center_sampling_radius > 0:
  357. is_inside_box = self._check_inside_boxes_limited(
  358. bboxes, xs, ys, num_points_each_level)
  359. else:
  360. is_inside_box = np.min(reg_targets, axis=2) > 0
  361. # check if the targets is inside the corresponding level
  362. max_reg_targets = np.max(reg_targets, axis=2)
  363. lower_bound = np.tile(
  364. np.expand_dims(
  365. object_scale_exp[:, 0], axis=1),
  366. reps=[1, max_reg_targets.shape[1]])
  367. high_bound = np.tile(
  368. np.expand_dims(
  369. object_scale_exp[:, 1], axis=1),
  370. reps=[1, max_reg_targets.shape[1]])
  371. is_match_current_level = \
  372. (max_reg_targets > lower_bound) & \
  373. (max_reg_targets < high_bound)
  374. points2gtarea = np.tile(
  375. np.expand_dims(
  376. gt_area, axis=0), reps=[xs.shape[0], 1])
  377. points2gtarea[is_inside_box == 0] = self.INF
  378. points2gtarea[is_match_current_level == 0] = self.INF
  379. points2min_area = points2gtarea.min(axis=1)
  380. points2min_area_ind = points2gtarea.argmin(axis=1)
  381. labels = gt_class[points2min_area_ind] + 1
  382. labels[points2min_area == self.INF] = 0
  383. reg_targets = reg_targets[range(xs.shape[0]), points2min_area_ind]
  384. ctn_targets = np.sqrt((reg_targets[:, [0, 2]].min(axis=1) / \
  385. reg_targets[:, [0, 2]].max(axis=1)) * \
  386. (reg_targets[:, [1, 3]].min(axis=1) / \
  387. reg_targets[:, [1, 3]].max(axis=1))).astype(np.float32)
  388. ctn_targets = np.reshape(
  389. ctn_targets, newshape=[ctn_targets.shape[0], 1])
  390. ctn_targets[labels <= 0] = 0
  391. pos_ind = np.nonzero(labels != 0)
  392. reg_targets_pos = reg_targets[pos_ind[0], :]
  393. split_sections = []
  394. beg = 0
  395. for lvl in range(len(num_points_each_level)):
  396. end = beg + num_points_each_level[lvl]
  397. split_sections.append(end)
  398. beg = end
  399. labels_by_level = np.split(labels, split_sections, axis=0)
  400. reg_targets_by_level = np.split(
  401. reg_targets, split_sections, axis=0)
  402. ctn_targets_by_level = np.split(
  403. ctn_targets, split_sections, axis=0)
  404. for lvl in range(len(self.downsample_ratios)):
  405. grid_w = int(np.ceil(w / self.downsample_ratios[lvl]))
  406. grid_h = int(np.ceil(h / self.downsample_ratios[lvl]))
  407. if self.norm_reg_targets:
  408. sample['reg_target{}'.format(lvl)] = \
  409. np.reshape(
  410. reg_targets_by_level[lvl] / \
  411. self.downsample_ratios[lvl],
  412. newshape=[grid_h, grid_w, 4])
  413. else:
  414. sample['reg_target{}'.format(lvl)] = np.reshape(
  415. reg_targets_by_level[lvl],
  416. newshape=[grid_h, grid_w, 4])
  417. sample['labels{}'.format(lvl)] = np.reshape(
  418. labels_by_level[lvl], newshape=[grid_h, grid_w, 1])
  419. sample['centerness{}'.format(lvl)] = np.reshape(
  420. ctn_targets_by_level[lvl], newshape=[grid_h, grid_w, 1])
  421. sample.pop('is_crowd', None)
  422. sample.pop('difficult', None)
  423. sample.pop('gt_class', None)
  424. sample.pop('gt_bbox', None)
  425. return samples
  426. @register_op
  427. class Gt2TTFTarget(BaseOperator):
  428. __shared__ = ['num_classes']
  429. """
  430. Gt2TTFTarget
  431. Generate TTFNet targets by ground truth data
  432. Args:
  433. num_classes(int): the number of classes.
  434. down_ratio(int): the down ratio from images to heatmap, 4 by default.
  435. alpha(float): the alpha parameter to generate gaussian target.
  436. 0.54 by default.
  437. """
  438. def __init__(self, num_classes=80, down_ratio=4, alpha=0.54):
  439. super(Gt2TTFTarget, self).__init__()
  440. self.down_ratio = down_ratio
  441. self.num_classes = num_classes
  442. self.alpha = alpha
  443. def __call__(self, samples, context=None):
  444. output_size = samples[0]['image'].shape[1]
  445. feat_size = output_size // self.down_ratio
  446. for sample in samples:
  447. heatmap = np.zeros(
  448. (self.num_classes, feat_size, feat_size), dtype='float32')
  449. box_target = np.ones(
  450. (4, feat_size, feat_size), dtype='float32') * -1
  451. reg_weight = np.zeros((1, feat_size, feat_size), dtype='float32')
  452. gt_bbox = sample['gt_bbox']
  453. gt_class = sample['gt_class']
  454. bbox_w = gt_bbox[:, 2] - gt_bbox[:, 0] + 1
  455. bbox_h = gt_bbox[:, 3] - gt_bbox[:, 1] + 1
  456. area = bbox_w * bbox_h
  457. boxes_areas_log = np.log(area)
  458. boxes_ind = np.argsort(boxes_areas_log, axis=0)[::-1]
  459. boxes_area_topk_log = boxes_areas_log[boxes_ind]
  460. gt_bbox = gt_bbox[boxes_ind]
  461. gt_class = gt_class[boxes_ind]
  462. feat_gt_bbox = gt_bbox / self.down_ratio
  463. feat_gt_bbox = np.clip(feat_gt_bbox, 0, feat_size - 1)
  464. feat_hs, feat_ws = (feat_gt_bbox[:, 3] - feat_gt_bbox[:, 1],
  465. feat_gt_bbox[:, 2] - feat_gt_bbox[:, 0])
  466. ct_inds = np.stack(
  467. [(gt_bbox[:, 0] + gt_bbox[:, 2]) / 2,
  468. (gt_bbox[:, 1] + gt_bbox[:, 3]) / 2],
  469. axis=1) / self.down_ratio
  470. h_radiuses_alpha = (feat_hs / 2. * self.alpha).astype('int32')
  471. w_radiuses_alpha = (feat_ws / 2. * self.alpha).astype('int32')
  472. for k in range(len(gt_bbox)):
  473. cls_id = gt_class[k]
  474. fake_heatmap = np.zeros(
  475. (feat_size, feat_size), dtype='float32')
  476. self.draw_truncate_gaussian(fake_heatmap, ct_inds[k],
  477. h_radiuses_alpha[k],
  478. w_radiuses_alpha[k])
  479. heatmap[cls_id] = np.maximum(heatmap[cls_id], fake_heatmap)
  480. box_target_inds = fake_heatmap > 0
  481. box_target[:, box_target_inds] = gt_bbox[k][:, None]
  482. local_heatmap = fake_heatmap[box_target_inds]
  483. ct_div = np.sum(local_heatmap)
  484. local_heatmap *= boxes_area_topk_log[k]
  485. reg_weight[0, box_target_inds] = local_heatmap / ct_div
  486. sample['ttf_heatmap'] = heatmap
  487. sample['ttf_box_target'] = box_target
  488. sample['ttf_reg_weight'] = reg_weight
  489. sample.pop('is_crowd', None)
  490. sample.pop('difficult', None)
  491. sample.pop('gt_class', None)
  492. sample.pop('gt_bbox', None)
  493. sample.pop('gt_score', None)
  494. return samples
  495. def draw_truncate_gaussian(self, heatmap, center, h_radius, w_radius):
  496. h, w = 2 * h_radius + 1, 2 * w_radius + 1
  497. sigma_x = w / 6
  498. sigma_y = h / 6
  499. gaussian = gaussian2D((h, w), sigma_x, sigma_y)
  500. x, y = int(center[0]), int(center[1])
  501. height, width = heatmap.shape[0:2]
  502. left, right = min(x, w_radius), min(width - x, w_radius + 1)
  503. top, bottom = min(y, h_radius), min(height - y, h_radius + 1)
  504. masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
  505. masked_gaussian = gaussian[h_radius - top:h_radius + bottom, w_radius -
  506. left:w_radius + right]
  507. if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0:
  508. heatmap[y - top:y + bottom, x - left:x + right] = np.maximum(
  509. masked_heatmap, masked_gaussian)
  510. return heatmap
  511. @register_op
  512. class Gt2Solov2Target(BaseOperator):
  513. """Assign mask target and labels in SOLOv2 network.
  514. Args:
  515. num_grids (list): The list of feature map grids size.
  516. scale_ranges (list): The list of mask boundary range.
  517. coord_sigma (float): The coefficient of coordinate area length.
  518. sampling_ratio (float): The ratio of down sampling.
  519. """
  520. def __init__(self,
  521. num_grids=[40, 36, 24, 16, 12],
  522. scale_ranges=[[1, 96], [48, 192], [96, 384], [192, 768],
  523. [384, 2048]],
  524. coord_sigma=0.2,
  525. sampling_ratio=4.0):
  526. super(Gt2Solov2Target, self).__init__()
  527. self.num_grids = num_grids
  528. self.scale_ranges = scale_ranges
  529. self.coord_sigma = coord_sigma
  530. self.sampling_ratio = sampling_ratio
  531. def _scale_size(self, im, scale):
  532. h, w = im.shape[:2]
  533. new_size = (int(w * float(scale) + 0.5), int(h * float(scale) + 0.5))
  534. resized_img = cv2.resize(
  535. im, None, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
  536. return resized_img
  537. def __call__(self, samples, context=None):
  538. sample_id = 0
  539. max_ins_num = [0] * len(self.num_grids)
  540. for sample in samples:
  541. gt_bboxes_raw = sample['gt_bbox']
  542. gt_labels_raw = sample['gt_class'] + 1
  543. im_c, im_h, im_w = sample['image'].shape[:]
  544. gt_masks_raw = sample['gt_segm'].astype(np.uint8)
  545. mask_feat_size = [
  546. int(im_h / self.sampling_ratio),
  547. int(im_w / self.sampling_ratio)
  548. ]
  549. gt_areas = np.sqrt((gt_bboxes_raw[:, 2] - gt_bboxes_raw[:, 0]) *
  550. (gt_bboxes_raw[:, 3] - gt_bboxes_raw[:, 1]))
  551. ins_ind_label_list = []
  552. idx = 0
  553. for (lower_bound, upper_bound), num_grid \
  554. in zip(self.scale_ranges, self.num_grids):
  555. hit_indices = ((gt_areas >= lower_bound) &
  556. (gt_areas <= upper_bound)).nonzero()[0]
  557. num_ins = len(hit_indices)
  558. ins_label = []
  559. grid_order = []
  560. cate_label = np.zeros([num_grid, num_grid], dtype=np.int64)
  561. ins_ind_label = np.zeros([num_grid**2], dtype=np.bool)
  562. if num_ins == 0:
  563. ins_label = np.zeros(
  564. [1, mask_feat_size[0], mask_feat_size[1]],
  565. dtype=np.uint8)
  566. ins_ind_label_list.append(ins_ind_label)
  567. sample['cate_label{}'.format(idx)] = cate_label.flatten()
  568. sample['ins_label{}'.format(idx)] = ins_label
  569. sample['grid_order{}'.format(idx)] = np.asarray(
  570. [sample_id * num_grid * num_grid + 0], dtype=np.int32)
  571. idx += 1
  572. continue
  573. gt_bboxes = gt_bboxes_raw[hit_indices]
  574. gt_labels = gt_labels_raw[hit_indices]
  575. gt_masks = gt_masks_raw[hit_indices, ...]
  576. half_ws = 0.5 * (
  577. gt_bboxes[:, 2] - gt_bboxes[:, 0]) * self.coord_sigma
  578. half_hs = 0.5 * (
  579. gt_bboxes[:, 3] - gt_bboxes[:, 1]) * self.coord_sigma
  580. for seg_mask, gt_label, half_h, half_w in zip(
  581. gt_masks, gt_labels, half_hs, half_ws):
  582. if seg_mask.sum() == 0:
  583. continue
  584. # mass center
  585. upsampled_size = (mask_feat_size[0] * 4,
  586. mask_feat_size[1] * 4)
  587. center_h, center_w = ndimage.measurements.center_of_mass(
  588. seg_mask)
  589. coord_w = int(
  590. (center_w / upsampled_size[1]) // (1. / num_grid))
  591. coord_h = int(
  592. (center_h / upsampled_size[0]) // (1. / num_grid))
  593. # left, top, right, down
  594. top_box = max(0,
  595. int(((center_h - half_h) / upsampled_size[0])
  596. // (1. / num_grid)))
  597. down_box = min(
  598. num_grid - 1,
  599. int(((center_h + half_h) / upsampled_size[0]) //
  600. (1. / num_grid)))
  601. left_box = max(
  602. 0,
  603. int(((center_w - half_w) / upsampled_size[1]) //
  604. (1. / num_grid)))
  605. right_box = min(num_grid - 1,
  606. int(((center_w + half_w) /
  607. upsampled_size[1]) //
  608. (1. / num_grid)))
  609. top = max(top_box, coord_h - 1)
  610. down = min(down_box, coord_h + 1)
  611. left = max(coord_w - 1, left_box)
  612. right = min(right_box, coord_w + 1)
  613. cate_label[top:(down + 1), left:(right + 1)] = gt_label
  614. seg_mask = self._scale_size(
  615. seg_mask, scale=1. / self.sampling_ratio)
  616. for i in range(top, down + 1):
  617. for j in range(left, right + 1):
  618. label = int(i * num_grid + j)
  619. cur_ins_label = np.zeros(
  620. [mask_feat_size[0], mask_feat_size[1]],
  621. dtype=np.uint8)
  622. cur_ins_label[:seg_mask.shape[0], :seg_mask.shape[
  623. 1]] = seg_mask
  624. ins_label.append(cur_ins_label)
  625. ins_ind_label[label] = True
  626. grid_order.append(sample_id * num_grid * num_grid +
  627. label)
  628. if ins_label == []:
  629. ins_label = np.zeros(
  630. [1, mask_feat_size[0], mask_feat_size[1]],
  631. dtype=np.uint8)
  632. ins_ind_label_list.append(ins_ind_label)
  633. sample['cate_label{}'.format(idx)] = cate_label.flatten()
  634. sample['ins_label{}'.format(idx)] = ins_label
  635. sample['grid_order{}'.format(idx)] = np.asarray(
  636. [sample_id * num_grid * num_grid + 0], dtype=np.int32)
  637. else:
  638. ins_label = np.stack(ins_label, axis=0)
  639. ins_ind_label_list.append(ins_ind_label)
  640. sample['cate_label{}'.format(idx)] = cate_label.flatten()
  641. sample['ins_label{}'.format(idx)] = ins_label
  642. sample['grid_order{}'.format(idx)] = np.asarray(
  643. grid_order, dtype=np.int32)
  644. assert len(grid_order) > 0
  645. max_ins_num[idx] = max(
  646. max_ins_num[idx],
  647. sample['ins_label{}'.format(idx)].shape[0])
  648. idx += 1
  649. ins_ind_labels = np.concatenate([
  650. ins_ind_labels_level_img
  651. for ins_ind_labels_level_img in ins_ind_label_list
  652. ])
  653. fg_num = np.sum(ins_ind_labels)
  654. sample['fg_num'] = fg_num
  655. sample_id += 1
  656. sample.pop('is_crowd')
  657. sample.pop('gt_class')
  658. sample.pop('gt_bbox')
  659. sample.pop('gt_poly')
  660. sample.pop('gt_segm')
  661. # padding batch
  662. for data in samples:
  663. for idx in range(len(self.num_grids)):
  664. gt_ins_data = np.zeros(
  665. [
  666. max_ins_num[idx],
  667. data['ins_label{}'.format(idx)].shape[1],
  668. data['ins_label{}'.format(idx)].shape[2]
  669. ],
  670. dtype=np.uint8)
  671. gt_ins_data[0:data['ins_label{}'.format(idx)].shape[
  672. 0], :, :] = data['ins_label{}'.format(idx)]
  673. gt_grid_order = np.zeros([max_ins_num[idx]], dtype=np.int32)
  674. gt_grid_order[0:data['grid_order{}'.format(idx)].shape[
  675. 0]] = data['grid_order{}'.format(idx)]
  676. data['ins_label{}'.format(idx)] = gt_ins_data
  677. data['grid_order{}'.format(idx)] = gt_grid_order
  678. return samples