mot_operators.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. try:
  18. from collections.abc import Sequence
  19. except Exception:
  20. from collections import Sequence
  21. from numbers import Integral
  22. import cv2
  23. import copy
  24. import numpy as np
  25. import random
  26. import math
  27. from .operators import BaseOperator, register_op
  28. from .batch_operators import Gt2TTFTarget
  29. from paddlex.ppdet.modeling.bbox_utils import bbox_iou_np_expand
  30. from paddlex.ppdet.core.workspace import serializable
  31. from paddlex.ppdet.utils.logger import setup_logger
  32. logger = setup_logger(__name__)
  33. __all__ = [
  34. 'RGBReverse', 'LetterBoxResize', 'MOTRandomAffine', 'Gt2JDETargetThres',
  35. 'Gt2JDETargetMax', 'Gt2FairMOTTarget'
  36. ]
  37. @register_op
  38. class RGBReverse(BaseOperator):
  39. """RGB to BGR, or BGR to RGB, sensitive to MOTRandomAffine
  40. """
  41. def __init__(self):
  42. super(RGBReverse, self).__init__()
  43. def apply(self, sample, context=None):
  44. im = sample['image']
  45. sample['image'] = np.ascontiguousarray(im[:, :, ::-1])
  46. return sample
  47. @register_op
  48. class LetterBoxResize(BaseOperator):
  49. def __init__(self, target_size):
  50. """
  51. Resize image to target size, convert normalized xywh to pixel xyxy
  52. format ([x_center, y_center, width, height] -> [x0, y0, x1, y1]).
  53. Args:
  54. target_size (int|list): image target size.
  55. """
  56. super(LetterBoxResize, self).__init__()
  57. if not isinstance(target_size, (Integral, Sequence)):
  58. raise TypeError(
  59. "Type of target_size is invalid. Must be Integer or List or Tuple, now is {}".
  60. format(type(target_size)))
  61. if isinstance(target_size, Integral):
  62. target_size = [target_size, target_size]
  63. self.target_size = target_size
  64. def apply_image(self, img, height, width, color=(127.5, 127.5, 127.5)):
  65. # letterbox: resize a rectangular image to a padded rectangular
  66. shape = img.shape[:2] # [height, width]
  67. ratio_h = float(height) / shape[0]
  68. ratio_w = float(width) / shape[1]
  69. ratio = min(ratio_h, ratio_w)
  70. new_shape = (round(shape[1] * ratio),
  71. round(shape[0] * ratio)) # [width, height]
  72. padw = (width - new_shape[0]) / 2
  73. padh = (height - new_shape[1]) / 2
  74. top, bottom = round(padh - 0.1), round(padh + 0.1)
  75. left, right = round(padw - 0.1), round(padw + 0.1)
  76. img = cv2.resize(
  77. img, new_shape, interpolation=cv2.INTER_AREA) # resized, no border
  78. img = cv2.copyMakeBorder(
  79. img, top, bottom, left, right, cv2.BORDER_CONSTANT,
  80. value=color) # padded rectangular
  81. return img, ratio, padw, padh
  82. def apply_bbox(self, bbox0, h, w, ratio, padw, padh):
  83. bboxes = bbox0.copy()
  84. bboxes[:, 0] = ratio * w * (bbox0[:, 0] - bbox0[:, 2] / 2) + padw
  85. bboxes[:, 1] = ratio * h * (bbox0[:, 1] - bbox0[:, 3] / 2) + padh
  86. bboxes[:, 2] = ratio * w * (bbox0[:, 0] + bbox0[:, 2] / 2) + padw
  87. bboxes[:, 3] = ratio * h * (bbox0[:, 1] + bbox0[:, 3] / 2) + padh
  88. return bboxes
  89. def apply(self, sample, context=None):
  90. """ Resize the image numpy.
  91. """
  92. im = sample['image']
  93. h, w = sample['im_shape']
  94. if not isinstance(im, np.ndarray):
  95. raise TypeError("{}: image type is not numpy.".format(self))
  96. if len(im.shape) != 3:
  97. raise ImageError('{}: image is not 3-dimensional.'.format(self))
  98. # apply image
  99. height, width = self.target_size
  100. img, ratio, padw, padh = self.apply_image(
  101. im, height=height, width=width)
  102. sample['image'] = img
  103. new_shape = (round(h * ratio), round(w * ratio))
  104. sample['im_shape'] = np.asarray(new_shape, dtype=np.float32)
  105. sample['scale_factor'] = np.asarray([ratio, ratio], dtype=np.float32)
  106. # apply bbox
  107. if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
  108. sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], h, w, ratio,
  109. padw, padh)
  110. return sample
  111. @register_op
  112. class MOTRandomAffine(BaseOperator):
  113. """
  114. Affine transform to image and coords to achieve the rotate, scale and
  115. shift effect for training image.
  116. Args:
  117. degrees (list[2]): the rotate range to apply, transform range is [min, max]
  118. translate (list[2]): the translate range to apply, ransform range is [min, max]
  119. scale (list[2]): the scale range to apply, transform range is [min, max]
  120. shear (list[2]): the shear range to apply, transform range is [min, max]
  121. borderValue (list[3]): value used in case of a constant border when appling
  122. the perspective transformation
  123. reject_outside (bool): reject warped bounding bboxes outside of image
  124. Returns:
  125. records(dict): contain the image and coords after tranformed
  126. """
  127. def __init__(self,
  128. degrees=(-5, 5),
  129. translate=(0.10, 0.10),
  130. scale=(0.50, 1.20),
  131. shear=(-2, 2),
  132. borderValue=(127.5, 127.5, 127.5),
  133. reject_outside=True):
  134. super(MOTRandomAffine, self).__init__()
  135. self.degrees = degrees
  136. self.translate = translate
  137. self.scale = scale
  138. self.shear = shear
  139. self.borderValue = borderValue
  140. self.reject_outside = reject_outside
  141. def apply(self, sample, context=None):
  142. # https://medium.com/uruvideo/dataset-augmentation-with-random-homographies-a8f4b44830d4
  143. border = 0 # width of added border (optional)
  144. img = sample['image']
  145. height, width = img.shape[0], img.shape[1]
  146. # Rotation and Scale
  147. R = np.eye(3)
  148. a = random.random() * (self.degrees[1] - self.degrees[0]
  149. ) + self.degrees[0]
  150. s = random.random() * (self.scale[1] - self.scale[0]) + self.scale[0]
  151. R[:2] = cv2.getRotationMatrix2D(
  152. angle=a, center=(width / 2, height / 2), scale=s)
  153. # Translation
  154. T = np.eye(3)
  155. T[0, 2] = (
  156. random.random() * 2 - 1
  157. ) * self.translate[0] * height + border # x translation (pixels)
  158. T[1, 2] = (
  159. random.random() * 2 - 1
  160. ) * self.translate[1] * width + border # y translation (pixels)
  161. # Shear
  162. S = np.eye(3)
  163. S[0, 1] = math.tan((random.random() *
  164. (self.shear[1] - self.shear[0]) + self.shear[0]) *
  165. math.pi / 180) # x shear (deg)
  166. S[1, 0] = math.tan((random.random() *
  167. (self.shear[1] - self.shear[0]) + self.shear[0]) *
  168. math.pi / 180) # y shear (deg)
  169. M = S @T @R # Combined rotation matrix. ORDER IS IMPORTANT HERE!!
  170. imw = cv2.warpPerspective(
  171. img,
  172. M,
  173. dsize=(width, height),
  174. flags=cv2.INTER_LINEAR,
  175. borderValue=self.borderValue) # BGR order borderValue
  176. if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
  177. targets = sample['gt_bbox']
  178. n = targets.shape[0]
  179. points = targets.copy()
  180. area0 = (points[:, 2] - points[:, 0]) * (
  181. points[:, 3] - points[:, 1])
  182. # warp points
  183. xy = np.ones((n * 4, 3))
  184. xy[:, :2] = points[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(
  185. n * 4, 2) # x1y1, x2y2, x1y2, x2y1
  186. xy = (xy @M.T)[:, :2].reshape(n, 8)
  187. # create new boxes
  188. x = xy[:, [0, 2, 4, 6]]
  189. y = xy[:, [1, 3, 5, 7]]
  190. xy = np.concatenate(
  191. (x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
  192. # apply angle-based reduction
  193. radians = a * math.pi / 180
  194. reduction = max(abs(math.sin(radians)), abs(math.cos(radians)))**0.5
  195. x = (xy[:, 2] + xy[:, 0]) / 2
  196. y = (xy[:, 3] + xy[:, 1]) / 2
  197. w = (xy[:, 2] - xy[:, 0]) * reduction
  198. h = (xy[:, 3] - xy[:, 1]) * reduction
  199. xy = np.concatenate(
  200. (x - w / 2, y - h / 2, x + w / 2, y + h / 2)).reshape(4, n).T
  201. # reject warped points outside of image
  202. if self.reject_outside:
  203. np.clip(xy[:, 0], 0, width, out=xy[:, 0])
  204. np.clip(xy[:, 2], 0, width, out=xy[:, 2])
  205. np.clip(xy[:, 1], 0, height, out=xy[:, 1])
  206. np.clip(xy[:, 3], 0, height, out=xy[:, 3])
  207. w = xy[:, 2] - xy[:, 0]
  208. h = xy[:, 3] - xy[:, 1]
  209. area = w * h
  210. ar = np.maximum(w / (h + 1e-16), h / (w + 1e-16))
  211. i = (w > 4) & (h > 4) & (area / (area0 + 1e-16) > 0.1) & (ar < 10)
  212. if sum(i) > 0:
  213. sample['gt_bbox'] = xy[i].astype(sample['gt_bbox'].dtype)
  214. sample['gt_class'] = sample['gt_class'][i]
  215. if 'difficult' in sample:
  216. sample['difficult'] = sample['difficult'][i]
  217. if 'gt_ide' in sample:
  218. sample['gt_ide'] = sample['gt_ide'][i]
  219. if 'is_crowd' in sample:
  220. sample['is_crowd'] = sample['is_crowd'][i]
  221. sample['image'] = imw
  222. return sample
  223. else:
  224. return sample
  225. @register_op
  226. class Gt2JDETargetThres(BaseOperator):
  227. __shared__ = ['num_classes']
  228. """
  229. Generate JDE targets by groud truth data when training
  230. Args:
  231. anchors (list): anchors of JDE model
  232. anchor_masks (list): anchor_masks of JDE model
  233. downsample_ratios (list): downsample ratios of JDE model
  234. ide_thresh (float): thresh of identity, higher is groud truth
  235. fg_thresh (float): thresh of foreground, higher is foreground
  236. bg_thresh (float): thresh of background, lower is background
  237. num_classes (int): number of classes
  238. """
  239. def __init__(self,
  240. anchors,
  241. anchor_masks,
  242. downsample_ratios,
  243. ide_thresh=0.5,
  244. fg_thresh=0.5,
  245. bg_thresh=0.4,
  246. num_classes=1):
  247. super(Gt2JDETargetThres, self).__init__()
  248. self.anchors = anchors
  249. self.anchor_masks = anchor_masks
  250. self.downsample_ratios = downsample_ratios
  251. self.ide_thresh = ide_thresh
  252. self.fg_thresh = fg_thresh
  253. self.bg_thresh = bg_thresh
  254. self.num_classes = num_classes
  255. def generate_anchor(self, nGh, nGw, anchor_hw):
  256. nA = len(anchor_hw)
  257. yy, xx = np.meshgrid(np.arange(nGh), np.arange(nGw))
  258. mesh = np.stack([xx.T, yy.T], axis=0) # [2, nGh, nGw]
  259. mesh = np.repeat(mesh[None, :], nA, axis=0) # [nA, 2, nGh, nGw]
  260. anchor_offset_mesh = anchor_hw[:, :, None][:, :, :, None]
  261. anchor_offset_mesh = np.repeat(anchor_offset_mesh, nGh, axis=-2)
  262. anchor_offset_mesh = np.repeat(anchor_offset_mesh, nGw, axis=-1)
  263. anchor_mesh = np.concatenate(
  264. [mesh, anchor_offset_mesh], axis=1) # [nA, 4, nGh, nGw]
  265. return anchor_mesh
  266. def encode_delta(self, gt_box_list, fg_anchor_list):
  267. px, py, pw, ph = fg_anchor_list[:, 0], fg_anchor_list[:,1], \
  268. fg_anchor_list[:, 2], fg_anchor_list[:,3]
  269. gx, gy, gw, gh = gt_box_list[:, 0], gt_box_list[:, 1], \
  270. gt_box_list[:, 2], gt_box_list[:, 3]
  271. dx = (gx - px) / pw
  272. dy = (gy - py) / ph
  273. dw = np.log(gw / pw)
  274. dh = np.log(gh / ph)
  275. return np.stack([dx, dy, dw, dh], axis=1)
  276. def pad_box(self, sample, num_max):
  277. assert 'gt_bbox' in sample
  278. bbox = sample['gt_bbox']
  279. gt_num = len(bbox)
  280. pad_bbox = np.zeros((num_max, 4), dtype=np.float32)
  281. if gt_num > 0:
  282. pad_bbox[:gt_num, :] = bbox[:gt_num, :]
  283. sample['gt_bbox'] = pad_bbox
  284. if 'gt_score' in sample:
  285. pad_score = np.zeros((num_max, ), dtype=np.float32)
  286. if gt_num > 0:
  287. pad_score[:gt_num] = sample['gt_score'][:gt_num, 0]
  288. sample['gt_score'] = pad_score
  289. if 'difficult' in sample:
  290. pad_diff = np.zeros((num_max, ), dtype=np.int32)
  291. if gt_num > 0:
  292. pad_diff[:gt_num] = sample['difficult'][:gt_num, 0]
  293. sample['difficult'] = pad_diff
  294. if 'is_crowd' in sample:
  295. pad_crowd = np.zeros((num_max, ), dtype=np.int32)
  296. if gt_num > 0:
  297. pad_crowd[:gt_num] = sample['is_crowd'][:gt_num, 0]
  298. sample['is_crowd'] = pad_crowd
  299. if 'gt_ide' in sample:
  300. pad_ide = np.zeros((num_max, ), dtype=np.int32)
  301. if gt_num > 0:
  302. pad_ide[:gt_num] = sample['gt_ide'][:gt_num, 0]
  303. sample['gt_ide'] = pad_ide
  304. return sample
  305. def __call__(self, samples, context=None):
  306. assert len(self.anchor_masks) == len(self.downsample_ratios), \
  307. "anchor_masks', and 'downsample_ratios' should have same length."
  308. h, w = samples[0]['image'].shape[1:3]
  309. num_max = 0
  310. for sample in samples:
  311. num_max = max(num_max, len(sample['gt_bbox']))
  312. for sample in samples:
  313. gt_bbox = sample['gt_bbox']
  314. gt_ide = sample['gt_ide']
  315. for i, (anchor_hw, downsample_ratio
  316. ) in enumerate(zip(self.anchors, self.downsample_ratios)):
  317. anchor_hw = np.array(
  318. anchor_hw, dtype=np.float32) / downsample_ratio
  319. nA = len(anchor_hw)
  320. nGh, nGw = int(h / downsample_ratio), int(w / downsample_ratio)
  321. tbox = np.zeros((nA, nGh, nGw, 4), dtype=np.float32)
  322. tconf = np.zeros((nA, nGh, nGw), dtype=np.float32)
  323. tid = -np.ones((nA, nGh, nGw, 1), dtype=np.float32)
  324. gxy, gwh = gt_bbox[:, 0:2].copy(), gt_bbox[:, 2:4].copy()
  325. gxy[:, 0] = gxy[:, 0] * nGw
  326. gxy[:, 1] = gxy[:, 1] * nGh
  327. gwh[:, 0] = gwh[:, 0] * nGw
  328. gwh[:, 1] = gwh[:, 1] * nGh
  329. gxy[:, 0] = np.clip(gxy[:, 0], 0, nGw - 1)
  330. gxy[:, 1] = np.clip(gxy[:, 1], 0, nGh - 1)
  331. tboxes = np.concatenate([gxy, gwh], axis=1)
  332. anchor_mesh = self.generate_anchor(nGh, nGw, anchor_hw)
  333. anchor_list = np.transpose(anchor_mesh,
  334. (0, 2, 3, 1)).reshape(-1, 4)
  335. iou_pdist = bbox_iou_np_expand(
  336. anchor_list, tboxes, x1y1x2y2=False)
  337. iou_max = np.max(iou_pdist, axis=1)
  338. max_gt_index = np.argmax(iou_pdist, axis=1)
  339. iou_map = iou_max.reshape(nA, nGh, nGw)
  340. gt_index_map = max_gt_index.reshape(nA, nGh, nGw)
  341. id_index = iou_map > self.ide_thresh
  342. fg_index = iou_map > self.fg_thresh
  343. bg_index = iou_map < self.bg_thresh
  344. ign_index = (iou_map < self.fg_thresh) * (
  345. iou_map > self.bg_thresh)
  346. tconf[fg_index] = 1
  347. tconf[bg_index] = 0
  348. tconf[ign_index] = -1
  349. gt_index = gt_index_map[fg_index]
  350. gt_box_list = tboxes[gt_index]
  351. gt_id_list = gt_ide[gt_index_map[id_index]]
  352. if np.sum(fg_index) > 0:
  353. tid[id_index] = gt_id_list
  354. fg_anchor_list = anchor_list.reshape(nA, nGh, nGw,
  355. 4)[fg_index]
  356. delta_target = self.encode_delta(gt_box_list,
  357. fg_anchor_list)
  358. tbox[fg_index] = delta_target
  359. sample['tbox{}'.format(i)] = tbox
  360. sample['tconf{}'.format(i)] = tconf
  361. sample['tide{}'.format(i)] = tid
  362. sample.pop('gt_class')
  363. sample = self.pad_box(sample, num_max)
  364. return samples
  365. @register_op
  366. class Gt2JDETargetMax(BaseOperator):
  367. __shared__ = ['num_classes']
  368. """
  369. Generate JDE targets by groud truth data when evaluating
  370. Args:
  371. anchors (list): anchors of JDE model
  372. anchor_masks (list): anchor_masks of JDE model
  373. downsample_ratios (list): downsample ratios of JDE model
  374. max_iou_thresh (float): iou thresh for high quality anchor
  375. num_classes (int): number of classes
  376. """
  377. def __init__(self,
  378. anchors,
  379. anchor_masks,
  380. downsample_ratios,
  381. max_iou_thresh=0.60,
  382. num_classes=1):
  383. super(Gt2JDETargetMax, self).__init__()
  384. self.anchors = anchors
  385. self.anchor_masks = anchor_masks
  386. self.downsample_ratios = downsample_ratios
  387. self.max_iou_thresh = max_iou_thresh
  388. self.num_classes = num_classes
  389. def __call__(self, samples, context=None):
  390. assert len(self.anchor_masks) == len(self.downsample_ratios), \
  391. "anchor_masks', and 'downsample_ratios' should have same length."
  392. h, w = samples[0]['image'].shape[1:3]
  393. for sample in samples:
  394. gt_bbox = sample['gt_bbox']
  395. gt_ide = sample['gt_ide']
  396. for i, (anchor_hw, downsample_ratio
  397. ) in enumerate(zip(self.anchors, self.downsample_ratios)):
  398. anchor_hw = np.array(
  399. anchor_hw, dtype=np.float32) / downsample_ratio
  400. nA = len(anchor_hw)
  401. nGh, nGw = int(h / downsample_ratio), int(w / downsample_ratio)
  402. tbox = np.zeros((nA, nGh, nGw, 4), dtype=np.float32)
  403. tconf = np.zeros((nA, nGh, nGw), dtype=np.float32)
  404. tid = -np.ones((nA, nGh, nGw, 1), dtype=np.float32)
  405. gxy, gwh = gt_bbox[:, 0:2].copy(), gt_bbox[:, 2:4].copy()
  406. gxy[:, 0] = gxy[:, 0] * nGw
  407. gxy[:, 1] = gxy[:, 1] * nGh
  408. gwh[:, 0] = gwh[:, 0] * nGw
  409. gwh[:, 1] = gwh[:, 1] * nGh
  410. gi = np.clip(gxy[:, 0], 0, nGw - 1).astype(int)
  411. gj = np.clip(gxy[:, 1], 0, nGh - 1).astype(int)
  412. # iou of targets-anchors (using wh only)
  413. box1 = gwh
  414. box2 = anchor_hw[:, None, :]
  415. inter_area = np.minimum(box1, box2).prod(2)
  416. iou = inter_area / (
  417. box1.prod(1) + box2.prod(2) - inter_area + 1e-16)
  418. # Select best iou_pred and anchor
  419. iou_best = iou.max(0) # best anchor [0-2] for each target
  420. a = np.argmax(iou, axis=0)
  421. # Select best unique target-anchor combinations
  422. iou_order = np.argsort(-iou_best) # best to worst
  423. # Unique anchor selection
  424. u = np.stack((gi, gj, a), 0)[:, iou_order]
  425. _, first_unique = np.unique(u, axis=1, return_index=True)
  426. mask = iou_order[first_unique]
  427. # best anchor must share significant commonality (iou) with target
  428. # TODO: examine arbitrary threshold
  429. idx = mask[iou_best[mask] > self.max_iou_thresh]
  430. if len(idx) > 0:
  431. a_i, gj_i, gi_i = a[idx], gj[idx], gi[idx]
  432. t_box = gt_bbox[idx]
  433. t_id = gt_ide[idx]
  434. if len(t_box.shape) == 1:
  435. t_box = t_box.reshape(1, 4)
  436. gxy, gwh = t_box[:, 0:2].copy(), t_box[:, 2:4].copy()
  437. gxy[:, 0] = gxy[:, 0] * nGw
  438. gxy[:, 1] = gxy[:, 1] * nGh
  439. gwh[:, 0] = gwh[:, 0] * nGw
  440. gwh[:, 1] = gwh[:, 1] * nGh
  441. # XY coordinates
  442. tbox[:, :, :, 0:2][a_i, gj_i, gi_i] = gxy - gxy.astype(int)
  443. # Width and height in yolo method
  444. tbox[:, :, :, 2:4][a_i, gj_i, gi_i] = np.log(gwh /
  445. anchor_hw[a_i])
  446. tconf[a_i, gj_i, gi_i] = 1
  447. tid[a_i, gj_i, gi_i] = t_id
  448. sample['tbox{}'.format(i)] = tbox
  449. sample['tconf{}'.format(i)] = tconf
  450. sample['tide{}'.format(i)] = tid
  451. class Gt2FairMOTTarget(Gt2TTFTarget):
  452. __shared__ = ['num_classes']
  453. """
  454. Generate FairMOT targets by ground truth data.
  455. Difference between Gt2FairMOTTarget and Gt2TTFTarget are:
  456. 1. the gaussian kernal radius to generate a heatmap.
  457. 2. the targets needed during traing.
  458. Args:
  459. num_classes(int): the number of classes.
  460. down_ratio(int): the down ratio from images to heatmap, 4 by default.
  461. max_objs(int): the maximum number of ground truth objects in a image, 500 by default.
  462. """
  463. def __init__(self, num_classes=1, down_ratio=4, max_objs=500):
  464. super(Gt2TTFTarget, self).__init__()
  465. self.down_ratio = down_ratio
  466. self.num_classes = num_classes
  467. self.max_objs = max_objs
  468. def __call__(self, samples, context=None):
  469. for b_id, sample in enumerate(samples):
  470. output_h = sample['image'].shape[1] // self.down_ratio
  471. output_w = sample['image'].shape[2] // self.down_ratio
  472. heatmap = np.zeros(
  473. (self.num_classes, output_h, output_w), dtype='float32')
  474. bbox_size = np.zeros((self.max_objs, 4), dtype=np.float32)
  475. center_offset = np.zeros((self.max_objs, 2), dtype=np.float32)
  476. index = np.zeros((self.max_objs, ), dtype=np.int64)
  477. index_mask = np.zeros((self.max_objs, ), dtype=np.int32)
  478. reid = np.zeros((self.max_objs, ), dtype=np.int64)
  479. bbox_xys = np.zeros((self.max_objs, 4), dtype=np.float32)
  480. gt_bbox = sample['gt_bbox']
  481. gt_class = sample['gt_class']
  482. gt_ide = sample['gt_ide']
  483. for k in range(len(gt_bbox)):
  484. cls_id = gt_class[k][0]
  485. bbox = gt_bbox[k]
  486. ide = gt_ide[k][0]
  487. bbox[[0, 2]] = bbox[[0, 2]] * output_w
  488. bbox[[1, 3]] = bbox[[1, 3]] * output_h
  489. bbox_amodal = copy.deepcopy(bbox)
  490. bbox_amodal[0] = bbox_amodal[0] - bbox_amodal[2] / 2.
  491. bbox_amodal[1] = bbox_amodal[1] - bbox_amodal[3] / 2.
  492. bbox_amodal[2] = bbox_amodal[0] + bbox_amodal[2]
  493. bbox_amodal[3] = bbox_amodal[1] + bbox_amodal[3]
  494. bbox[0] = np.clip(bbox[0], 0, output_w - 1)
  495. bbox[1] = np.clip(bbox[1], 0, output_h - 1)
  496. h = bbox[3]
  497. w = bbox[2]
  498. bbox_xy = copy.deepcopy(bbox)
  499. bbox_xy[0] = bbox_xy[0] - bbox_xy[2] / 2
  500. bbox_xy[1] = bbox_xy[1] - bbox_xy[3] / 2
  501. bbox_xy[2] = bbox_xy[0] + bbox_xy[2]
  502. bbox_xy[3] = bbox_xy[1] + bbox_xy[3]
  503. if h > 0 and w > 0:
  504. radius = self.gaussian_radius((math.ceil(h), math.ceil(w)))
  505. radius = max(0, int(radius))
  506. ct = np.array([bbox[0], bbox[1]], dtype=np.float32)
  507. ct_int = ct.astype(np.int32)
  508. self.draw_truncate_gaussian(heatmap[cls_id], ct_int, radius,
  509. radius)
  510. bbox_size[k] = ct[0] - bbox_amodal[0], ct[1] - bbox_amodal[1], \
  511. bbox_amodal[2] - ct[0], bbox_amodal[3] - ct[1]
  512. index[k] = ct_int[1] * output_w + ct_int[0]
  513. center_offset[k] = ct - ct_int
  514. index_mask[k] = 1
  515. reid[k] = ide
  516. bbox_xys[k] = bbox_xy
  517. sample['heatmap'] = heatmap
  518. sample['index'] = index
  519. sample['offset'] = center_offset
  520. sample['size'] = bbox_size
  521. sample['index_mask'] = index_mask
  522. sample['reid'] = reid
  523. sample['bbox_xys'] = bbox_xys
  524. sample.pop('is_crowd', None)
  525. sample.pop('difficult', None)
  526. sample.pop('gt_class', None)
  527. sample.pop('gt_bbox', None)
  528. sample.pop('gt_score', None)
  529. sample.pop('gt_ide', None)
  530. return samples
  531. def gaussian_radius(self, det_size, min_overlap=0.7):
  532. height, width = det_size
  533. a1 = 1
  534. b1 = (height + width)
  535. c1 = width * height * (1 - min_overlap) / (1 + min_overlap)
  536. sq1 = np.sqrt(b1**2 - 4 * a1 * c1)
  537. r1 = (b1 + sq1) / 2
  538. a2 = 4
  539. b2 = 2 * (height + width)
  540. c2 = (1 - min_overlap) * width * height
  541. sq2 = np.sqrt(b2**2 - 4 * a2 * c2)
  542. r2 = (b2 + sq2) / 2
  543. a3 = 4 * min_overlap
  544. b3 = -2 * min_overlap * (height + width)
  545. c3 = (min_overlap - 1) * width * height
  546. sq3 = np.sqrt(b3**2 - 4 * a3 * c3)
  547. r3 = (b3 + sq3) / 2
  548. return min(r1, r2, r3)