processors.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import List, Sequence, Tuple, Union, Optional
  15. import cv2
  16. import numpy as np
  17. from numpy import ndarray
  18. from ..common import Resize as CommonResize
  19. from ..common import Normalize as CommonNormalize
  20. from ...common.reader import ReadImage as CommonReadImage
  21. Boxes = List[dict]
  22. Number = Union[int, float]
  23. class ReadImage(CommonReadImage):
  24. """Reads images from a list of raw image data or file paths."""
  25. def __call__(self, raw_imgs: List[Union[ndarray, str]]) -> List[dict]:
  26. """Processes the input list of raw image data or file paths and returns a list of dictionaries containing image information.
  27. Args:
  28. raw_imgs (List[Union[ndarray, str]]): A list of raw image data (numpy ndarrays) or file paths (strings).
  29. Returns:
  30. List[dict]: A list of dictionaries, each containing image information.
  31. """
  32. out_datas = []
  33. for raw_img in raw_imgs:
  34. data = dict()
  35. if isinstance(raw_img, str):
  36. data["img_path"] = raw_img
  37. img = self.read(raw_img)
  38. data["img"] = img
  39. data["ori_img"] = img
  40. data["img_size"] = [img.shape[1], img.shape[0]] # [size_w, size_h]
  41. data["ori_img_size"] = [img.shape[1], img.shape[0]] # [size_w, size_h]
  42. out_datas.append(data)
  43. return out_datas
  44. class Resize(CommonResize):
  45. def __call__(self, datas: List[dict]) -> List[dict]:
  46. """
  47. Args:
  48. datas (List[dict]): A list of dictionaries, each containing image data with key 'img'.
  49. Returns:
  50. List[dict]: A list of dictionaries with updated image data, including resized images,
  51. original image sizes, resized image sizes, and scale factors.
  52. """
  53. for data in datas:
  54. ori_img = data["img"]
  55. if "ori_img_size" not in data:
  56. data["ori_img_size"] = [ori_img.shape[1], ori_img.shape[0]]
  57. ori_img_size = data["ori_img_size"]
  58. img = self.resize(ori_img)
  59. data["img"] = img
  60. img_size = [img.shape[1], img.shape[0]]
  61. data["img_size"] = img_size # [size_w, size_h]
  62. data["scale_factors"] = [ # [w_scale, h_scale]
  63. img_size[0] / ori_img_size[0],
  64. img_size[1] / ori_img_size[1],
  65. ]
  66. return datas
  67. class Normalize(CommonNormalize):
  68. """Normalizes images in a list of dictionaries containing image data"""
  69. def apply(self, img: ndarray) -> ndarray:
  70. """Applies normalization to a single image."""
  71. old_type = img.dtype
  72. # XXX: If `old_type` has higher precision than float32,
  73. # we will lose some precision.
  74. img = img.astype("float32", copy=False)
  75. img *= self.scale
  76. img -= self.mean
  77. img /= self.std
  78. if self.preserve_dtype:
  79. img = img.astype(old_type, copy=False)
  80. return img
  81. def __call__(self, datas: List[dict]) -> List[dict]:
  82. """Normalizes images in a list of dictionaries. Iterates over each dictionary,
  83. applies normalization to the 'img' key, and returns the modified list.
  84. """
  85. for data in datas:
  86. data["img"] = self.apply(data["img"])
  87. return datas
  88. class ToCHWImage:
  89. """Converts images in a list of dictionaries from HWC to CHW format."""
  90. def __call__(self, datas: List[dict]) -> List[dict]:
  91. """Converts the image data in the list of dictionaries from HWC to CHW format in-place.
  92. Args:
  93. datas (List[dict]): A list of dictionaries, each containing an image tensor in 'img' key with HWC format.
  94. Returns:
  95. List[dict]: The same list of dictionaries with the image tensors converted to CHW format.
  96. """
  97. for data in datas:
  98. data["img"] = data["img"].transpose((2, 0, 1))
  99. return datas
  100. class ToBatch:
  101. """
  102. Class for batch processing of data dictionaries.
  103. Args:
  104. ordered_required_keys (Optional[Tuple[str]]): A tuple of keys that need to be present in the input data dictionaries in a specific order.
  105. """
  106. def __init__(self, ordered_required_keys: Optional[Tuple[str]] = None):
  107. self.ordered_required_keys = ordered_required_keys
  108. def apply(
  109. self, datas: List[dict], key: str, dtype: np.dtype = np.float32
  110. ) -> np.ndarray:
  111. """
  112. Apply batch processing to a list of data dictionaries.
  113. Args:
  114. datas (List[dict]): A list of data dictionaries to process.
  115. key (str): The key in the data dictionaries to extract and batch.
  116. dtype (np.dtype): The desired data type of the output array (default is np.float32).
  117. Returns:
  118. np.ndarray: A numpy array containing the batched data.
  119. Raises:
  120. KeyError: If the specified key is not found in any of the data dictionaries.
  121. """
  122. if key == "img_size":
  123. # [h, w] size for det models
  124. img_sizes = [data[key][::-1] for data in datas]
  125. return np.stack(img_sizes, axis=0).astype(dtype=dtype, copy=False)
  126. elif key == "scale_factors":
  127. # [h, w] scale factors for det models, default [1.0, 1.0]
  128. scale_factors = [data.get(key, [1.0, 1.0])[::-1] for data in datas]
  129. return np.stack(scale_factors, axis=0).astype(dtype=dtype, copy=False)
  130. else:
  131. return np.stack([data[key] for data in datas], axis=0).astype(
  132. dtype=dtype, copy=False
  133. )
  134. def __call__(self, datas: List[dict]) -> Sequence[ndarray]:
  135. return [self.apply(datas, key) for key in self.ordered_required_keys]
  136. class DetPad:
  137. """
  138. Pad image to a specified size.
  139. Args:
  140. size (list[int]): image target size
  141. fill_value (list[float]): rgb value of pad area, default (114.0, 114.0, 114.0)
  142. """
  143. def __init__(
  144. self,
  145. size: List[int],
  146. fill_value: List[Union[int, float]] = [114.0, 114.0, 114.0],
  147. ):
  148. super().__init__()
  149. if isinstance(size, int):
  150. size = [size, size]
  151. self.size = size
  152. self.fill_value = fill_value
  153. def apply(self, img: ndarray) -> ndarray:
  154. im = img
  155. im_h, im_w = im.shape[:2]
  156. h, w = self.size
  157. if h == im_h and w == im_w:
  158. return im
  159. canvas = np.ones((h, w, 3), dtype=np.float32)
  160. canvas *= np.array(self.fill_value, dtype=np.float32)
  161. canvas[0:im_h, 0:im_w, :] = im.astype(np.float32)
  162. return canvas
  163. def __call__(self, datas: List[dict]) -> List[dict]:
  164. for data in datas:
  165. data["img"] = self.apply(data["img"])
  166. return datas
  167. class PadStride:
  168. """padding image for model with FPN , instead PadBatch(pad_to_stride, pad_gt) in original config
  169. Args:
  170. stride (bool): model with FPN need image shape % stride == 0
  171. """
  172. def __init__(self, stride: int = 0):
  173. super().__init__()
  174. self.coarsest_stride = stride
  175. def apply(self, img: ndarray):
  176. """
  177. Args:
  178. im (np.ndarray): image (np.ndarray)
  179. Returns:
  180. im (np.ndarray): processed image (np.ndarray)
  181. """
  182. im = img
  183. coarsest_stride = self.coarsest_stride
  184. if coarsest_stride <= 0:
  185. return img
  186. im_c, im_h, im_w = im.shape
  187. pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride)
  188. pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride)
  189. padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32)
  190. padding_im[:, :im_h, :im_w] = im
  191. return padding_im
  192. def __call__(self, datas: List[dict]) -> List[dict]:
  193. for data in datas:
  194. data["img"] = self.apply(data["img"])
  195. return datas
  196. def rotate_point(pt: List[float], angle_rad: float) -> List[float]:
  197. """Rotate a point by an angle.
  198. Args:
  199. pt (list[float]): 2 dimensional point to be rotated
  200. angle_rad (float): rotation angle by radian
  201. Returns:
  202. list[float]: Rotated point.
  203. """
  204. assert len(pt) == 2
  205. sn, cs = np.sin(angle_rad), np.cos(angle_rad)
  206. new_x = pt[0] * cs - pt[1] * sn
  207. new_y = pt[0] * sn + pt[1] * cs
  208. rotated_pt = [new_x, new_y]
  209. return rotated_pt
  210. def _get_3rd_point(a: ndarray, b: ndarray) -> ndarray:
  211. """To calculate the affine matrix, three pairs of points are required. This
  212. function is used to get the 3rd point, given 2D points a & b.
  213. The 3rd point is defined by rotating vector `a - b` by 90 degrees
  214. anticlockwise, using b as the rotation center.
  215. Args:
  216. a (np.ndarray): point(x,y)
  217. b (np.ndarray): point(x,y)
  218. Returns:
  219. np.ndarray: The 3rd point.
  220. """
  221. assert len(a) == 2
  222. assert len(b) == 2
  223. direction = a - b
  224. third_pt = b + np.array([-direction[1], direction[0]], dtype=np.float32)
  225. return third_pt
  226. def get_affine_transform(
  227. center: ndarray,
  228. input_size: Union[Number, Tuple[Number, Number], ndarray],
  229. rot: float,
  230. output_size: ndarray,
  231. shift: Tuple[float, float] = (0.0, 0.0),
  232. inv: bool = False,
  233. ):
  234. """Get the affine transform matrix, given the center/scale/rot/output_size.
  235. Args:
  236. center (np.ndarray[2, ]): Center of the bounding box (x, y).
  237. input_size (np.ndarray[2, ]): Scale of the bounding box
  238. wrt [width, height].
  239. rot (float): Rotation angle (degree).
  240. output_size (np.ndarray[2, ]): Size of the destination heatmaps.
  241. shift (0-100%): Shift translation ratio wrt the width/height.
  242. Default (0., 0.).
  243. inv (bool): Option to inverse the affine transform direction.
  244. (inv=False: src->dst or inv=True: dst->src)
  245. Returns:
  246. np.ndarray: The transform matrix.
  247. """
  248. assert len(center) == 2
  249. assert len(output_size) == 2
  250. assert len(shift) == 2
  251. if not isinstance(input_size, (ndarray, list)):
  252. input_size = np.array([input_size, input_size], dtype=np.float32)
  253. scale_tmp = input_size
  254. shift = np.array(shift)
  255. src_w = scale_tmp[0]
  256. dst_w = output_size[0]
  257. dst_h = output_size[1]
  258. rot_rad = np.pi * rot / 180
  259. src_dir = rotate_point([0.0, src_w * -0.5], rot_rad)
  260. dst_dir = np.array([0.0, dst_w * -0.5])
  261. src = np.zeros((3, 2), dtype=np.float32)
  262. src[0, :] = center + scale_tmp * shift
  263. src[1, :] = center + src_dir + scale_tmp * shift
  264. src[2, :] = _get_3rd_point(src[0, :], src[1, :])
  265. dst = np.zeros((3, 2), dtype=np.float32)
  266. dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
  267. dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
  268. dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
  269. if inv:
  270. trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
  271. else:
  272. trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
  273. return trans
  274. class WarpAffine:
  275. """Apply warp affine transformation to the image based on the given parameters.
  276. Args:
  277. keep_res (bool): Whether to keep the original resolution aspect ratio during transformation.
  278. pad (int): Padding value used when keep_res is True.
  279. input_h (int): Target height for the input image when keep_res is False.
  280. input_w (int): Target width for the input image when keep_res is False.
  281. scale (float): Scale factor for resizing.
  282. shift (float): Shift factor for transformation.
  283. down_ratio (int): Downsampling ratio for the output image.
  284. """
  285. def __init__(
  286. self,
  287. keep_res=False,
  288. pad=31,
  289. input_h=512,
  290. input_w=512,
  291. scale=0.4,
  292. shift=0.1,
  293. down_ratio=4,
  294. ):
  295. super().__init__()
  296. self.keep_res = keep_res
  297. self.pad = pad
  298. self.input_h = input_h
  299. self.input_w = input_w
  300. self.scale = scale
  301. self.shift = shift
  302. self.down_ratio = down_ratio
  303. def apply(self, img: ndarray):
  304. img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
  305. h, w = img.shape[:2]
  306. if self.keep_res:
  307. # True in detection eval/infer
  308. input_h = (h | self.pad) + 1
  309. input_w = (w | self.pad) + 1
  310. s = np.array([input_w, input_h], dtype=np.float32)
  311. c = np.array([w // 2, h // 2], dtype=np.float32)
  312. else:
  313. # False in centertrack eval_mot/eval_mot
  314. s = max(h, w) * 1.0
  315. input_h, input_w = self.input_h, self.input_w
  316. c = np.array([w / 2.0, h / 2.0], dtype=np.float32)
  317. trans_input = get_affine_transform(c, s, 0, [input_w, input_h])
  318. img = cv2.resize(img, (w, h))
  319. inp = cv2.warpAffine(
  320. img, trans_input, (input_w, input_h), flags=cv2.INTER_LINEAR
  321. )
  322. if not self.keep_res:
  323. out_h = input_h // self.down_ratio
  324. out_w = input_w // self.down_ratio
  325. trans_output = get_affine_transform(c, s, 0, [out_w, out_h])
  326. return inp
  327. def __call__(self, datas: List[dict]) -> List[dict]:
  328. for data in datas:
  329. ori_img = data["img"]
  330. if "ori_img_size" not in data:
  331. data["ori_img_size"] = [ori_img.shape[1], ori_img.shape[0]]
  332. ori_img_size = data["ori_img_size"]
  333. img = self.apply(ori_img)
  334. data["img"] = img
  335. img_size = [img.shape[1], img.shape[0]]
  336. data["img_size"] = img_size # [size_w, size_h]
  337. data["scale_factors"] = [ # [w_scale, h_scale]
  338. img_size[0] / ori_img_size[0],
  339. img_size[1] / ori_img_size[1],
  340. ]
  341. return datas
  342. def compute_iou(box1: List[Number], box2: List[Number]) -> float:
  343. """Compute the Intersection over Union (IoU) of two bounding boxes.
  344. Args:
  345. box1 (List[Number]): Coordinates of the first bounding box in format [x1, y1, x2, y2].
  346. box2 (List[Number]): Coordinates of the second bounding box in format [x1, y1, x2, y2].
  347. Returns:
  348. float: The IoU of the two bounding boxes.
  349. """
  350. x1 = max(box1[0], box2[0])
  351. y1 = max(box1[1], box2[1])
  352. x2 = min(box1[2], box2[2])
  353. y2 = min(box1[3], box2[3])
  354. inter_area = max(0, x2 - x1 + 1) * max(0, y2 - y1 + 1)
  355. box1_area = (box1[2] - box1[0] + 1) * (box1[3] - box1[1] + 1)
  356. box2_area = (box2[2] - box2[0] + 1) * (box2[3] - box2[1] + 1)
  357. iou = inter_area / float(box1_area + box2_area - inter_area)
  358. return iou
  359. def is_box_mostly_inside(
  360. inner_box: List[Number], outer_box: List[Number], threshold: float = 0.9
  361. ) -> bool:
  362. """Determine if one bounding box is mostly inside another bounding box.
  363. Args:
  364. inner_box (List[Number]): Coordinates of the inner bounding box in format [x1, y1, x2, y2].
  365. outer_box (List[Number]): Coordinates of the outer bounding box in format [x1, y1, x2, y2].
  366. threshold (float): The threshold for determining if the inner box is mostly inside the outer box (default is 0.9).
  367. Returns:
  368. bool: True if the ratio of the intersection area to the inner box area is greater than or equal to the threshold, False otherwise.
  369. """
  370. x1 = max(inner_box[0], outer_box[0])
  371. y1 = max(inner_box[1], outer_box[1])
  372. x2 = min(inner_box[2], outer_box[2])
  373. y2 = min(inner_box[3], outer_box[3])
  374. inter_area = max(0, x2 - x1 + 1) * max(0, y2 - y1 + 1)
  375. inner_box_area = (inner_box[2] - inner_box[0] + 1) * (
  376. inner_box[3] - inner_box[1] + 1
  377. )
  378. return (inter_area / inner_box_area) >= threshold
  379. def restructured_boxes(
  380. boxes: ndarray, labels: List[str], img_size: Tuple[int, int]
  381. ) -> Boxes:
  382. """
  383. Restructure the given bounding boxes and labels based on the image size.
  384. Args:
  385. boxes (ndarray): A 2D array of bounding boxes with each box represented as [cls_id, score, xmin, ymin, xmax, ymax].
  386. labels (List[str]): A list of class labels corresponding to the class ids.
  387. img_size (Tuple[int, int]): A tuple representing the width and height of the image.
  388. Returns:
  389. Boxes: A list of dictionaries, each containing 'cls_id', 'label', 'score', and 'coordinate' keys.
  390. """
  391. box_list = []
  392. w, h = img_size
  393. for box in boxes:
  394. xmin, ymin, xmax, ymax = box[2:]
  395. xmin = max(0, xmin)
  396. ymin = max(0, ymin)
  397. xmax = min(w, xmax)
  398. ymax = min(h, ymax)
  399. box_list.append(
  400. {
  401. "cls_id": int(box[0]),
  402. "label": labels[int(box[0])],
  403. "score": float(box[1]),
  404. "coordinate": [xmin, ymin, xmax, ymax],
  405. }
  406. )
  407. return box_list
  408. def restructured_rotated_boxes(
  409. boxes: ndarray, labels: List[str], img_size: Tuple[int, int]
  410. ) -> Boxes:
  411. """
  412. Restructure the given rotated bounding boxes and labels based on the image size.
  413. Args:
  414. boxes (ndarray): A 2D array of rotated bounding boxes with each box represented as [cls_id, score, x1, y1, x2, y2, x3, y3, x4, y4].
  415. labels (List[str]): A list of class labels corresponding to the class ids.
  416. img_size (Tuple[int, int]): A tuple representing the width and height of the image.
  417. Returns:
  418. Boxes: A list of dictionaries, each containing 'cls_id', 'label', 'score', and 'coordinate' keys.
  419. """
  420. box_list = []
  421. w, h = img_size
  422. assert boxes.shape[1] == 10, "The shape of rotated boxes should be [N, 10]"
  423. for box in boxes:
  424. x1, y1, x2, y2, x3, y3, x4, y4 = box[2:]
  425. x1 = min(max(0, x1), w)
  426. y1 = min(max(0, y1), h)
  427. x2 = min(max(0, x2), w)
  428. y2 = min(max(0, y2), h)
  429. x3 = min(max(0, x3), w)
  430. y3 = min(max(0, y3), h)
  431. x4 = min(max(0, x4), w)
  432. y4 = min(max(0, y4), h)
  433. box_list.append(
  434. {
  435. "cls_id": int(box[0]),
  436. "label": labels[int(box[0])],
  437. "score": float(box[1]),
  438. "coordinate": [x1, y1, x2, y2, x3, y3, x4, y4],
  439. }
  440. )
  441. return box_list
  442. def non_max_suppression(
  443. boxes: ndarray, scores: ndarray, iou_threshold: float
  444. ) -> List[int]:
  445. """
  446. Perform non-maximum suppression to remove redundant overlapping boxes with
  447. lower scores. This function is commonly used in object detection tasks.
  448. Parameters:
  449. boxes (ndarray): An array of shape (N, 4) representing the bounding boxes.
  450. Each row is in the format [x1, y1, x2, y2].
  451. scores (ndarray): An array of shape (N,) containing the scores for each box.
  452. iou_threshold (float): The Intersection over Union (IoU) threshold to use
  453. for suppressing overlapping boxes.
  454. Returns:
  455. List[int]: A list of indices representing the indices of the boxes to keep.
  456. """
  457. if len(boxes) == 0:
  458. return []
  459. x1 = boxes[:, 0]
  460. y1 = boxes[:, 1]
  461. x2 = boxes[:, 2]
  462. y2 = boxes[:, 3]
  463. areas = (x2 - x1 + 1) * (y2 - y1 + 1)
  464. order = scores.argsort()[::-1]
  465. keep = []
  466. while order.size > 0:
  467. i = order[0]
  468. keep.append(i)
  469. xx1 = np.maximum(x1[i], x1[order[1:]])
  470. yy1 = np.maximum(y1[i], y1[order[1:]])
  471. xx2 = np.minimum(x2[i], x2[order[1:]])
  472. yy2 = np.minimum(y2[i], y2[order[1:]])
  473. w = np.maximum(0.0, xx2 - xx1 + 1)
  474. h = np.maximum(0.0, yy2 - yy1 + 1)
  475. inter = w * h
  476. iou = inter / (areas[i] + areas[order[1:]] - inter)
  477. inds = np.where(iou <= iou_threshold)[0]
  478. order = order[inds + 1]
  479. return keep
  480. class DetPostProcess:
  481. """Save Result Transform
  482. This class is responsible for post-processing detection results, including
  483. thresholding, non-maximum suppression (NMS), and restructuring the boxes
  484. based on the input type (normal or rotated object detection).
  485. """
  486. def __init__(
  487. self,
  488. threshold: float = 0.5,
  489. labels: Optional[List[str]] = None,
  490. layout_postprocess: bool = False,
  491. ) -> None:
  492. """Initialize the DetPostProcess class.
  493. Args:
  494. threshold (float, optional): The threshold to apply to the detection scores. Defaults to 0.5.
  495. labels (Optional[List[str]], optional): The list of labels for the detection categories. Defaults to None.
  496. layout_postprocess (bool, optional): Whether to apply layout post-processing. Defaults to False.
  497. """
  498. super().__init__()
  499. self.threshold = threshold
  500. self.labels = labels
  501. self.layout_postprocess = layout_postprocess
  502. def apply(self, boxes: ndarray, img_size) -> Boxes:
  503. """Apply post-processing to the detection boxes.
  504. Args:
  505. boxes (ndarray): The input detection boxes with scores.
  506. img_size (tuple): The original image size.
  507. Returns:
  508. Boxes: The post-processed detection boxes.
  509. """
  510. if isinstance(self.threshold, float):
  511. expect_boxes = (boxes[:, 1] > self.threshold) & (boxes[:, 0] > -1)
  512. boxes = boxes[expect_boxes, :]
  513. elif isinstance(self.threshold, dict):
  514. category_filtered_boxes = []
  515. for cat_id in np.unique(boxes[:, 0]):
  516. category_boxes = boxes[boxes[:, 0] == cat_id]
  517. category_scores = category_boxes[:, 1]
  518. category_threshold = self.threshold.get(int(cat_id), 0.5)
  519. selected_indices = category_scores > category_threshold
  520. category_filtered_boxes.append(category_boxes[selected_indices])
  521. boxes = (
  522. np.vstack(category_filtered_boxes)
  523. if category_filtered_boxes
  524. else np.array([])
  525. )
  526. if self.layout_postprocess:
  527. filtered_boxes = []
  528. ### Layout postprocess for NMS
  529. for cat_id in np.unique(boxes[:, 0]):
  530. category_boxes = boxes[boxes[:, 0] == cat_id]
  531. category_scores = category_boxes[:, 1]
  532. if len(category_boxes) > 0:
  533. nms_indices = non_max_suppression(
  534. category_boxes[:, 2:], category_scores, 0.5
  535. )
  536. category_boxes = category_boxes[nms_indices]
  537. keep_boxes = []
  538. for i, box in enumerate(category_boxes):
  539. if all(
  540. not is_box_mostly_inside(box[2:], other_box[2:])
  541. for j, other_box in enumerate(category_boxes)
  542. if i != j
  543. ):
  544. keep_boxes.append(box)
  545. filtered_boxes.extend(keep_boxes)
  546. boxes = np.array(filtered_boxes)
  547. ### Layout postprocess for removing boxes inside image category box
  548. if self.labels and "image" in self.labels:
  549. image_cls_id = self.labels.index("image")
  550. if len(boxes) > 0:
  551. image_boxes = boxes[boxes[:, 0] == image_cls_id]
  552. other_boxes = boxes[boxes[:, 0] != image_cls_id]
  553. to_keep = []
  554. for box in other_boxes:
  555. keep = True
  556. for img_box in image_boxes:
  557. if (
  558. box[2] >= img_box[2]
  559. and box[3] >= img_box[3]
  560. and box[4] <= img_box[4]
  561. and box[5] <= img_box[5]
  562. ):
  563. keep = False
  564. break
  565. if keep:
  566. to_keep.append(box)
  567. boxes = (
  568. np.vstack([image_boxes, to_keep]) if to_keep else image_boxes
  569. )
  570. ### Layout postprocess for overlaps
  571. final_boxes = []
  572. while len(boxes) > 0:
  573. current_box = boxes[0]
  574. current_score = current_box[1]
  575. overlaps = [current_box]
  576. non_overlaps = []
  577. for other_box in boxes[1:]:
  578. iou = compute_iou(current_box[2:], other_box[2:])
  579. if iou > 0.95:
  580. if other_box[1] > current_score:
  581. overlaps.append(other_box)
  582. else:
  583. non_overlaps.append(other_box)
  584. best_box = max(overlaps, key=lambda x: x[1])
  585. final_boxes.append(best_box)
  586. boxes = np.array(non_overlaps)
  587. boxes = np.array(final_boxes)
  588. if boxes.shape[1] == 6:
  589. """For Normal Object Detection"""
  590. boxes = restructured_boxes(boxes, self.labels, img_size)
  591. elif boxes.shape[1] == 10:
  592. """Adapt For Rotated Object Detection"""
  593. boxes = restructured_rotated_boxes(boxes, self.labels, img_size)
  594. else:
  595. """Unexpected Input Box Shape"""
  596. raise ValueError(
  597. f"The shape of boxes should be 6 or 10, instead of {boxes.shape[1]}"
  598. )
  599. return boxes
  600. def __call__(self, batch_outputs: List[dict], datas: List[dict]) -> List[Boxes]:
  601. """Apply the post-processing to a batch of outputs.
  602. Args:
  603. batch_outputs (List[dict]): The list of detection outputs.
  604. datas (List[dict]): The list of input data.
  605. Returns:
  606. List[Boxes]: The list of post-processed detection boxes.
  607. """
  608. outputs = []
  609. for data, output in zip(datas, batch_outputs):
  610. boxes = self.apply(output["boxes"], data["ori_img_size"])
  611. outputs.append(boxes)
  612. return outputs