batch_operators.py 44 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import typing
  18. try:
  19. from collections.abc import Sequence
  20. except Exception:
  21. from collections import Sequence
  22. import cv2
  23. import math
  24. import numpy as np
  25. from .operators import register_op, BaseOperator, Resize
  26. from .op_helper import jaccard_overlap, gaussian2D, gaussian_radius, draw_umich_gaussian
  27. from .atss_assigner import ATSSAssigner
  28. from scipy import ndimage
  29. from paddlex.ppdet.modeling import bbox_utils
  30. from paddlex.ppdet.utils.logger import setup_logger
  31. from paddlex.ppdet.modeling.keypoint_utils import get_affine_transform, affine_transform
  32. logger = setup_logger(__name__)
  33. __all__ = [
  34. 'PadBatch',
  35. 'BatchRandomResize',
  36. 'Gt2YoloTarget',
  37. 'Gt2FCOSTarget',
  38. 'Gt2TTFTarget',
  39. 'Gt2Solov2Target',
  40. 'Gt2SparseRCNNTarget',
  41. 'PadMaskBatch',
  42. 'Gt2GFLTarget',
  43. 'Gt2CenterNetTarget',
  44. ]
  45. @register_op
  46. class PadBatch(BaseOperator):
  47. """
  48. Pad a batch of samples so they can be divisible by a stride.
  49. The layout of each image should be 'CHW'.
  50. Args:
  51. pad_to_stride (int): If `pad_to_stride > 0`, pad zeros to ensure
  52. height and width is divisible by `pad_to_stride`.
  53. """
  54. def __init__(self, pad_to_stride=0):
  55. super(PadBatch, self).__init__()
  56. self.pad_to_stride = pad_to_stride
  57. def __call__(self, samples, context=None):
  58. """
  59. Args:
  60. samples (list): a batch of sample, each is dict.
  61. """
  62. coarsest_stride = self.pad_to_stride
  63. # multi scale input is nested list
  64. if isinstance(samples,
  65. typing.Sequence) and len(samples) > 0 and isinstance(
  66. samples[0], typing.Sequence):
  67. inner_samples = samples[0]
  68. else:
  69. inner_samples = samples
  70. max_shape = np.array(
  71. [data['image'].shape for data in inner_samples]).max(axis=0)
  72. if coarsest_stride > 0:
  73. max_shape[1] = int(
  74. np.ceil(max_shape[1] / coarsest_stride) * coarsest_stride)
  75. max_shape[2] = int(
  76. np.ceil(max_shape[2] / coarsest_stride) * coarsest_stride)
  77. for data in inner_samples:
  78. im = data['image']
  79. im_c, im_h, im_w = im.shape[:]
  80. padding_im = np.zeros(
  81. (im_c, max_shape[1], max_shape[2]), dtype=np.float32)
  82. padding_im[:, :im_h, :im_w] = im
  83. data['image'] = padding_im
  84. if 'semantic' in data and data['semantic'] is not None:
  85. semantic = data['semantic']
  86. padding_sem = np.zeros(
  87. (1, max_shape[1], max_shape[2]), dtype=np.float32)
  88. padding_sem[:, :im_h, :im_w] = semantic
  89. data['semantic'] = padding_sem
  90. if 'gt_segm' in data and data['gt_segm'] is not None:
  91. gt_segm = data['gt_segm']
  92. padding_segm = np.zeros(
  93. (gt_segm.shape[0], max_shape[1], max_shape[2]),
  94. dtype=np.uint8)
  95. padding_segm[:, :im_h, :im_w] = gt_segm
  96. data['gt_segm'] = padding_segm
  97. if 'gt_rbox2poly' in data and data['gt_rbox2poly'] is not None:
  98. # ploy to rbox
  99. polys = data['gt_rbox2poly']
  100. rbox = bbox_utils.poly2rbox(polys)
  101. data['gt_rbox'] = rbox
  102. return samples
  103. @register_op
  104. class BatchRandomResize(BaseOperator):
  105. """
  106. Resize image to target size randomly. random target_size and interpolation method
  107. Args:
  108. target_size (int, list, tuple): image target size, if random size is True, must be list or tuple
  109. keep_ratio (bool): whether keep_raio or not, default true
  110. interp (int): the interpolation method
  111. random_size (bool): whether random select target size of image
  112. random_interp (bool): whether random select interpolation method
  113. """
  114. def __init__(self,
  115. target_size,
  116. keep_ratio,
  117. interp=cv2.INTER_NEAREST,
  118. random_size=True,
  119. random_interp=False):
  120. super(BatchRandomResize, self).__init__()
  121. self.keep_ratio = keep_ratio
  122. self.interps = [
  123. cv2.INTER_NEAREST,
  124. cv2.INTER_LINEAR,
  125. cv2.INTER_AREA,
  126. cv2.INTER_CUBIC,
  127. cv2.INTER_LANCZOS4,
  128. ]
  129. self.interp = interp
  130. assert isinstance(target_size, (
  131. int, Sequence)), "target_size must be int, list or tuple"
  132. if random_size and not isinstance(target_size, list):
  133. raise TypeError(
  134. "Type of target_size is invalid when random_size is True. Must be List, now is {}".
  135. format(type(target_size)))
  136. self.target_size = target_size
  137. self.random_size = random_size
  138. self.random_interp = random_interp
  139. def __call__(self, samples, context=None):
  140. if self.random_size:
  141. index = np.random.choice(len(self.target_size))
  142. target_size = self.target_size[index]
  143. else:
  144. target_size = self.target_size
  145. if self.random_interp:
  146. interp = np.random.choice(self.interps)
  147. else:
  148. interp = self.interp
  149. resizer = Resize(
  150. target_size, keep_ratio=self.keep_ratio, interp=interp)
  151. return resizer(samples, context=context)
  152. @register_op
  153. class Gt2YoloTarget(BaseOperator):
  154. """
  155. Generate YOLOv3 targets by groud truth data, this operator is only used in
  156. fine grained YOLOv3 loss mode
  157. """
  158. def __init__(self,
  159. anchors,
  160. anchor_masks,
  161. downsample_ratios,
  162. num_classes=80,
  163. iou_thresh=1.):
  164. super(Gt2YoloTarget, self).__init__()
  165. self.anchors = anchors
  166. self.anchor_masks = anchor_masks
  167. self.downsample_ratios = downsample_ratios
  168. self.num_classes = num_classes
  169. self.iou_thresh = iou_thresh
  170. def __call__(self, samples, context=None):
  171. assert len(self.anchor_masks) == len(self.downsample_ratios), \
  172. "anchor_masks', and 'downsample_ratios' should have same length."
  173. h, w = samples[0]['image'].shape[1:3]
  174. an_hw = np.array(self.anchors) / np.array([[w, h]])
  175. for sample in samples:
  176. gt_bbox = sample['gt_bbox']
  177. gt_class = sample['gt_class']
  178. if 'gt_score' not in sample:
  179. sample['gt_score'] = np.ones(
  180. (gt_bbox.shape[0], 1), dtype=np.float32)
  181. gt_score = sample['gt_score']
  182. for i, (
  183. mask, downsample_ratio
  184. ) in enumerate(zip(self.anchor_masks, self.downsample_ratios)):
  185. grid_h = int(h / downsample_ratio)
  186. grid_w = int(w / downsample_ratio)
  187. target = np.zeros(
  188. (len(mask), 6 + self.num_classes, grid_h, grid_w),
  189. dtype=np.float32)
  190. for b in range(gt_bbox.shape[0]):
  191. gx, gy, gw, gh = gt_bbox[b, :]
  192. cls = gt_class[b]
  193. score = gt_score[b]
  194. if gw <= 0. or gh <= 0. or score <= 0.:
  195. continue
  196. # find best match anchor index
  197. best_iou = 0.
  198. best_idx = -1
  199. for an_idx in range(an_hw.shape[0]):
  200. iou = jaccard_overlap(
  201. [0., 0., gw, gh],
  202. [0., 0., an_hw[an_idx, 0], an_hw[an_idx, 1]])
  203. if iou > best_iou:
  204. best_iou = iou
  205. best_idx = an_idx
  206. gi = int(gx * grid_w)
  207. gj = int(gy * grid_h)
  208. # gtbox should be regresed in this layes if best match
  209. # anchor index in anchor mask of this layer
  210. if best_idx in mask:
  211. best_n = mask.index(best_idx)
  212. # x, y, w, h, scale
  213. target[best_n, 0, gj, gi] = gx * grid_w - gi
  214. target[best_n, 1, gj, gi] = gy * grid_h - gj
  215. target[best_n, 2, gj, gi] = np.log(
  216. gw * w / self.anchors[best_idx][0])
  217. target[best_n, 3, gj, gi] = np.log(
  218. gh * h / self.anchors[best_idx][1])
  219. target[best_n, 4, gj, gi] = 2.0 - gw * gh
  220. # objectness record gt_score
  221. target[best_n, 5, gj, gi] = score
  222. # classification
  223. target[best_n, 6 + cls, gj, gi] = 1.
  224. # For non-matched anchors, calculate the target if the iou
  225. # between anchor and gt is larger than iou_thresh
  226. if self.iou_thresh < 1:
  227. for idx, mask_i in enumerate(mask):
  228. if mask_i == best_idx: continue
  229. iou = jaccard_overlap(
  230. [0., 0., gw, gh],
  231. [0., 0., an_hw[mask_i, 0], an_hw[mask_i, 1]])
  232. if iou > self.iou_thresh and target[idx, 5, gj,
  233. gi] == 0.:
  234. # x, y, w, h, scale
  235. target[idx, 0, gj, gi] = gx * grid_w - gi
  236. target[idx, 1, gj, gi] = gy * grid_h - gj
  237. target[idx, 2, gj, gi] = np.log(
  238. gw * w / self.anchors[mask_i][0])
  239. target[idx, 3, gj, gi] = np.log(
  240. gh * h / self.anchors[mask_i][1])
  241. target[idx, 4, gj, gi] = 2.0 - gw * gh
  242. # objectness record gt_score
  243. target[idx, 5, gj, gi] = score
  244. # classification
  245. target[idx, 6 + cls, gj, gi] = 1.
  246. sample['target{}'.format(i)] = target
  247. # remove useless gt_class and gt_score after target calculated
  248. sample.pop('gt_class')
  249. sample.pop('gt_score')
  250. return samples
  251. @register_op
  252. class Gt2FCOSTarget(BaseOperator):
  253. """
  254. Generate FCOS targets by groud truth data
  255. """
  256. def __init__(self,
  257. object_sizes_boundary,
  258. center_sampling_radius,
  259. downsample_ratios,
  260. norm_reg_targets=False):
  261. super(Gt2FCOSTarget, self).__init__()
  262. self.center_sampling_radius = center_sampling_radius
  263. self.downsample_ratios = downsample_ratios
  264. self.INF = np.inf
  265. self.object_sizes_boundary = [-1] + object_sizes_boundary + [self.INF]
  266. object_sizes_of_interest = []
  267. for i in range(len(self.object_sizes_boundary) - 1):
  268. object_sizes_of_interest.append([
  269. self.object_sizes_boundary[i],
  270. self.object_sizes_boundary[i + 1]
  271. ])
  272. self.object_sizes_of_interest = object_sizes_of_interest
  273. self.norm_reg_targets = norm_reg_targets
  274. def _compute_points(self, w, h):
  275. """
  276. compute the corresponding points in each feature map
  277. :param h: image height
  278. :param w: image width
  279. :return: points from all feature map
  280. """
  281. locations = []
  282. for stride in self.downsample_ratios:
  283. shift_x = np.arange(0, w, stride).astype(np.float32)
  284. shift_y = np.arange(0, h, stride).astype(np.float32)
  285. shift_x, shift_y = np.meshgrid(shift_x, shift_y)
  286. shift_x = shift_x.flatten()
  287. shift_y = shift_y.flatten()
  288. location = np.stack([shift_x, shift_y], axis=1) + stride // 2
  289. locations.append(location)
  290. num_points_each_level = [len(location) for location in locations]
  291. locations = np.concatenate(locations, axis=0)
  292. return locations, num_points_each_level
  293. def _convert_xywh2xyxy(self, gt_bbox, w, h):
  294. """
  295. convert the bounding box from style xywh to xyxy
  296. :param gt_bbox: bounding boxes normalized into [0, 1]
  297. :param w: image width
  298. :param h: image height
  299. :return: bounding boxes in xyxy style
  300. """
  301. bboxes = gt_bbox.copy()
  302. bboxes[:, [0, 2]] = bboxes[:, [0, 2]] * w
  303. bboxes[:, [1, 3]] = bboxes[:, [1, 3]] * h
  304. bboxes[:, 2] = bboxes[:, 0] + bboxes[:, 2]
  305. bboxes[:, 3] = bboxes[:, 1] + bboxes[:, 3]
  306. return bboxes
  307. def _check_inside_boxes_limited(self, gt_bbox, xs, ys,
  308. num_points_each_level):
  309. """
  310. check if points is within the clipped boxes
  311. :param gt_bbox: bounding boxes
  312. :param xs: horizontal coordinate of points
  313. :param ys: vertical coordinate of points
  314. :return: the mask of points is within gt_box or not
  315. """
  316. bboxes = np.reshape(
  317. gt_bbox, newshape=[1, gt_bbox.shape[0], gt_bbox.shape[1]])
  318. bboxes = np.tile(bboxes, reps=[xs.shape[0], 1, 1])
  319. ct_x = (bboxes[:, :, 0] + bboxes[:, :, 2]) / 2
  320. ct_y = (bboxes[:, :, 1] + bboxes[:, :, 3]) / 2
  321. beg = 0
  322. clipped_box = bboxes.copy()
  323. for lvl, stride in enumerate(self.downsample_ratios):
  324. end = beg + num_points_each_level[lvl]
  325. stride_exp = self.center_sampling_radius * stride
  326. clipped_box[beg:end, :, 0] = np.maximum(
  327. bboxes[beg:end, :, 0], ct_x[beg:end, :] - stride_exp)
  328. clipped_box[beg:end, :, 1] = np.maximum(
  329. bboxes[beg:end, :, 1], ct_y[beg:end, :] - stride_exp)
  330. clipped_box[beg:end, :, 2] = np.minimum(
  331. bboxes[beg:end, :, 2], ct_x[beg:end, :] + stride_exp)
  332. clipped_box[beg:end, :, 3] = np.minimum(
  333. bboxes[beg:end, :, 3], ct_y[beg:end, :] + stride_exp)
  334. beg = end
  335. l_res = xs - clipped_box[:, :, 0]
  336. r_res = clipped_box[:, :, 2] - xs
  337. t_res = ys - clipped_box[:, :, 1]
  338. b_res = clipped_box[:, :, 3] - ys
  339. clipped_box_reg_targets = np.stack(
  340. [l_res, t_res, r_res, b_res], axis=2)
  341. inside_gt_box = np.min(clipped_box_reg_targets, axis=2) > 0
  342. return inside_gt_box
  343. def __call__(self, samples, context=None):
  344. assert len(self.object_sizes_of_interest) == len(self.downsample_ratios), \
  345. "object_sizes_of_interest', and 'downsample_ratios' should have same length."
  346. for sample in samples:
  347. im = sample['image']
  348. bboxes = sample['gt_bbox']
  349. gt_class = sample['gt_class']
  350. # calculate the locations
  351. h, w = im.shape[1:3]
  352. points, num_points_each_level = self._compute_points(w, h)
  353. object_scale_exp = []
  354. for i, num_pts in enumerate(num_points_each_level):
  355. object_scale_exp.append(
  356. np.tile(
  357. np.array([self.object_sizes_of_interest[i]]),
  358. reps=[num_pts, 1]))
  359. object_scale_exp = np.concatenate(object_scale_exp, axis=0)
  360. gt_area = (bboxes[:, 2] - bboxes[:, 0]) * (
  361. bboxes[:, 3] - bboxes[:, 1])
  362. xs, ys = points[:, 0], points[:, 1]
  363. xs = np.reshape(xs, newshape=[xs.shape[0], 1])
  364. xs = np.tile(xs, reps=[1, bboxes.shape[0]])
  365. ys = np.reshape(ys, newshape=[ys.shape[0], 1])
  366. ys = np.tile(ys, reps=[1, bboxes.shape[0]])
  367. l_res = xs - bboxes[:, 0]
  368. r_res = bboxes[:, 2] - xs
  369. t_res = ys - bboxes[:, 1]
  370. b_res = bboxes[:, 3] - ys
  371. reg_targets = np.stack([l_res, t_res, r_res, b_res], axis=2)
  372. if self.center_sampling_radius > 0:
  373. is_inside_box = self._check_inside_boxes_limited(
  374. bboxes, xs, ys, num_points_each_level)
  375. else:
  376. is_inside_box = np.min(reg_targets, axis=2) > 0
  377. # check if the targets is inside the corresponding level
  378. max_reg_targets = np.max(reg_targets, axis=2)
  379. lower_bound = np.tile(
  380. np.expand_dims(
  381. object_scale_exp[:, 0], axis=1),
  382. reps=[1, max_reg_targets.shape[1]])
  383. high_bound = np.tile(
  384. np.expand_dims(
  385. object_scale_exp[:, 1], axis=1),
  386. reps=[1, max_reg_targets.shape[1]])
  387. is_match_current_level = \
  388. (max_reg_targets > lower_bound) & \
  389. (max_reg_targets < high_bound)
  390. points2gtarea = np.tile(
  391. np.expand_dims(
  392. gt_area, axis=0), reps=[xs.shape[0], 1])
  393. points2gtarea[is_inside_box == 0] = self.INF
  394. points2gtarea[is_match_current_level == 0] = self.INF
  395. points2min_area = points2gtarea.min(axis=1)
  396. points2min_area_ind = points2gtarea.argmin(axis=1)
  397. labels = gt_class[points2min_area_ind] + 1
  398. labels[points2min_area == self.INF] = 0
  399. reg_targets = reg_targets[range(xs.shape[0]), points2min_area_ind]
  400. ctn_targets = np.sqrt((reg_targets[:, [0, 2]].min(axis=1) / \
  401. reg_targets[:, [0, 2]].max(axis=1)) * \
  402. (reg_targets[:, [1, 3]].min(axis=1) / \
  403. reg_targets[:, [1, 3]].max(axis=1))).astype(np.float32)
  404. ctn_targets = np.reshape(
  405. ctn_targets, newshape=[ctn_targets.shape[0], 1])
  406. ctn_targets[labels <= 0] = 0
  407. pos_ind = np.nonzero(labels != 0)
  408. reg_targets_pos = reg_targets[pos_ind[0], :]
  409. split_sections = []
  410. beg = 0
  411. for lvl in range(len(num_points_each_level)):
  412. end = beg + num_points_each_level[lvl]
  413. split_sections.append(end)
  414. beg = end
  415. labels_by_level = np.split(labels, split_sections, axis=0)
  416. reg_targets_by_level = np.split(
  417. reg_targets, split_sections, axis=0)
  418. ctn_targets_by_level = np.split(
  419. ctn_targets, split_sections, axis=0)
  420. for lvl in range(len(self.downsample_ratios)):
  421. grid_w = int(np.ceil(w / self.downsample_ratios[lvl]))
  422. grid_h = int(np.ceil(h / self.downsample_ratios[lvl]))
  423. if self.norm_reg_targets:
  424. sample['reg_target{}'.format(lvl)] = \
  425. np.reshape(
  426. reg_targets_by_level[lvl] / \
  427. self.downsample_ratios[lvl],
  428. newshape=[grid_h, grid_w, 4])
  429. else:
  430. sample['reg_target{}'.format(lvl)] = np.reshape(
  431. reg_targets_by_level[lvl],
  432. newshape=[grid_h, grid_w, 4])
  433. sample['labels{}'.format(lvl)] = np.reshape(
  434. labels_by_level[lvl], newshape=[grid_h, grid_w, 1])
  435. sample['centerness{}'.format(lvl)] = np.reshape(
  436. ctn_targets_by_level[lvl], newshape=[grid_h, grid_w, 1])
  437. sample.pop('is_crowd', None)
  438. sample.pop('difficult', None)
  439. sample.pop('gt_class', None)
  440. sample.pop('gt_bbox', None)
  441. return samples
  442. @register_op
  443. class Gt2GFLTarget(BaseOperator):
  444. """
  445. Generate GFocal loss targets by groud truth data
  446. """
  447. def __init__(self,
  448. num_classes=80,
  449. downsample_ratios=[8, 16, 32, 64, 128],
  450. grid_cell_scale=4,
  451. cell_offset=0):
  452. super(Gt2GFLTarget, self).__init__()
  453. self.num_classes = num_classes
  454. self.downsample_ratios = downsample_ratios
  455. self.grid_cell_scale = grid_cell_scale
  456. self.cell_offset = cell_offset
  457. self.assigner = ATSSAssigner()
  458. def get_grid_cells(self, featmap_size, scale, stride, offset=0):
  459. """
  460. Generate grid cells of a feature map for target assignment.
  461. Args:
  462. featmap_size: Size of a single level feature map.
  463. scale: Grid cell scale.
  464. stride: Down sample stride of the feature map.
  465. offset: Offset of grid cells.
  466. return:
  467. Grid_cells xyxy position. Size should be [feat_w * feat_h, 4]
  468. """
  469. cell_size = stride * scale
  470. h, w = featmap_size
  471. x_range = (np.arange(w, dtype=np.float32) + offset) * stride
  472. y_range = (np.arange(h, dtype=np.float32) + offset) * stride
  473. x, y = np.meshgrid(x_range, y_range)
  474. y = y.flatten()
  475. x = x.flatten()
  476. grid_cells = np.stack(
  477. [
  478. x - 0.5 * cell_size, y - 0.5 * cell_size, x + 0.5 * cell_size,
  479. y + 0.5 * cell_size
  480. ],
  481. axis=-1)
  482. return grid_cells
  483. def get_sample(self, assign_gt_inds, gt_bboxes):
  484. pos_inds = np.unique(np.nonzero(assign_gt_inds > 0)[0])
  485. neg_inds = np.unique(np.nonzero(assign_gt_inds == 0)[0])
  486. pos_assigned_gt_inds = assign_gt_inds[pos_inds] - 1
  487. if gt_bboxes.size == 0:
  488. # hack for index error case
  489. assert pos_assigned_gt_inds.size == 0
  490. pos_gt_bboxes = np.empty_like(gt_bboxes).reshape(-1, 4)
  491. else:
  492. if len(gt_bboxes.shape) < 2:
  493. gt_bboxes = gt_bboxes.resize(-1, 4)
  494. pos_gt_bboxes = gt_bboxes[pos_assigned_gt_inds, :]
  495. return pos_inds, neg_inds, pos_gt_bboxes, pos_assigned_gt_inds
  496. def __call__(self, samples, context=None):
  497. assert len(samples) > 0
  498. batch_size = len(samples)
  499. # get grid cells of image
  500. h, w = samples[0]['image'].shape[1:3]
  501. multi_level_grid_cells = []
  502. for stride in self.downsample_ratios:
  503. featmap_size = (int(math.ceil(h / stride)),
  504. int(math.ceil(w / stride)))
  505. multi_level_grid_cells.append(
  506. self.get_grid_cells(featmap_size, self.grid_cell_scale, stride,
  507. self.cell_offset))
  508. mlvl_grid_cells_list = [
  509. multi_level_grid_cells for i in range(batch_size)
  510. ]
  511. # pixel cell number of multi-level feature maps
  512. num_level_cells = [
  513. grid_cells.shape[0] for grid_cells in mlvl_grid_cells_list[0]
  514. ]
  515. num_level_cells_list = [num_level_cells] * batch_size
  516. # concat all level cells and to a single array
  517. for i in range(batch_size):
  518. mlvl_grid_cells_list[i] = np.concatenate(mlvl_grid_cells_list[i])
  519. # target assign on all images
  520. for sample, grid_cells, num_level_cells in zip(
  521. samples, mlvl_grid_cells_list, num_level_cells_list):
  522. gt_bboxes = sample['gt_bbox']
  523. gt_labels = sample['gt_class'].squeeze()
  524. if gt_labels.size == 1:
  525. gt_labels = np.array([gt_labels]).astype(np.int32)
  526. gt_bboxes_ignore = None
  527. assign_gt_inds, _ = self.assigner(grid_cells, num_level_cells,
  528. gt_bboxes, gt_bboxes_ignore,
  529. gt_labels)
  530. pos_inds, neg_inds, pos_gt_bboxes, pos_assigned_gt_inds = self.get_sample(
  531. assign_gt_inds, gt_bboxes)
  532. num_cells = grid_cells.shape[0]
  533. bbox_targets = np.zeros_like(grid_cells)
  534. bbox_weights = np.zeros_like(grid_cells)
  535. labels = np.ones([num_cells], dtype=np.int64) * self.num_classes
  536. label_weights = np.zeros([num_cells], dtype=np.float32)
  537. if len(pos_inds) > 0:
  538. pos_bbox_targets = pos_gt_bboxes
  539. bbox_targets[pos_inds, :] = pos_bbox_targets
  540. bbox_weights[pos_inds, :] = 1.0
  541. if not np.any(gt_labels):
  542. labels[pos_inds] = 0
  543. else:
  544. labels[pos_inds] = gt_labels[pos_assigned_gt_inds]
  545. label_weights[pos_inds] = 1.0
  546. if len(neg_inds) > 0:
  547. label_weights[neg_inds] = 1.0
  548. sample['grid_cells'] = grid_cells
  549. sample['labels'] = labels
  550. sample['label_weights'] = label_weights
  551. sample['bbox_targets'] = bbox_targets
  552. sample['pos_num'] = max(pos_inds.size, 1)
  553. sample.pop('is_crowd', None)
  554. sample.pop('difficult', None)
  555. sample.pop('gt_class', None)
  556. sample.pop('gt_bbox', None)
  557. sample.pop('gt_score', None)
  558. return samples
  559. @register_op
  560. class Gt2TTFTarget(BaseOperator):
  561. __shared__ = ['num_classes']
  562. """
  563. Gt2TTFTarget
  564. Generate TTFNet targets by ground truth data
  565. Args:
  566. num_classes(int): the number of classes.
  567. down_ratio(int): the down ratio from images to heatmap, 4 by default.
  568. alpha(float): the alpha parameter to generate gaussian target.
  569. 0.54 by default.
  570. """
  571. def __init__(self, num_classes=80, down_ratio=4, alpha=0.54):
  572. super(Gt2TTFTarget, self).__init__()
  573. self.down_ratio = down_ratio
  574. self.num_classes = num_classes
  575. self.alpha = alpha
  576. def __call__(self, samples, context=None):
  577. output_size = samples[0]['image'].shape[1]
  578. feat_size = output_size // self.down_ratio
  579. for sample in samples:
  580. heatmap = np.zeros(
  581. (self.num_classes, feat_size, feat_size), dtype='float32')
  582. box_target = np.ones(
  583. (4, feat_size, feat_size), dtype='float32') * -1
  584. reg_weight = np.zeros((1, feat_size, feat_size), dtype='float32')
  585. gt_bbox = sample['gt_bbox']
  586. gt_class = sample['gt_class']
  587. bbox_w = gt_bbox[:, 2] - gt_bbox[:, 0] + 1
  588. bbox_h = gt_bbox[:, 3] - gt_bbox[:, 1] + 1
  589. area = bbox_w * bbox_h
  590. boxes_areas_log = np.log(area)
  591. boxes_ind = np.argsort(boxes_areas_log, axis=0)[::-1]
  592. boxes_area_topk_log = boxes_areas_log[boxes_ind]
  593. gt_bbox = gt_bbox[boxes_ind]
  594. gt_class = gt_class[boxes_ind]
  595. feat_gt_bbox = gt_bbox / self.down_ratio
  596. feat_gt_bbox = np.clip(feat_gt_bbox, 0, feat_size - 1)
  597. feat_hs, feat_ws = (feat_gt_bbox[:, 3] - feat_gt_bbox[:, 1],
  598. feat_gt_bbox[:, 2] - feat_gt_bbox[:, 0])
  599. ct_inds = np.stack(
  600. [(gt_bbox[:, 0] + gt_bbox[:, 2]) / 2,
  601. (gt_bbox[:, 1] + gt_bbox[:, 3]) / 2],
  602. axis=1) / self.down_ratio
  603. h_radiuses_alpha = (feat_hs / 2. * self.alpha).astype('int32')
  604. w_radiuses_alpha = (feat_ws / 2. * self.alpha).astype('int32')
  605. for k in range(len(gt_bbox)):
  606. cls_id = gt_class[k]
  607. fake_heatmap = np.zeros(
  608. (feat_size, feat_size), dtype='float32')
  609. self.draw_truncate_gaussian(fake_heatmap, ct_inds[k],
  610. h_radiuses_alpha[k],
  611. w_radiuses_alpha[k])
  612. heatmap[cls_id] = np.maximum(heatmap[cls_id], fake_heatmap)
  613. box_target_inds = fake_heatmap > 0
  614. box_target[:, box_target_inds] = gt_bbox[k][:, None]
  615. local_heatmap = fake_heatmap[box_target_inds]
  616. ct_div = np.sum(local_heatmap)
  617. local_heatmap *= boxes_area_topk_log[k]
  618. reg_weight[0, box_target_inds] = local_heatmap / ct_div
  619. sample['ttf_heatmap'] = heatmap
  620. sample['ttf_box_target'] = box_target
  621. sample['ttf_reg_weight'] = reg_weight
  622. sample.pop('is_crowd', None)
  623. sample.pop('difficult', None)
  624. sample.pop('gt_class', None)
  625. sample.pop('gt_bbox', None)
  626. sample.pop('gt_score', None)
  627. return samples
  628. def draw_truncate_gaussian(self, heatmap, center, h_radius, w_radius):
  629. h, w = 2 * h_radius + 1, 2 * w_radius + 1
  630. sigma_x = w / 6
  631. sigma_y = h / 6
  632. gaussian = gaussian2D((h, w), sigma_x, sigma_y)
  633. x, y = int(center[0]), int(center[1])
  634. height, width = heatmap.shape[0:2]
  635. left, right = min(x, w_radius), min(width - x, w_radius + 1)
  636. top, bottom = min(y, h_radius), min(height - y, h_radius + 1)
  637. masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
  638. masked_gaussian = gaussian[h_radius - top:h_radius + bottom, w_radius -
  639. left:w_radius + right]
  640. if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0:
  641. heatmap[y - top:y + bottom, x - left:x + right] = np.maximum(
  642. masked_heatmap, masked_gaussian)
  643. return heatmap
  644. @register_op
  645. class Gt2Solov2Target(BaseOperator):
  646. """Assign mask target and labels in SOLOv2 network.
  647. The code of this function is based on:
  648. https://github.com/WXinlong/SOLO/blob/master/mmdet/models/anchor_heads/solov2_head.py#L271
  649. Args:
  650. num_grids (list): The list of feature map grids size.
  651. scale_ranges (list): The list of mask boundary range.
  652. coord_sigma (float): The coefficient of coordinate area length.
  653. sampling_ratio (float): The ratio of down sampling.
  654. """
  655. def __init__(self,
  656. num_grids=[40, 36, 24, 16, 12],
  657. scale_ranges=[[1, 96], [48, 192], [96, 384], [192, 768],
  658. [384, 2048]],
  659. coord_sigma=0.2,
  660. sampling_ratio=4.0):
  661. super(Gt2Solov2Target, self).__init__()
  662. self.num_grids = num_grids
  663. self.scale_ranges = scale_ranges
  664. self.coord_sigma = coord_sigma
  665. self.sampling_ratio = sampling_ratio
  666. def _scale_size(self, im, scale):
  667. h, w = im.shape[:2]
  668. new_size = (int(w * float(scale) + 0.5), int(h * float(scale) + 0.5))
  669. resized_img = cv2.resize(
  670. im, None, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
  671. return resized_img
  672. def __call__(self, samples, context=None):
  673. sample_id = 0
  674. max_ins_num = [0] * len(self.num_grids)
  675. for sample in samples:
  676. gt_bboxes_raw = sample['gt_bbox']
  677. gt_labels_raw = sample['gt_class'] + 1
  678. im_c, im_h, im_w = sample['image'].shape[:]
  679. gt_masks_raw = sample['gt_segm'].astype(np.uint8)
  680. mask_feat_size = [
  681. int(im_h / self.sampling_ratio),
  682. int(im_w / self.sampling_ratio)
  683. ]
  684. gt_areas = np.sqrt((gt_bboxes_raw[:, 2] - gt_bboxes_raw[:, 0]) *
  685. (gt_bboxes_raw[:, 3] - gt_bboxes_raw[:, 1]))
  686. ins_ind_label_list = []
  687. idx = 0
  688. for (lower_bound, upper_bound), num_grid \
  689. in zip(self.scale_ranges, self.num_grids):
  690. hit_indices = ((gt_areas >= lower_bound) &
  691. (gt_areas <= upper_bound)).nonzero()[0]
  692. num_ins = len(hit_indices)
  693. ins_label = []
  694. grid_order = []
  695. cate_label = np.zeros([num_grid, num_grid], dtype=np.int64)
  696. ins_ind_label = np.zeros([num_grid**2], dtype=np.bool)
  697. if num_ins == 0:
  698. ins_label = np.zeros(
  699. [1, mask_feat_size[0], mask_feat_size[1]],
  700. dtype=np.uint8)
  701. ins_ind_label_list.append(ins_ind_label)
  702. sample['cate_label{}'.format(idx)] = cate_label.flatten()
  703. sample['ins_label{}'.format(idx)] = ins_label
  704. sample['grid_order{}'.format(idx)] = np.asarray(
  705. [sample_id * num_grid * num_grid + 0], dtype=np.int32)
  706. idx += 1
  707. continue
  708. gt_bboxes = gt_bboxes_raw[hit_indices]
  709. gt_labels = gt_labels_raw[hit_indices]
  710. gt_masks = gt_masks_raw[hit_indices, ...]
  711. half_ws = 0.5 * (
  712. gt_bboxes[:, 2] - gt_bboxes[:, 0]) * self.coord_sigma
  713. half_hs = 0.5 * (
  714. gt_bboxes[:, 3] - gt_bboxes[:, 1]) * self.coord_sigma
  715. for seg_mask, gt_label, half_h, half_w in zip(
  716. gt_masks, gt_labels, half_hs, half_ws):
  717. if seg_mask.sum() == 0:
  718. continue
  719. # mass center
  720. upsampled_size = (mask_feat_size[0] * 4,
  721. mask_feat_size[1] * 4)
  722. center_h, center_w = ndimage.measurements.center_of_mass(
  723. seg_mask)
  724. coord_w = int(
  725. (center_w / upsampled_size[1]) // (1. / num_grid))
  726. coord_h = int(
  727. (center_h / upsampled_size[0]) // (1. / num_grid))
  728. # left, top, right, down
  729. top_box = max(0,
  730. int(((center_h - half_h) / upsampled_size[0])
  731. // (1. / num_grid)))
  732. down_box = min(
  733. num_grid - 1,
  734. int(((center_h + half_h) / upsampled_size[0]) //
  735. (1. / num_grid)))
  736. left_box = max(
  737. 0,
  738. int(((center_w - half_w) / upsampled_size[1]) //
  739. (1. / num_grid)))
  740. right_box = min(num_grid - 1,
  741. int(((center_w + half_w) /
  742. upsampled_size[1]) //
  743. (1. / num_grid)))
  744. top = max(top_box, coord_h - 1)
  745. down = min(down_box, coord_h + 1)
  746. left = max(coord_w - 1, left_box)
  747. right = min(right_box, coord_w + 1)
  748. cate_label[top:(down + 1), left:(right + 1)] = gt_label
  749. seg_mask = self._scale_size(
  750. seg_mask, scale=1. / self.sampling_ratio)
  751. for i in range(top, down + 1):
  752. for j in range(left, right + 1):
  753. label = int(i * num_grid + j)
  754. cur_ins_label = np.zeros(
  755. [mask_feat_size[0], mask_feat_size[1]],
  756. dtype=np.uint8)
  757. cur_ins_label[:seg_mask.shape[0], :seg_mask.shape[
  758. 1]] = seg_mask
  759. ins_label.append(cur_ins_label)
  760. ins_ind_label[label] = True
  761. grid_order.append(sample_id * num_grid * num_grid +
  762. label)
  763. if ins_label == []:
  764. ins_label = np.zeros(
  765. [1, mask_feat_size[0], mask_feat_size[1]],
  766. dtype=np.uint8)
  767. ins_ind_label_list.append(ins_ind_label)
  768. sample['cate_label{}'.format(idx)] = cate_label.flatten()
  769. sample['ins_label{}'.format(idx)] = ins_label
  770. sample['grid_order{}'.format(idx)] = np.asarray(
  771. [sample_id * num_grid * num_grid + 0], dtype=np.int32)
  772. else:
  773. ins_label = np.stack(ins_label, axis=0)
  774. ins_ind_label_list.append(ins_ind_label)
  775. sample['cate_label{}'.format(idx)] = cate_label.flatten()
  776. sample['ins_label{}'.format(idx)] = ins_label
  777. sample['grid_order{}'.format(idx)] = np.asarray(
  778. grid_order, dtype=np.int32)
  779. assert len(grid_order) > 0
  780. max_ins_num[idx] = max(
  781. max_ins_num[idx],
  782. sample['ins_label{}'.format(idx)].shape[0])
  783. idx += 1
  784. ins_ind_labels = np.concatenate([
  785. ins_ind_labels_level_img
  786. for ins_ind_labels_level_img in ins_ind_label_list
  787. ])
  788. fg_num = np.sum(ins_ind_labels)
  789. sample['fg_num'] = fg_num
  790. sample_id += 1
  791. sample.pop('is_crowd')
  792. sample.pop('gt_class')
  793. sample.pop('gt_bbox')
  794. sample.pop('gt_poly')
  795. sample.pop('gt_segm')
  796. # padding batch
  797. for data in samples:
  798. for idx in range(len(self.num_grids)):
  799. gt_ins_data = np.zeros(
  800. [
  801. max_ins_num[idx],
  802. data['ins_label{}'.format(idx)].shape[1],
  803. data['ins_label{}'.format(idx)].shape[2]
  804. ],
  805. dtype=np.uint8)
  806. gt_ins_data[0:data['ins_label{}'.format(idx)].shape[
  807. 0], :, :] = data['ins_label{}'.format(idx)]
  808. gt_grid_order = np.zeros([max_ins_num[idx]], dtype=np.int32)
  809. gt_grid_order[0:data['grid_order{}'.format(idx)].shape[
  810. 0]] = data['grid_order{}'.format(idx)]
  811. data['ins_label{}'.format(idx)] = gt_ins_data
  812. data['grid_order{}'.format(idx)] = gt_grid_order
  813. return samples
  814. @register_op
  815. class Gt2SparseRCNNTarget(BaseOperator):
  816. '''
  817. Generate SparseRCNN targets by groud truth data
  818. '''
  819. def __init__(self):
  820. super(Gt2SparseRCNNTarget, self).__init__()
  821. def __call__(self, samples, context=None):
  822. for sample in samples:
  823. im = sample["image"]
  824. h, w = im.shape[1:3]
  825. img_whwh = np.array([w, h, w, h], dtype=np.int32)
  826. sample["img_whwh"] = img_whwh
  827. if "scale_factor" in sample:
  828. sample["scale_factor_wh"] = np.array(
  829. [sample["scale_factor"][1], sample["scale_factor"][0]],
  830. dtype=np.float32)
  831. else:
  832. sample["scale_factor_wh"] = np.array(
  833. [1.0, 1.0], dtype=np.float32)
  834. return samples
  835. @register_op
  836. class PadMaskBatch(BaseOperator):
  837. """
  838. Pad a batch of samples so they can be divisible by a stride.
  839. The layout of each image should be 'CHW'.
  840. Args:
  841. pad_to_stride (int): If `pad_to_stride > 0`, pad zeros to ensure
  842. height and width is divisible by `pad_to_stride`.
  843. return_pad_mask (bool): If `return_pad_mask = True`, return
  844. `pad_mask` for transformer.
  845. """
  846. def __init__(self, pad_to_stride=0, return_pad_mask=False):
  847. super(PadMaskBatch, self).__init__()
  848. self.pad_to_stride = pad_to_stride
  849. self.return_pad_mask = return_pad_mask
  850. def __call__(self, samples, context=None):
  851. """
  852. Args:
  853. samples (list): a batch of sample, each is dict.
  854. """
  855. coarsest_stride = self.pad_to_stride
  856. max_shape = np.array([data['image'].shape for data in samples]).max(
  857. axis=0)
  858. if coarsest_stride > 0:
  859. max_shape[1] = int(
  860. np.ceil(max_shape[1] / coarsest_stride) * coarsest_stride)
  861. max_shape[2] = int(
  862. np.ceil(max_shape[2] / coarsest_stride) * coarsest_stride)
  863. for data in samples:
  864. im = data['image']
  865. im_c, im_h, im_w = im.shape[:]
  866. padding_im = np.zeros(
  867. (im_c, max_shape[1], max_shape[2]), dtype=np.float32)
  868. padding_im[:, :im_h, :im_w] = im
  869. data['image'] = padding_im
  870. if 'semantic' in data and data['semantic'] is not None:
  871. semantic = data['semantic']
  872. padding_sem = np.zeros(
  873. (1, max_shape[1], max_shape[2]), dtype=np.float32)
  874. padding_sem[:, :im_h, :im_w] = semantic
  875. data['semantic'] = padding_sem
  876. if 'gt_segm' in data and data['gt_segm'] is not None:
  877. gt_segm = data['gt_segm']
  878. padding_segm = np.zeros(
  879. (gt_segm.shape[0], max_shape[1], max_shape[2]),
  880. dtype=np.uint8)
  881. padding_segm[:, :im_h, :im_w] = gt_segm
  882. data['gt_segm'] = padding_segm
  883. if self.return_pad_mask:
  884. padding_mask = np.zeros(
  885. (max_shape[1], max_shape[2]), dtype=np.float32)
  886. padding_mask[:im_h, :im_w] = 1.
  887. data['pad_mask'] = padding_mask
  888. if 'gt_rbox2poly' in data and data['gt_rbox2poly'] is not None:
  889. # ploy to rbox
  890. polys = data['gt_rbox2poly']
  891. rbox = bbox_utils.poly2rbox(polys)
  892. data['gt_rbox'] = rbox
  893. return samples
  894. @register_op
  895. class Gt2CenterNetTarget(BaseOperator):
  896. """Gt2CenterNetTarget
  897. Genterate CenterNet targets by ground-truth
  898. Args:
  899. down_ratio (int): The down sample ratio between output feature and
  900. input image.
  901. num_classes (int): The number of classes, 80 by default.
  902. max_objs (int): The maximum objects detected, 128 by default.
  903. """
  904. def __init__(self, down_ratio, num_classes=80, max_objs=128):
  905. super(Gt2CenterNetTarget, self).__init__()
  906. self.down_ratio = down_ratio
  907. self.num_classes = num_classes
  908. self.max_objs = max_objs
  909. def __call__(self, sample, context=None):
  910. input_h, input_w = sample['image'].shape[1:]
  911. output_h = input_h // self.down_ratio
  912. output_w = input_w // self.down_ratio
  913. num_classes = self.num_classes
  914. c = sample['center']
  915. s = sample['scale']
  916. gt_bbox = sample['gt_bbox']
  917. gt_class = sample['gt_class']
  918. hm = np.zeros((num_classes, output_h, output_w), dtype=np.float32)
  919. wh = np.zeros((self.max_objs, 2), dtype=np.float32)
  920. dense_wh = np.zeros((2, output_h, output_w), dtype=np.float32)
  921. reg = np.zeros((self.max_objs, 2), dtype=np.float32)
  922. ind = np.zeros((self.max_objs), dtype=np.int64)
  923. reg_mask = np.zeros((self.max_objs), dtype=np.int32)
  924. cat_spec_wh = np.zeros(
  925. (self.max_objs, num_classes * 2), dtype=np.float32)
  926. cat_spec_mask = np.zeros(
  927. (self.max_objs, num_classes * 2), dtype=np.int32)
  928. trans_output = get_affine_transform(c, [s, s], 0, [output_w, output_h])
  929. gt_det = []
  930. for i, (bbox, cls) in enumerate(zip(gt_bbox, gt_class)):
  931. cls = int(cls)
  932. bbox[:2] = affine_transform(bbox[:2], trans_output)
  933. bbox[2:] = affine_transform(bbox[2:], trans_output)
  934. bbox[[0, 2]] = np.clip(bbox[[0, 2]], 0, output_w - 1)
  935. bbox[[1, 3]] = np.clip(bbox[[1, 3]], 0, output_h - 1)
  936. h, w = bbox[3] - bbox[1], bbox[2] - bbox[0]
  937. if h > 0 and w > 0:
  938. radius = gaussian_radius((math.ceil(h), math.ceil(w)), 0.7)
  939. radius = max(0, int(radius))
  940. ct = np.array(
  941. [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2],
  942. dtype=np.float32)
  943. ct_int = ct.astype(np.int32)
  944. draw_umich_gaussian(hm[cls], ct_int, radius)
  945. wh[i] = 1. * w, 1. * h
  946. ind[i] = ct_int[1] * output_w + ct_int[0]
  947. reg[i] = ct - ct_int
  948. reg_mask[i] = 1
  949. cat_spec_wh[i, cls * 2:cls * 2 + 2] = wh[i]
  950. cat_spec_mask[i, cls * 2:cls * 2 + 2] = 1
  951. gt_det.append([
  952. ct[0] - w / 2, ct[1] - h / 2, ct[0] + w / 2, ct[1] + h / 2,
  953. 1, cls
  954. ])
  955. sample.pop('gt_bbox', None)
  956. sample.pop('gt_class', None)
  957. sample.pop('center', None)
  958. sample.pop('scale', None)
  959. sample.pop('is_crowd', None)
  960. sample.pop('difficult', None)
  961. sample['heatmap'] = hm
  962. sample['index_mask'] = reg_mask
  963. sample['index'] = ind
  964. sample['size'] = wh
  965. sample['offset'] = reg
  966. return sample