processors.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import List, Sequence, Tuple, Union, Optional
  15. import cv2
  16. import numpy as np
  17. from numpy import ndarray
  18. from ..common import Resize as CommonResize
  19. from ..common import Normalize as CommonNormalize
  20. from ...common.reader import ReadImage as CommonReadImage
  21. Boxes = List[dict]
  22. Number = Union[int, float]
  23. class ReadImage(CommonReadImage):
  24. """Reads images from a list of raw image data or file paths."""
  25. def __call__(self, raw_imgs: List[Union[ndarray, str, dict]]) -> List[dict]:
  26. """Processes the input list of raw image data or file paths and returns a list of dictionaries containing image information.
  27. Args:
  28. raw_imgs (List[Union[ndarray, str]]): A list of raw image data (numpy ndarrays) or file paths (strings).
  29. Returns:
  30. List[dict]: A list of dictionaries, each containing image information.
  31. """
  32. out_datas = []
  33. for raw_img in raw_imgs:
  34. data = dict()
  35. if isinstance(raw_img, str):
  36. data["img_path"] = raw_img
  37. if isinstance(raw_img, dict):
  38. if "img" in raw_img:
  39. src_img = raw_img["img"]
  40. elif "img_path" in raw_img:
  41. src_img = raw_img["img_path"]
  42. data["img_path"] = src_img
  43. else:
  44. raise ValueError(
  45. "When raw_img is dict, must have one of keys ['img', 'img_path']."
  46. )
  47. data.update(raw_img)
  48. raw_img = src_img
  49. img = self.read(raw_img)
  50. data["img"] = img
  51. data["ori_img"] = img
  52. data["img_size"] = [img.shape[1], img.shape[0]] # [size_w, size_h]
  53. data["ori_img_size"] = [img.shape[1], img.shape[0]] # [size_w, size_h]
  54. out_datas.append(data)
  55. return out_datas
  56. class Resize(CommonResize):
  57. def __call__(self, datas: List[dict]) -> List[dict]:
  58. """
  59. Args:
  60. datas (List[dict]): A list of dictionaries, each containing image data with key 'img'.
  61. Returns:
  62. List[dict]: A list of dictionaries with updated image data, including resized images,
  63. original image sizes, resized image sizes, and scale factors.
  64. """
  65. for data in datas:
  66. ori_img = data["img"]
  67. if "ori_img_size" not in data:
  68. data["ori_img_size"] = [ori_img.shape[1], ori_img.shape[0]]
  69. ori_img_size = data["ori_img_size"]
  70. img = self.resize(ori_img)
  71. data["img"] = img
  72. img_size = [img.shape[1], img.shape[0]]
  73. data["img_size"] = img_size # [size_w, size_h]
  74. data["scale_factors"] = [ # [w_scale, h_scale]
  75. img_size[0] / ori_img_size[0],
  76. img_size[1] / ori_img_size[1],
  77. ]
  78. return datas
  79. class Normalize(CommonNormalize):
  80. """Normalizes images in a list of dictionaries containing image data"""
  81. def apply(self, img: ndarray) -> ndarray:
  82. """Applies normalization to a single image."""
  83. old_type = img.dtype
  84. # XXX: If `old_type` has higher precision than float32,
  85. # we will lose some precision.
  86. img = img.astype("float32", copy=False)
  87. img *= self.scale
  88. img -= self.mean
  89. img /= self.std
  90. if self.preserve_dtype:
  91. img = img.astype(old_type, copy=False)
  92. return img
  93. def __call__(self, datas: List[dict]) -> List[dict]:
  94. """Normalizes images in a list of dictionaries. Iterates over each dictionary,
  95. applies normalization to the 'img' key, and returns the modified list.
  96. """
  97. for data in datas:
  98. data["img"] = self.apply(data["img"])
  99. return datas
  100. class ToCHWImage:
  101. """Converts images in a list of dictionaries from HWC to CHW format."""
  102. def __call__(self, datas: List[dict]) -> List[dict]:
  103. """Converts the image data in the list of dictionaries from HWC to CHW format in-place.
  104. Args:
  105. datas (List[dict]): A list of dictionaries, each containing an image tensor in 'img' key with HWC format.
  106. Returns:
  107. List[dict]: The same list of dictionaries with the image tensors converted to CHW format.
  108. """
  109. for data in datas:
  110. data["img"] = data["img"].transpose((2, 0, 1))
  111. return datas
  112. class ToBatch:
  113. """
  114. Class for batch processing of data dictionaries.
  115. Args:
  116. ordered_required_keys (Optional[Tuple[str]]): A tuple of keys that need to be present in the input data dictionaries in a specific order.
  117. """
  118. def __init__(self, ordered_required_keys: Optional[Tuple[str]] = None):
  119. self.ordered_required_keys = ordered_required_keys
  120. def apply(
  121. self, datas: List[dict], key: str, dtype: np.dtype = np.float32
  122. ) -> np.ndarray:
  123. """
  124. Apply batch processing to a list of data dictionaries.
  125. Args:
  126. datas (List[dict]): A list of data dictionaries to process.
  127. key (str): The key in the data dictionaries to extract and batch.
  128. dtype (np.dtype): The desired data type of the output array (default is np.float32).
  129. Returns:
  130. np.ndarray: A numpy array containing the batched data.
  131. Raises:
  132. KeyError: If the specified key is not found in any of the data dictionaries.
  133. """
  134. if key == "img_size":
  135. # [h, w] size for det models
  136. img_sizes = [data[key][::-1] for data in datas]
  137. return np.stack(img_sizes, axis=0).astype(dtype=dtype, copy=False)
  138. elif key == "scale_factors":
  139. # [h, w] scale factors for det models, default [1.0, 1.0]
  140. scale_factors = [data.get(key, [1.0, 1.0])[::-1] for data in datas]
  141. return np.stack(scale_factors, axis=0).astype(dtype=dtype, copy=False)
  142. else:
  143. return np.stack([data[key] for data in datas], axis=0).astype(
  144. dtype=dtype, copy=False
  145. )
  146. def __call__(self, datas: List[dict]) -> Sequence[ndarray]:
  147. return [self.apply(datas, key) for key in self.ordered_required_keys]
  148. class DetPad:
  149. """
  150. Pad image to a specified size.
  151. Args:
  152. size (list[int]): image target size
  153. fill_value (list[float]): rgb value of pad area, default (114.0, 114.0, 114.0)
  154. """
  155. def __init__(
  156. self,
  157. size: List[int],
  158. fill_value: List[Union[int, float]] = [114.0, 114.0, 114.0],
  159. ):
  160. super().__init__()
  161. if isinstance(size, int):
  162. size = [size, size]
  163. self.size = size
  164. self.fill_value = fill_value
  165. def apply(self, img: ndarray) -> ndarray:
  166. im = img
  167. im_h, im_w = im.shape[:2]
  168. h, w = self.size
  169. if h == im_h and w == im_w:
  170. return im
  171. canvas = np.ones((h, w, 3), dtype=np.float32)
  172. canvas *= np.array(self.fill_value, dtype=np.float32)
  173. canvas[0:im_h, 0:im_w, :] = im.astype(np.float32)
  174. return canvas
  175. def __call__(self, datas: List[dict]) -> List[dict]:
  176. for data in datas:
  177. data["img"] = self.apply(data["img"])
  178. return datas
  179. class PadStride:
  180. """padding image for model with FPN , instead PadBatch(pad_to_stride, pad_gt) in original config
  181. Args:
  182. stride (bool): model with FPN need image shape % stride == 0
  183. """
  184. def __init__(self, stride: int = 0):
  185. super().__init__()
  186. self.coarsest_stride = stride
  187. def apply(self, img: ndarray):
  188. """
  189. Args:
  190. im (np.ndarray): image (np.ndarray)
  191. Returns:
  192. im (np.ndarray): processed image (np.ndarray)
  193. """
  194. im = img
  195. coarsest_stride = self.coarsest_stride
  196. if coarsest_stride <= 0:
  197. return img
  198. im_c, im_h, im_w = im.shape
  199. pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride)
  200. pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride)
  201. padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32)
  202. padding_im[:, :im_h, :im_w] = im
  203. return padding_im
  204. def __call__(self, datas: List[dict]) -> List[dict]:
  205. for data in datas:
  206. data["img"] = self.apply(data["img"])
  207. return datas
  208. def rotate_point(pt: List[float], angle_rad: float) -> List[float]:
  209. """Rotate a point by an angle.
  210. Args:
  211. pt (list[float]): 2 dimensional point to be rotated
  212. angle_rad (float): rotation angle by radian
  213. Returns:
  214. list[float]: Rotated point.
  215. """
  216. assert len(pt) == 2
  217. sn, cs = np.sin(angle_rad), np.cos(angle_rad)
  218. new_x = pt[0] * cs - pt[1] * sn
  219. new_y = pt[0] * sn + pt[1] * cs
  220. rotated_pt = [new_x, new_y]
  221. return rotated_pt
  222. def _get_3rd_point(a: ndarray, b: ndarray) -> ndarray:
  223. """To calculate the affine matrix, three pairs of points are required. This
  224. function is used to get the 3rd point, given 2D points a & b.
  225. The 3rd point is defined by rotating vector `a - b` by 90 degrees
  226. anticlockwise, using b as the rotation center.
  227. Args:
  228. a (np.ndarray): point(x,y)
  229. b (np.ndarray): point(x,y)
  230. Returns:
  231. np.ndarray: The 3rd point.
  232. """
  233. assert len(a) == 2
  234. assert len(b) == 2
  235. direction = a - b
  236. third_pt = b + np.array([-direction[1], direction[0]], dtype=np.float32)
  237. return third_pt
  238. def get_affine_transform(
  239. center: ndarray,
  240. input_size: Union[Number, Tuple[Number, Number], ndarray],
  241. rot: float,
  242. output_size: ndarray,
  243. shift: Tuple[float, float] = (0.0, 0.0),
  244. inv: bool = False,
  245. ):
  246. """Get the affine transform matrix, given the center/scale/rot/output_size.
  247. Args:
  248. center (np.ndarray[2, ]): Center of the bounding box (x, y).
  249. input_size (np.ndarray[2, ]): Scale of the bounding box
  250. wrt [width, height].
  251. rot (float): Rotation angle (degree).
  252. output_size (np.ndarray[2, ]): Size of the destination heatmaps.
  253. shift (0-100%): Shift translation ratio wrt the width/height.
  254. Default (0., 0.).
  255. inv (bool): Option to inverse the affine transform direction.
  256. (inv=False: src->dst or inv=True: dst->src)
  257. Returns:
  258. np.ndarray: The transform matrix.
  259. """
  260. assert len(center) == 2
  261. assert len(output_size) == 2
  262. assert len(shift) == 2
  263. if not isinstance(input_size, (ndarray, list)):
  264. input_size = np.array([input_size, input_size], dtype=np.float32)
  265. scale_tmp = input_size
  266. shift = np.array(shift)
  267. src_w = scale_tmp[0]
  268. dst_w = output_size[0]
  269. dst_h = output_size[1]
  270. rot_rad = np.pi * rot / 180
  271. src_dir = rotate_point([0.0, src_w * -0.5], rot_rad)
  272. dst_dir = np.array([0.0, dst_w * -0.5])
  273. src = np.zeros((3, 2), dtype=np.float32)
  274. src[0, :] = center + scale_tmp * shift
  275. src[1, :] = center + src_dir + scale_tmp * shift
  276. src[2, :] = _get_3rd_point(src[0, :], src[1, :])
  277. dst = np.zeros((3, 2), dtype=np.float32)
  278. dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
  279. dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
  280. dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
  281. if inv:
  282. trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
  283. else:
  284. trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
  285. return trans
  286. class WarpAffine:
  287. """Apply warp affine transformation to the image based on the given parameters.
  288. Args:
  289. keep_res (bool): Whether to keep the original resolution aspect ratio during transformation.
  290. pad (int): Padding value used when keep_res is True.
  291. input_h (int): Target height for the input image when keep_res is False.
  292. input_w (int): Target width for the input image when keep_res is False.
  293. scale (float): Scale factor for resizing.
  294. shift (float): Shift factor for transformation.
  295. down_ratio (int): Downsampling ratio for the output image.
  296. """
  297. def __init__(
  298. self,
  299. keep_res=False,
  300. pad=31,
  301. input_h=512,
  302. input_w=512,
  303. scale=0.4,
  304. shift=0.1,
  305. down_ratio=4,
  306. ):
  307. super().__init__()
  308. self.keep_res = keep_res
  309. self.pad = pad
  310. self.input_h = input_h
  311. self.input_w = input_w
  312. self.scale = scale
  313. self.shift = shift
  314. self.down_ratio = down_ratio
  315. def apply(self, img: ndarray):
  316. img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
  317. h, w = img.shape[:2]
  318. if self.keep_res:
  319. # True in detection eval/infer
  320. input_h = (h | self.pad) + 1
  321. input_w = (w | self.pad) + 1
  322. s = np.array([input_w, input_h], dtype=np.float32)
  323. c = np.array([w // 2, h // 2], dtype=np.float32)
  324. else:
  325. # False in centertrack eval_mot/eval_mot
  326. s = max(h, w) * 1.0
  327. input_h, input_w = self.input_h, self.input_w
  328. c = np.array([w / 2.0, h / 2.0], dtype=np.float32)
  329. trans_input = get_affine_transform(c, s, 0, [input_w, input_h])
  330. img = cv2.resize(img, (w, h))
  331. inp = cv2.warpAffine(
  332. img, trans_input, (input_w, input_h), flags=cv2.INTER_LINEAR
  333. )
  334. if not self.keep_res:
  335. out_h = input_h // self.down_ratio
  336. out_w = input_w // self.down_ratio
  337. trans_output = get_affine_transform(c, s, 0, [out_w, out_h])
  338. return inp
  339. def __call__(self, datas: List[dict]) -> List[dict]:
  340. for data in datas:
  341. ori_img = data["img"]
  342. if "ori_img_size" not in data:
  343. data["ori_img_size"] = [ori_img.shape[1], ori_img.shape[0]]
  344. img = self.apply(ori_img)
  345. data["img"] = img
  346. return datas
  347. def compute_iou(box1: List[Number], box2: List[Number]) -> float:
  348. """Compute the Intersection over Union (IoU) of two bounding boxes.
  349. Args:
  350. box1 (List[Number]): Coordinates of the first bounding box in format [x1, y1, x2, y2].
  351. box2 (List[Number]): Coordinates of the second bounding box in format [x1, y1, x2, y2].
  352. Returns:
  353. float: The IoU of the two bounding boxes.
  354. """
  355. x1 = max(box1[0], box2[0])
  356. y1 = max(box1[1], box2[1])
  357. x2 = min(box1[2], box2[2])
  358. y2 = min(box1[3], box2[3])
  359. inter_area = max(0, x2 - x1 + 1) * max(0, y2 - y1 + 1)
  360. box1_area = (box1[2] - box1[0] + 1) * (box1[3] - box1[1] + 1)
  361. box2_area = (box2[2] - box2[0] + 1) * (box2[3] - box2[1] + 1)
  362. iou = inter_area / float(box1_area + box2_area - inter_area)
  363. return iou
  364. def is_box_mostly_inside(
  365. inner_box: List[Number], outer_box: List[Number], threshold: float = 0.9
  366. ) -> bool:
  367. """Determine if one bounding box is mostly inside another bounding box.
  368. Args:
  369. inner_box (List[Number]): Coordinates of the inner bounding box in format [x1, y1, x2, y2].
  370. outer_box (List[Number]): Coordinates of the outer bounding box in format [x1, y1, x2, y2].
  371. threshold (float): The threshold for determining if the inner box is mostly inside the outer box (default is 0.9).
  372. Returns:
  373. bool: True if the ratio of the intersection area to the inner box area is greater than or equal to the threshold, False otherwise.
  374. """
  375. x1 = max(inner_box[0], outer_box[0])
  376. y1 = max(inner_box[1], outer_box[1])
  377. x2 = min(inner_box[2], outer_box[2])
  378. y2 = min(inner_box[3], outer_box[3])
  379. inter_area = max(0, x2 - x1 + 1) * max(0, y2 - y1 + 1)
  380. inner_box_area = (inner_box[2] - inner_box[0] + 1) * (
  381. inner_box[3] - inner_box[1] + 1
  382. )
  383. return (inter_area / inner_box_area) >= threshold
  384. def restructured_boxes(
  385. boxes: ndarray, labels: List[str], img_size: Tuple[int, int]
  386. ) -> Boxes:
  387. """
  388. Restructure the given bounding boxes and labels based on the image size.
  389. Args:
  390. boxes (ndarray): A 2D array of bounding boxes with each box represented as [cls_id, score, xmin, ymin, xmax, ymax].
  391. labels (List[str]): A list of class labels corresponding to the class ids.
  392. img_size (Tuple[int, int]): A tuple representing the width and height of the image.
  393. Returns:
  394. Boxes: A list of dictionaries, each containing 'cls_id', 'label', 'score', and 'coordinate' keys.
  395. """
  396. box_list = []
  397. w, h = img_size
  398. for box in boxes:
  399. xmin, ymin, xmax, ymax = box[2:]
  400. xmin = max(0, xmin)
  401. ymin = max(0, ymin)
  402. xmax = min(w, xmax)
  403. ymax = min(h, ymax)
  404. box_list.append(
  405. {
  406. "cls_id": int(box[0]),
  407. "label": labels[int(box[0])],
  408. "score": float(box[1]),
  409. "coordinate": [xmin, ymin, xmax, ymax],
  410. }
  411. )
  412. return box_list
  413. def restructured_rotated_boxes(
  414. boxes: ndarray, labels: List[str], img_size: Tuple[int, int]
  415. ) -> Boxes:
  416. """
  417. Restructure the given rotated bounding boxes and labels based on the image size.
  418. Args:
  419. boxes (ndarray): A 2D array of rotated bounding boxes with each box represented as [cls_id, score, x1, y1, x2, y2, x3, y3, x4, y4].
  420. labels (List[str]): A list of class labels corresponding to the class ids.
  421. img_size (Tuple[int, int]): A tuple representing the width and height of the image.
  422. Returns:
  423. Boxes: A list of dictionaries, each containing 'cls_id', 'label', 'score', and 'coordinate' keys.
  424. """
  425. box_list = []
  426. w, h = img_size
  427. assert boxes.shape[1] == 10, "The shape of rotated boxes should be [N, 10]"
  428. for box in boxes:
  429. x1, y1, x2, y2, x3, y3, x4, y4 = box[2:]
  430. x1 = min(max(0, x1), w)
  431. y1 = min(max(0, y1), h)
  432. x2 = min(max(0, x2), w)
  433. y2 = min(max(0, y2), h)
  434. x3 = min(max(0, x3), w)
  435. y3 = min(max(0, y3), h)
  436. x4 = min(max(0, x4), w)
  437. y4 = min(max(0, y4), h)
  438. box_list.append(
  439. {
  440. "cls_id": int(box[0]),
  441. "label": labels[int(box[0])],
  442. "score": float(box[1]),
  443. "coordinate": [x1, y1, x2, y2, x3, y3, x4, y4],
  444. }
  445. )
  446. return box_list
  447. def non_max_suppression(
  448. boxes: ndarray, scores: ndarray, iou_threshold: float
  449. ) -> List[int]:
  450. """
  451. Perform non-maximum suppression to remove redundant overlapping boxes with
  452. lower scores. This function is commonly used in object detection tasks.
  453. Parameters:
  454. boxes (ndarray): An array of shape (N, 4) representing the bounding boxes.
  455. Each row is in the format [x1, y1, x2, y2].
  456. scores (ndarray): An array of shape (N,) containing the scores for each box.
  457. iou_threshold (float): The Intersection over Union (IoU) threshold to use
  458. for suppressing overlapping boxes.
  459. Returns:
  460. List[int]: A list of indices representing the indices of the boxes to keep.
  461. """
  462. if len(boxes) == 0:
  463. return []
  464. x1 = boxes[:, 0]
  465. y1 = boxes[:, 1]
  466. x2 = boxes[:, 2]
  467. y2 = boxes[:, 3]
  468. areas = (x2 - x1 + 1) * (y2 - y1 + 1)
  469. order = scores.argsort()[::-1]
  470. keep = []
  471. while order.size > 0:
  472. i = order[0]
  473. keep.append(i)
  474. xx1 = np.maximum(x1[i], x1[order[1:]])
  475. yy1 = np.maximum(y1[i], y1[order[1:]])
  476. xx2 = np.minimum(x2[i], x2[order[1:]])
  477. yy2 = np.minimum(y2[i], y2[order[1:]])
  478. w = np.maximum(0.0, xx2 - xx1 + 1)
  479. h = np.maximum(0.0, yy2 - yy1 + 1)
  480. inter = w * h
  481. iou = inter / (areas[i] + areas[order[1:]] - inter)
  482. inds = np.where(iou <= iou_threshold)[0]
  483. order = order[inds + 1]
  484. return keep
  485. class DetPostProcess:
  486. """Save Result Transform
  487. This class is responsible for post-processing detection results, including
  488. thresholding, non-maximum suppression (NMS), and restructuring the boxes
  489. based on the input type (normal or rotated object detection).
  490. """
  491. def __init__(
  492. self,
  493. threshold: float = 0.5,
  494. labels: Optional[List[str]] = None,
  495. layout_postprocess: bool = False,
  496. ) -> None:
  497. """Initialize the DetPostProcess class.
  498. Args:
  499. threshold (float, optional): The threshold to apply to the detection scores. Defaults to 0.5.
  500. labels (Optional[List[str]], optional): The list of labels for the detection categories. Defaults to None.
  501. layout_postprocess (bool, optional): Whether to apply layout post-processing. Defaults to False.
  502. """
  503. super().__init__()
  504. self.threshold = threshold
  505. self.labels = labels
  506. self.layout_postprocess = layout_postprocess
  507. def apply(self, boxes: ndarray, img_size, threshold: Union[float, dict]) -> Boxes:
  508. """Apply post-processing to the detection boxes.
  509. Args:
  510. boxes (ndarray): The input detection boxes with scores.
  511. img_size (tuple): The original image size.
  512. Returns:
  513. Boxes: The post-processed detection boxes.
  514. """
  515. if isinstance(threshold, float):
  516. expect_boxes = (boxes[:, 1] > threshold) & (boxes[:, 0] > -1)
  517. boxes = boxes[expect_boxes, :]
  518. elif isinstance(threshold, dict):
  519. category_filtered_boxes = []
  520. for cat_id in np.unique(boxes[:, 0]):
  521. category_boxes = boxes[boxes[:, 0] == cat_id]
  522. category_scores = category_boxes[:, 1]
  523. category_threshold = threshold.get(int(cat_id), 0.5)
  524. selected_indices = category_scores > category_threshold
  525. category_filtered_boxes.append(category_boxes[selected_indices])
  526. boxes = (
  527. np.vstack(category_filtered_boxes)
  528. if category_filtered_boxes
  529. else np.array([])
  530. )
  531. if self.layout_postprocess:
  532. filtered_boxes = []
  533. ### Layout postprocess for NMS
  534. for cat_id in np.unique(boxes[:, 0]):
  535. category_boxes = boxes[boxes[:, 0] == cat_id]
  536. category_scores = category_boxes[:, 1]
  537. if len(category_boxes) > 0:
  538. nms_indices = non_max_suppression(
  539. category_boxes[:, 2:], category_scores, 0.5
  540. )
  541. category_boxes = category_boxes[nms_indices]
  542. keep_boxes = []
  543. for i, box in enumerate(category_boxes):
  544. if all(
  545. not is_box_mostly_inside(box[2:], other_box[2:])
  546. for j, other_box in enumerate(category_boxes)
  547. if i != j
  548. ):
  549. keep_boxes.append(box)
  550. filtered_boxes.extend(keep_boxes)
  551. boxes = np.array(filtered_boxes)
  552. ### Layout postprocess for removing boxes inside image category box
  553. if self.labels and "image" in self.labels:
  554. image_cls_id = self.labels.index("image")
  555. if len(boxes) > 0:
  556. image_boxes = boxes[boxes[:, 0] == image_cls_id]
  557. other_boxes = boxes[boxes[:, 0] != image_cls_id]
  558. to_keep = []
  559. for box in other_boxes:
  560. keep = True
  561. for img_box in image_boxes:
  562. if (
  563. box[2] >= img_box[2]
  564. and box[3] >= img_box[3]
  565. and box[4] <= img_box[4]
  566. and box[5] <= img_box[5]
  567. ):
  568. keep = False
  569. break
  570. if keep:
  571. to_keep.append(box)
  572. boxes = (
  573. np.vstack([image_boxes, to_keep]) if to_keep else image_boxes
  574. )
  575. ### Layout postprocess for overlaps
  576. final_boxes = []
  577. while len(boxes) > 0:
  578. current_box = boxes[0]
  579. current_score = current_box[1]
  580. overlaps = [current_box]
  581. non_overlaps = []
  582. for other_box in boxes[1:]:
  583. iou = compute_iou(current_box[2:], other_box[2:])
  584. if iou > 0.95:
  585. if other_box[1] > current_score:
  586. overlaps.append(other_box)
  587. else:
  588. non_overlaps.append(other_box)
  589. best_box = max(overlaps, key=lambda x: x[1])
  590. final_boxes.append(best_box)
  591. boxes = np.array(non_overlaps)
  592. boxes = np.array(final_boxes)
  593. if boxes.shape[1] == 6:
  594. """For Normal Object Detection"""
  595. boxes = restructured_boxes(boxes, self.labels, img_size)
  596. elif boxes.shape[1] == 10:
  597. """Adapt For Rotated Object Detection"""
  598. boxes = restructured_rotated_boxes(boxes, self.labels, img_size)
  599. else:
  600. """Unexpected Input Box Shape"""
  601. raise ValueError(
  602. f"The shape of boxes should be 6 or 10, instead of {boxes.shape[1]}"
  603. )
  604. return boxes
  605. def __call__(
  606. self,
  607. batch_outputs: List[dict],
  608. datas: List[dict],
  609. threshold: Optional[Union[float, dict]] = None,
  610. ) -> List[Boxes]:
  611. """Apply the post-processing to a batch of outputs.
  612. Args:
  613. batch_outputs (List[dict]): The list of detection outputs.
  614. datas (List[dict]): The list of input data.
  615. Returns:
  616. List[Boxes]: The list of post-processed detection boxes.
  617. """
  618. outputs = []
  619. for data, output in zip(datas, batch_outputs):
  620. boxes = self.apply(
  621. output["boxes"],
  622. data["ori_img_size"],
  623. threshold if threshold is not None else self.threshold,
  624. )
  625. outputs.append(boxes)
  626. return outputs