processors.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numbers
  15. import numpy as np
  16. from ....utils.deps import class_requires_deps, is_dep_available
  17. from ...common.reader.det_3d_reader import Sample
  18. from ...utils.benchmark import benchmark
  19. if is_dep_available("opencv-contrib-python"):
  20. import cv2
  21. @benchmark.timeit
  22. class LoadPointsFromFile:
  23. """Load points from a file and process them according to specified parameters."""
  24. def __init__(
  25. self, load_dim=6, use_dim=[0, 1, 2], shift_height=False, use_color=False
  26. ):
  27. """Initializes the LoadPointsFromFile object.
  28. Args:
  29. load_dim (int): Dimensions loaded in points.
  30. use_dim (list or int): Dimensions used in points. If int, will use a range from 0 to use_dim (exclusive).
  31. shift_height (bool): Whether to shift height values.
  32. use_color (bool): Whether to include color attributes in the loaded points.
  33. """
  34. self.shift_height = shift_height
  35. self.use_color = use_color
  36. if isinstance(use_dim, int):
  37. use_dim = list(range(use_dim))
  38. assert (
  39. max(use_dim) < load_dim
  40. ), f"Expect all used dimensions < {load_dim}, got {use_dim}"
  41. self.load_dim = load_dim
  42. self.use_dim = use_dim
  43. def _load_points(self, pts_filename):
  44. """Private function to load point clouds data from a file.
  45. Args:
  46. pts_filename (str): Path to the point cloud file.
  47. Returns:
  48. numpy.ndarray: Loaded point cloud data.
  49. """
  50. points = np.fromfile(pts_filename, dtype=np.float32)
  51. return points
  52. def __call__(self, results):
  53. """Call function to load points data from file and process it.
  54. Args:
  55. results (dict): Dictionary containing the 'pts_filename' key with the path to the point cloud file.
  56. Returns:
  57. dict: Updated results dictionary with 'points' key added.
  58. """
  59. pts_filename = results["pts_filename"]
  60. points = self._load_points(pts_filename)
  61. points = points.reshape(-1, self.load_dim)
  62. points = points[:, self.use_dim]
  63. attribute_dims = None
  64. if self.shift_height:
  65. floor_height = np.percentile(points[:, 2], 0.99)
  66. height = points[:, 2] - floor_height
  67. points = np.concatenate(
  68. [points[:, :3], np.expand_dims(height, 1), points[:, 3:]], 1
  69. )
  70. attribute_dims = dict(height=3)
  71. if self.use_color:
  72. assert len(self.use_dim) >= 6
  73. if attribute_dims is None:
  74. attribute_dims = dict()
  75. attribute_dims.update(
  76. dict(
  77. color=[
  78. points.shape[1] - 3,
  79. points.shape[1] - 2,
  80. points.shape[1] - 1,
  81. ]
  82. )
  83. )
  84. results["points"] = points
  85. return results
  86. @benchmark.timeit
  87. class LoadPointsFromMultiSweeps(object):
  88. """Load points from multiple sweeps.This is usually used for nuScenes dataset to utilize previous sweeps."""
  89. def __init__(
  90. self,
  91. sweeps_num=10,
  92. load_dim=5,
  93. use_dim=[0, 1, 2, 4],
  94. pad_empty_sweeps=False,
  95. remove_close=False,
  96. test_mode=False,
  97. point_cloud_angle_range=None,
  98. ):
  99. """Initializes the LoadPointsFromMultiSweeps object
  100. Args:
  101. sweeps_num (int): Number of sweeps. Defaults to 10.
  102. load_dim (int): Dimension number of the loaded points. Defaults to 5.
  103. use_dim (list[int]): Which dimension to use. Defaults to [0, 1, 2, 4].
  104. for more details. Defaults to dict(backend='disk').
  105. pad_empty_sweeps (bool): Whether to repeat keyframe when
  106. sweeps is empty. Defaults to False.
  107. remove_close (bool): Whether to remove close points.
  108. Defaults to False.
  109. test_mode (bool): If test_model=True used for testing, it will not
  110. randomly sample sweeps but select the nearest N frames.
  111. Defaults to False.
  112. """
  113. self.load_dim = load_dim
  114. self.sweeps_num = sweeps_num
  115. self.use_dim = use_dim
  116. self.pad_empty_sweeps = pad_empty_sweeps
  117. self.remove_close = remove_close
  118. self.test_mode = test_mode
  119. if point_cloud_angle_range is not None:
  120. self.filter_by_angle = True
  121. self.point_cloud_angle_range = point_cloud_angle_range
  122. print(point_cloud_angle_range)
  123. else:
  124. self.filter_by_angle = False
  125. # self.point_cloud_angle_range = point_cloud_angle_range
  126. def _load_points(self, pts_filename):
  127. """Private function to load point clouds data.
  128. Args:
  129. pts_filename (str): Filename of point clouds data.
  130. Returns:
  131. np.ndarray: An array containing point clouds data.
  132. """
  133. points = np.fromfile(pts_filename, dtype=np.float32)
  134. return points
  135. def _remove_close(self, points, radius=1.0):
  136. """Removes point too close within a certain radius from origin.
  137. Args:
  138. points (np.ndarray): Sweep points.
  139. radius (float): Radius below which points are removed.
  140. Defaults to 1.0.
  141. Returns:
  142. np.ndarray: Points after removing.
  143. """
  144. if isinstance(points, np.ndarray):
  145. points_numpy = points
  146. else:
  147. raise NotImplementedError
  148. x_filt = np.abs(points_numpy[:, 0]) < radius
  149. y_filt = np.abs(points_numpy[:, 1]) < radius
  150. not_close = np.logical_not(np.logical_and(x_filt, y_filt))
  151. return points[not_close]
  152. def filter_point_by_angle(self, points):
  153. """
  154. Filters points based on their angle in relation to the origin.
  155. Args:
  156. points (np.ndarray): An array of points with shape (N, 2), where each row
  157. is a point in 2D space.
  158. Returns:
  159. np.ndarray: A filtered array of points that fall within the specified
  160. angle range.
  161. """
  162. if isinstance(points, np.ndarray):
  163. points_numpy = points
  164. else:
  165. raise NotImplementedError
  166. pts_phi = (
  167. np.arctan(points_numpy[:, 0] / points_numpy[:, 1])
  168. + (points_numpy[:, 1] < 0) * np.pi
  169. + np.pi * 2
  170. ) % (np.pi * 2)
  171. pts_phi[pts_phi > np.pi] -= np.pi * 2
  172. pts_phi = pts_phi / np.pi * 180
  173. assert np.all(-180 <= pts_phi) and np.all(pts_phi <= 180)
  174. filt = np.logical_and(
  175. pts_phi >= self.point_cloud_angle_range[0],
  176. pts_phi <= self.point_cloud_angle_range[1],
  177. )
  178. return points[filt]
  179. def __call__(self, results):
  180. """Call function to load multi-sweep point clouds from files.
  181. Args:
  182. results (dict): Result dict containing multi-sweep point cloud \
  183. filenames.
  184. Returns:
  185. dict: The result dict containing the multi-sweep points data. \
  186. Added key and value are described below.
  187. - points (np.ndarray): Multi-sweep point cloud arrays.
  188. """
  189. points = results["points"]
  190. points[:, 4] = 0
  191. sweep_points_list = [points]
  192. ts = results["timestamp"]
  193. if self.pad_empty_sweeps and len(results["sweeps"]) == 0:
  194. for i in range(self.sweeps_num):
  195. if self.remove_close:
  196. sweep_points_list.append(self._remove_close(points))
  197. else:
  198. sweep_points_list.append(points)
  199. else:
  200. if len(results["sweeps"]) <= self.sweeps_num:
  201. choices = np.arange(len(results["sweeps"]))
  202. elif self.test_mode:
  203. choices = np.arange(self.sweeps_num)
  204. else:
  205. choices = np.random.choice(
  206. len(results["sweeps"]), self.sweeps_num, replace=False
  207. )
  208. for idx in choices:
  209. sweep = results["sweeps"][idx]
  210. points_sweep = self._load_points(sweep["data_path"])
  211. points_sweep = np.copy(points_sweep).reshape(-1, self.load_dim)
  212. if self.remove_close:
  213. points_sweep = self._remove_close(points_sweep)
  214. sweep_ts = sweep["timestamp"] / 1e6
  215. points_sweep[:, :3] = (
  216. points_sweep[:, :3] @ sweep["sensor2lidar_rotation"].T
  217. )
  218. points_sweep[:, :3] += sweep["sensor2lidar_translation"]
  219. points_sweep[:, 4] = ts - sweep_ts
  220. # points_sweep = points.new_point(points_sweep)
  221. sweep_points_list.append(points_sweep)
  222. points = np.concatenate(sweep_points_list, axis=0)
  223. if self.filter_by_angle:
  224. points = self.filter_point_by_angle(points)
  225. points = points[:, self.use_dim]
  226. results["points"] = points
  227. return results
  228. @benchmark.timeit
  229. @class_requires_deps("opencv-contrib-python")
  230. class LoadMultiViewImageFromFiles:
  231. """Load multi-view images from files."""
  232. def __init__(
  233. self,
  234. to_float32=False,
  235. project_pts_to_img_depth=False,
  236. cam_depth_range=[4.0, 45.0, 1.0],
  237. constant_std=0.5,
  238. imread_flag=-1,
  239. ):
  240. """
  241. Initializes the LoadMultiViewImageFromFiles object.
  242. Args:
  243. to_float32 (bool): Whether to convert the loaded images to float32. Default: False.
  244. project_pts_to_img_depth (bool): Whether to project points to image depth. Default: False.
  245. cam_depth_range (list): Camera depth range in the format [min, max, focal]. Default: [4.0, 45.0, 1.0].
  246. constant_std (float): Constant standard deviation for normalization. Default: 0.5.
  247. imread_flag (int): Flag determining the color type of the loaded image.
  248. - -1: cv2.IMREAD_UNCHANGED
  249. - 0: cv2.IMREAD_GRAYSCALE
  250. - 1: cv2.IMREAD_COLOR
  251. Default: -1.
  252. """
  253. self.to_float32 = to_float32
  254. self.project_pts_to_img_depth = project_pts_to_img_depth
  255. self.cam_depth_range = cam_depth_range
  256. self.constant_std = constant_std
  257. self.imread_flag = imread_flag
  258. def __call__(self, sample):
  259. """
  260. Call method to load multi-view image from files and update the sample dictionary.
  261. Args:
  262. sample (dict): Dictionary containing the image filename key.
  263. Returns:
  264. dict: Updated sample dictionary with loaded images and additional information.
  265. """
  266. filename = sample["img_filename"]
  267. img = np.stack(
  268. [cv2.imread(name, self.imread_flag) for name in filename], axis=-1
  269. )
  270. if self.to_float32:
  271. img = img.astype(np.float32)
  272. sample["filename"] = filename
  273. sample["img"] = [img[..., i] for i in range(img.shape[-1])]
  274. sample["img_shape"] = img.shape
  275. sample["ori_shape"] = img.shape
  276. sample["pad_shape"] = img.shape
  277. # sample['scale_factor'] = 1.0
  278. num_channels = 1 if len(img.shape) < 3 else img.shape[2]
  279. sample["img_norm_cfg"] = dict(
  280. mean=np.zeros(num_channels, dtype=np.float32),
  281. std=np.ones(num_channels, dtype=np.float32),
  282. to_rgb=False,
  283. )
  284. sample["img_fields"] = ["img"]
  285. return sample
  286. @benchmark.timeit
  287. @class_requires_deps("opencv-contrib-python")
  288. class ResizeImage:
  289. """Resize images & bbox & mask."""
  290. def __init__(
  291. self,
  292. img_scale=None,
  293. multiscale_mode="range",
  294. ratio_range=None,
  295. keep_ratio=True,
  296. bbox_clip_border=True,
  297. backend="cv2",
  298. override=False,
  299. ):
  300. """Initializes the ResizeImage object.
  301. Args:
  302. img_scale (list or int, optional): The scale of the image. If a single integer is provided, it will be converted to a list. Defaults to None.
  303. multiscale_mode (str): The mode for multiscale resizing. Can be "value" or "range". Defaults to "range".
  304. ratio_range (list, optional): The range of image aspect ratios. Only used when img_scale is a single value. Defaults to None.
  305. keep_ratio (bool): Whether to keep the aspect ratio when resizing. Defaults to True.
  306. bbox_clip_border (bool): Whether to clip the bounding box to the image border. Defaults to True.
  307. backend (str): The backend to use for image resizing. Can be "cv2". Defaults to "cv2".
  308. override (bool): Whether to override certain resize parameters. Note: This option needs refactoring. Defaults to False.
  309. """
  310. if img_scale is None:
  311. self.img_scale = None
  312. else:
  313. if isinstance(img_scale, list):
  314. self.img_scale = img_scale
  315. else:
  316. self.img_scale = [img_scale]
  317. if ratio_range is not None:
  318. # mode 1: given a scale and a range of image ratio
  319. assert len(self.img_scale) == 1
  320. else:
  321. # mode 2: given multiple scales or a range of scales
  322. assert multiscale_mode in ["value", "range"]
  323. self.backend = backend
  324. self.multiscale_mode = multiscale_mode
  325. self.ratio_range = ratio_range
  326. self.keep_ratio = keep_ratio
  327. # TODO: refactor the override option in Resize
  328. self.override = override
  329. self.bbox_clip_border = bbox_clip_border
  330. @staticmethod
  331. def random_select(img_scales):
  332. """Randomly select an img_scale from the given list of candidates.
  333. Args:
  334. img_scales (list): A list of image scales to choose from.
  335. Returns:
  336. tuple: A tuple containing the selected image scale and its index in the list.
  337. """
  338. scale_idx = np.random.randint(len(img_scales))
  339. img_scale = img_scales[scale_idx]
  340. return img_scale, scale_idx
  341. @staticmethod
  342. def random_sample(img_scales):
  343. """
  344. Randomly sample an img_scale when `multiscale_mode` is set to 'range'.
  345. Args:
  346. img_scales (list of tuples): A list of tuples, where each tuple contains
  347. the minimum and maximum scale dimensions for an image.
  348. Returns:
  349. tuple: A tuple containing the randomly sampled img_scale (long_edge, short_edge)
  350. and None (to maintain function signature compatibility).
  351. """
  352. img_scale_long = [max(s) for s in img_scales]
  353. img_scale_short = [min(s) for s in img_scales]
  354. long_edge = np.random.randint(min(img_scale_long), max(img_scale_long) + 1)
  355. short_edge = np.random.randint(min(img_scale_short), max(img_scale_short) + 1)
  356. img_scale = (long_edge, short_edge)
  357. return img_scale, None
  358. @staticmethod
  359. def random_sample_ratio(img_scale, ratio_range):
  360. """
  361. Randomly sample an img_scale based on the specified ratio_range.
  362. Args:
  363. img_scale (list): A list of two integers representing the minimum and maximum
  364. scale for the image.
  365. ratio_range (tuple): A tuple of two floats representing the minimum and maximum
  366. ratio for sampling the img_scale.
  367. Returns:
  368. tuple: A tuple containing the sampled scale (as a tuple of two integers)
  369. and None.
  370. """
  371. assert isinstance(img_scale, list) and len(img_scale) == 2
  372. min_ratio, max_ratio = ratio_range
  373. assert min_ratio <= max_ratio
  374. ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio
  375. scale = int(img_scale[0] * ratio), int(img_scale[1] * ratio)
  376. return scale, None
  377. def _random_scale(self, results):
  378. """Randomly sample an img_scale according to `ratio_range` and `multiscale_mode`.
  379. Args:
  380. results (dict): A dictionary to store the sampled scale and its index.
  381. Returns:
  382. None. The sampled scale and its index are stored in `results` dictionary.
  383. """
  384. if self.ratio_range is not None:
  385. scale, scale_idx = self.random_sample_ratio(
  386. self.img_scale[0], self.ratio_range
  387. )
  388. elif len(self.img_scale) == 1:
  389. scale, scale_idx = self.img_scale[0], 0
  390. elif self.multiscale_mode == "range":
  391. scale, scale_idx = self.random_sample(self.img_scale)
  392. elif self.multiscale_mode == "value":
  393. scale, scale_idx = self.random_select(self.img_scale)
  394. else:
  395. raise NotImplementedError
  396. results["scale"] = scale
  397. results["scale_idx"] = scale_idx
  398. def _resize_img(self, results):
  399. """Resize images based on the scale factor provided in ``results['scale']`` while maintaining the aspect ratio if ``self.keep_ratio`` is True.
  400. Args:
  401. results (dict): A dictionary containing image fields and their corresponding scales.
  402. Returns:
  403. None. The ``results`` dictionary is modified in place with resized images and additional fields like `img_shape`, `pad_shape`, `scale_factor`, and `keep_ratio`.
  404. """
  405. for key in results.get("img_fields", ["img"]):
  406. for idx in range(len(results["img"])):
  407. if self.keep_ratio:
  408. img, scale_factor = self.imrescale(
  409. results[key][idx],
  410. results["scale"],
  411. interpolation="bilinear" if key == "img" else "nearest",
  412. return_scale=True,
  413. backend=self.backend,
  414. )
  415. new_h, new_w = img.shape[:2]
  416. h, w = results[key][idx].shape[:2]
  417. w_scale = new_w / w
  418. h_scale = new_h / h
  419. else:
  420. raise NotImplementedError
  421. results[key][idx] = img
  422. scale_factor = np.array(
  423. [w_scale, h_scale, w_scale, h_scale], dtype=np.float32
  424. )
  425. results["img_shape"] = img.shape
  426. # in case that there is no padding
  427. results["pad_shape"] = img.shape
  428. results["scale_factor"] = scale_factor
  429. results["keep_ratio"] = self.keep_ratio
  430. def rescale_size(self, old_size, scale, return_scale=False):
  431. """
  432. Calculate the new size to be rescaled to based on the given scale.
  433. Args:
  434. old_size (tuple): A tuple containing the width and height of the original size.
  435. scale (float, int, or list of int): The scale factor or a list of integers representing the maximum and minimum allowed size.
  436. return_scale (bool): Whether to return the scale factor along with the new size.
  437. Returns:
  438. tuple: A tuple containing the new size and optionally the scale factor if return_scale is True.
  439. """
  440. w, h = old_size
  441. if isinstance(scale, (float, int)):
  442. if scale <= 0:
  443. raise ValueError(f"Invalid scale {scale}, must be positive.")
  444. scale_factor = scale
  445. elif isinstance(scale, list):
  446. max_long_edge = max(scale)
  447. max_short_edge = min(scale)
  448. scale_factor = min(max_long_edge / max(h, w), max_short_edge / min(h, w))
  449. else:
  450. raise TypeError(
  451. f"Scale must be a number or list of int, but got {type(scale)}"
  452. )
  453. def _scale_size(size, scale):
  454. if isinstance(scale, (float, int)):
  455. scale = (scale, scale)
  456. w, h = size
  457. return int(w * float(scale[0]) + 0.5), int(h * float(scale[1]) + 0.5)
  458. new_size = _scale_size((w, h), scale_factor)
  459. if return_scale:
  460. return new_size, scale_factor
  461. else:
  462. return new_size
  463. def imrescale(
  464. self, img, scale, return_scale=False, interpolation="bilinear", backend=None
  465. ):
  466. """Resize image while keeping the aspect ratio.
  467. Args:
  468. img (numpy.ndarray): The input image.
  469. scale (float): The scaling factor.
  470. return_scale (bool): Whether to return the scaling factor along with the resized image.
  471. interpolation (str): The interpolation method to use. Defaults to 'bilinear'.
  472. backend (str): The backend to use for resizing. Defaults to None.
  473. Returns:
  474. tuple or numpy.ndarray: The resized image, and optionally the scaling factor.
  475. """
  476. h, w = img.shape[:2]
  477. new_size, scale_factor = self.rescale_size((w, h), scale, return_scale=True)
  478. rescaled_img = self.imresize(
  479. img, new_size, interpolation=interpolation, backend=backend
  480. )
  481. if return_scale:
  482. return rescaled_img, scale_factor
  483. else:
  484. return rescaled_img
  485. def imresize(
  486. self,
  487. img,
  488. size,
  489. return_scale=False,
  490. interpolation="bilinear",
  491. out=None,
  492. backend=None,
  493. ):
  494. """Resize an image to a given size.
  495. Args:
  496. img (numpy.ndarray): The input image to be resized.
  497. size (tuple): The new size for the image as (height, width).
  498. return_scale (bool): Whether to return the scaling factors along with the resized image.
  499. interpolation (str): The interpolation method to use. Default is 'bilinear'.
  500. out (numpy.ndarray, optional): Output array. If provided, it must have the same shape and dtype as the output array.
  501. backend (str, optional): The backend to use for resizing. Supported backends are 'cv2' and 'pillow'.
  502. Returns:
  503. numpy.ndarray or tuple: The resized image. If return_scale is True, returns a tuple containing the resized image and the scaling factors (w_scale, h_scale).
  504. """
  505. cv2_interp_codes = {
  506. "nearest": cv2.INTER_NEAREST,
  507. "bilinear": cv2.INTER_LINEAR,
  508. "bicubic": cv2.INTER_CUBIC,
  509. "area": cv2.INTER_AREA,
  510. "lanczos": cv2.INTER_LANCZOS4,
  511. }
  512. h, w = img.shape[:2]
  513. if backend not in ["cv2", "pillow"]:
  514. raise ValueError(
  515. f"backend: {backend} is not supported for resize."
  516. f"Supported backends are 'cv2', 'pillow'"
  517. )
  518. if backend == "pillow":
  519. raise NotImplementedError
  520. else:
  521. resized_img = cv2.resize(
  522. img, size, dst=out, interpolation=cv2_interp_codes[interpolation]
  523. )
  524. if not return_scale:
  525. return resized_img
  526. else:
  527. w_scale = size[0] / w
  528. h_scale = size[1] / h
  529. return resized_img, w_scale, h_scale
  530. def _resize_bboxes(self, results):
  531. """Resize bounding boxes with `results['scale_factor']`.
  532. Args:
  533. results (dict): A dictionary containing the bounding boxes and other related information.
  534. """
  535. for key in results.get("bbox_fields", []):
  536. bboxes = results[key] * results["scale_factor"]
  537. if self.bbox_clip_border:
  538. img_shape = results["img_shape"]
  539. bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1])
  540. bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0])
  541. results[key] = bboxes
  542. def _resize_masks(self, results):
  543. """Resize masks with ``results['scale']``"""
  544. raise NotImplementedError
  545. def _resize_seg(self, results):
  546. """Resize semantic segmentation map with ``results['scale']``."""
  547. raise NotImplementedError
  548. def __call__(self, results):
  549. """Call function to resize images, bounding boxes, masks, and semantic segmentation maps according to the provided scale or scale factor.
  550. Args:
  551. results (dict): A dictionary containing the input data, including 'img', 'scale', and optionally 'scale_factor'.
  552. Returns:
  553. dict: A dictionary with the resized data.
  554. """
  555. if "scale" not in results:
  556. if "scale_factor" in results:
  557. img_shape = results["img"][0].shape[:2]
  558. scale_factor = results["scale_factor"]
  559. assert isinstance(scale_factor, float)
  560. results["scale"] = list(
  561. [int(x * scale_factor) for x in img_shape][::-1]
  562. )
  563. else:
  564. self._random_scale(results)
  565. else:
  566. if not self.override:
  567. assert (
  568. "scale_factor" not in results
  569. ), "scale and scale_factor cannot be both set."
  570. else:
  571. results.pop("scale")
  572. if "scale_factor" in results:
  573. results.pop("scale_factor")
  574. self._random_scale(results)
  575. self._resize_img(results)
  576. self._resize_bboxes(results)
  577. return results
  578. @benchmark.timeit
  579. @class_requires_deps("opencv-contrib-python")
  580. class NormalizeImage:
  581. """Normalize the image."""
  582. """Normalize an image by subtracting the mean and dividing by the standard deviation.
  583. Args:
  584. mean (list or tuple): Mean values for each channel.
  585. std (list or tuple): Standard deviation values for each channel.
  586. to_rgb (bool): Whether to convert the image from BGR to RGB.
  587. """
  588. def __init__(self, mean, std, to_rgb=True):
  589. """Initializes the NormalizeImage class with mean, std, and to_rgb parameters."""
  590. self.mean = np.array(mean, dtype=np.float32)
  591. self.std = np.array(std, dtype=np.float32)
  592. self.to_rgb = to_rgb
  593. def _imnormalize(self, img, mean, std, to_rgb=True):
  594. """Normalize the given image inplace.
  595. Args:
  596. img (numpy.ndarray): The image to normalize.
  597. mean (numpy.ndarray): Mean values for normalization.
  598. std (numpy.ndarray): Standard deviation values for normalization.
  599. to_rgb (bool): Whether to convert the image from BGR to RGB.
  600. Returns:
  601. numpy.ndarray: The normalized image.
  602. """
  603. img = img.copy().astype(np.float32)
  604. mean = np.float64(mean.reshape(1, -1))
  605. stdinv = 1 / np.float64(std.reshape(1, -1))
  606. if to_rgb:
  607. cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) # inplace
  608. cv2.subtract(img, mean, img) # inplace
  609. cv2.multiply(img, stdinv, img) # inplace
  610. return img
  611. def __call__(self, results):
  612. """Call method to normalize images in the results dictionary.
  613. Args:
  614. results (dict): A dictionary containing image fields to normalize.
  615. Returns:
  616. dict: The results dictionary with normalized images.
  617. """
  618. for key in results.get("img_fields", ["img"]):
  619. if key == "img_depth":
  620. continue
  621. for idx in range(len(results["img"])):
  622. results[key][idx] = self._imnormalize(
  623. results[key][idx], self.mean, self.std, self.to_rgb
  624. )
  625. results["img_norm_cfg"] = dict(mean=self.mean, std=self.std, to_rgb=self.to_rgb)
  626. return results
  627. @benchmark.timeit
  628. @class_requires_deps("opencv-contrib-python")
  629. class PadImage(object):
  630. """Pad the image & mask."""
  631. def __init__(self, size=None, size_divisor=None, pad_val=0):
  632. self.size = size
  633. self.size_divisor = size_divisor
  634. self.pad_val = pad_val
  635. # only one of size and size_divisor should be valid
  636. assert size is not None or size_divisor is not None
  637. assert size is None or size_divisor is None
  638. def impad(
  639. self, img, *, shape=None, padding=None, pad_val=0, padding_mode="constant"
  640. ):
  641. """Pad the given image to a certain shape or pad on all sides
  642. Args:
  643. img (numpy.ndarray): The input image to be padded.
  644. shape (tuple, optional): Desired output shape in the form (height, width). One of shape or padding must be specified.
  645. padding (int, tuple, optional): Number of pixels to pad on each side of the image. If a single int is provided this
  646. is used to pad all sides with this value. If a tuple of length 2 is provided this is interpreted as (top_bottom, left_right).
  647. If a tuple of length 4 is provided this is interpreted as (top, right, bottom, left).
  648. pad_val (int, list, optional): Pixel value used for padding. If a list is provided, it must have the same length as the
  649. last dimension of the input image. Defaults to 0.
  650. padding_mode (str, optional): Padding mode to use. One of 'constant', 'edge', 'reflect', 'symmetric'.
  651. Defaults to 'constant'.
  652. Returns:
  653. numpy.ndarray: The padded image.
  654. """
  655. assert (shape is not None) ^ (padding is not None)
  656. if shape is not None:
  657. padding = [0, 0, shape[1] - img.shape[1], shape[0] - img.shape[0]]
  658. # check pad_val
  659. if isinstance(pad_val, list):
  660. assert len(pad_val) == img.shape[-1]
  661. elif not isinstance(pad_val, numbers.Number):
  662. raise TypeError(
  663. "pad_val must be a int or a list. " f"But received {type(pad_val)}"
  664. )
  665. # check padding
  666. if isinstance(padding, list) and len(padding) in [2, 4]:
  667. if len(padding) == 2:
  668. padding = [padding[0], padding[1], padding[0], padding[1]]
  669. elif isinstance(padding, numbers.Number):
  670. padding = [padding, padding, padding, padding]
  671. else:
  672. raise ValueError(
  673. "Padding must be a int or a 2, or 4 element list."
  674. f"But received {padding}"
  675. )
  676. # check padding mode
  677. assert padding_mode in ["constant", "edge", "reflect", "symmetric"]
  678. border_type = {
  679. "constant": cv2.BORDER_CONSTANT,
  680. "edge": cv2.BORDER_REPLICATE,
  681. "reflect": cv2.BORDER_REFLECT_101,
  682. "symmetric": cv2.BORDER_REFLECT,
  683. }
  684. img = cv2.copyMakeBorder(
  685. img,
  686. padding[1],
  687. padding[3],
  688. padding[0],
  689. padding[2],
  690. border_type[padding_mode],
  691. value=pad_val,
  692. )
  693. return img
  694. def impad_to_multiple(self, img, divisor, pad_val=0):
  695. """
  696. Pad an image to ensure each edge length is a multiple of a given number.
  697. Args:
  698. img (numpy.ndarray): The input image.
  699. divisor (int): The number to which each edge length should be a multiple.
  700. pad_val (int, optional): The value to pad the image with. Defaults to 0.
  701. Returns:
  702. numpy.ndarray: The padded image.
  703. """
  704. pad_h = int(np.ceil(img.shape[0] / divisor)) * divisor
  705. pad_w = int(np.ceil(img.shape[1] / divisor)) * divisor
  706. return self.impad(img, shape=(pad_h, pad_w), pad_val=pad_val)
  707. def _pad_img(self, results):
  708. """
  709. Pad images according to ``self.size`` or adjust their shapes to be multiples of ``self.size_divisor``.
  710. Args:
  711. results (dict): A dictionary containing image data, with 'img_fields' as an optional key
  712. pointing to a list of image field names.
  713. """
  714. for key in results.get("img_fields", ["img"]):
  715. if self.size is not None:
  716. padded_img = self.impad(
  717. results[key], shape=self.size, pad_val=self.pad_val
  718. )
  719. elif self.size_divisor is not None:
  720. for idx in range(len(results[key])):
  721. padded_img = self.impad_to_multiple(
  722. results[key][idx], self.size_divisor, pad_val=self.pad_val
  723. )
  724. results[key][idx] = padded_img
  725. results["pad_shape"] = padded_img.shape
  726. results["pad_fixed_size"] = self.size
  727. results["pad_size_divisor"] = self.size_divisor
  728. def _pad_masks(self, results):
  729. """Pad masks according to ``results['pad_shape']``."""
  730. raise NotImplementedError
  731. def _pad_seg(self, results):
  732. """Pad semantic segmentation map according to ``results['pad_shape']``."""
  733. raise NotImplementedError
  734. def __call__(self, results):
  735. """Call function to pad images, masks, semantic segmentation maps."""
  736. self._pad_img(results)
  737. return results
  738. @benchmark.timeit
  739. class SampleFilterByKey:
  740. """Collect data from the loader relevant to the specific task."""
  741. def __init__(
  742. self,
  743. keys,
  744. meta_keys=(
  745. "filename",
  746. "ori_shape",
  747. "img_shape",
  748. "lidar2img",
  749. "depth2img",
  750. "cam2img",
  751. "pad_shape",
  752. "scale_factor",
  753. "flip",
  754. "pcd_horizontal_flip",
  755. "pcd_vertical_flip",
  756. "box_type_3d",
  757. "img_norm_cfg",
  758. "pcd_trans",
  759. "sample_idx",
  760. "pcd_scale_factor",
  761. "pcd_rotation",
  762. "pts_filename",
  763. "transformation_3d_flow",
  764. ),
  765. ):
  766. self.keys = keys
  767. self.meta_keys = meta_keys
  768. def __call__(self, sample):
  769. """Call function to filter sample by keys. The keys in `meta_keys` are used to filter metadata from the input sample.
  770. Args:
  771. sample (Sample): The input sample to be filtered.
  772. Returns:
  773. Sample: A new Sample object containing only the filtered metadata and specified keys.
  774. """
  775. filtered_sample = Sample(path=sample.path, modality=sample.modality)
  776. filtered_sample.meta.id = sample.meta.id
  777. img_metas = {}
  778. for key in self.meta_keys:
  779. if key in sample:
  780. img_metas[key] = sample[key]
  781. filtered_sample["img_metas"] = img_metas
  782. for key in self.keys:
  783. filtered_sample[key] = sample[key]
  784. return filtered_sample
  785. @benchmark.timeit
  786. class GetInferInput:
  787. """Collect infer input data from transformed sample"""
  788. def collate_fn(self, batch):
  789. sample = batch[0]
  790. collated_batch = {}
  791. collated_fields = [
  792. "img",
  793. "points",
  794. "img_metas",
  795. "gt_bboxes_3d",
  796. "gt_labels_3d",
  797. "modality",
  798. "meta",
  799. "idx",
  800. "img_depth",
  801. ]
  802. for k in list(sample.keys()):
  803. if k not in collated_fields:
  804. continue
  805. if k == "img":
  806. collated_batch[k] = np.stack([elem[k] for elem in batch], axis=0)
  807. elif k == "img_depth":
  808. collated_batch[k] = np.stack(
  809. [np.stack(elem[k], axis=0) for elem in batch], axis=0
  810. )
  811. else:
  812. collated_batch[k] = [elem[k] for elem in batch]
  813. return collated_batch
  814. def __call__(self, sample):
  815. """Call function to infer input data from transformed sample
  816. Args:
  817. sample (Sample): The input sample data.
  818. Returns:
  819. infer_input (list): A list containing all the input data for inference.
  820. sample_id (str): token id of the input sample.
  821. """
  822. if sample.modality == "multimodal" or sample.modality == "multiview":
  823. if "img" in sample.keys():
  824. sample.img = np.stack(
  825. [img.transpose(2, 0, 1) for img in sample.img], axis=0
  826. )
  827. sample = self.collate_fn([sample])
  828. infer_input = []
  829. img = sample.get("img", None)[0]
  830. infer_input.append(img.astype(np.float32))
  831. lidar2img = np.stack(sample["img_metas"][0]["lidar2img"]).astype(np.float32)
  832. infer_input.append(lidar2img)
  833. points = sample.get("points", None)[0]
  834. infer_input.append(points.astype(np.float32))
  835. img_metas = {
  836. "input_lidar_path": sample["img_metas"][0]["pts_filename"],
  837. "input_img_paths": sample["img_metas"][0]["filename"],
  838. "sample_id": sample["img_metas"][0]["sample_idx"],
  839. }
  840. return infer_input, img_metas