processors.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import List, Sequence, Tuple, Union, Optional
  15. import cv2
  16. import numpy as np
  17. from numpy import ndarray
  18. from ..common import Resize as CommonResize
  19. from ..common import Normalize as CommonNormalize
  20. from ...common.reader import ReadImage as CommonReadImage
  21. from ...utils.benchmark import benchmark
  22. Boxes = List[dict]
  23. Number = Union[int, float]
  24. @benchmark.timeit
  25. class ReadImage(CommonReadImage):
  26. """Reads images from a list of raw image data or file paths."""
  27. def __call__(self, raw_imgs: List[Union[ndarray, str, dict]]) -> List[dict]:
  28. """Processes the input list of raw image data or file paths and returns a list of dictionaries containing image information.
  29. Args:
  30. raw_imgs (List[Union[ndarray, str]]): A list of raw image data (numpy ndarrays) or file paths (strings).
  31. Returns:
  32. List[dict]: A list of dictionaries, each containing image information.
  33. """
  34. out_datas = []
  35. for raw_img in raw_imgs:
  36. data = dict()
  37. if isinstance(raw_img, str):
  38. data["img_path"] = raw_img
  39. if isinstance(raw_img, dict):
  40. if "img" in raw_img:
  41. src_img = raw_img["img"]
  42. elif "img_path" in raw_img:
  43. src_img = raw_img["img_path"]
  44. data["img_path"] = src_img
  45. else:
  46. raise ValueError(
  47. "When raw_img is dict, must have one of keys ['img', 'img_path']."
  48. )
  49. data.update(raw_img)
  50. raw_img = src_img
  51. img, ori_img = self.read(raw_img)
  52. data["img"] = img
  53. data["ori_img"] = ori_img
  54. data["img_size"] = [img.shape[1], img.shape[0]] # [size_w, size_h]
  55. data["ori_img_size"] = [img.shape[1], img.shape[0]] # [size_w, size_h]
  56. out_datas.append(data)
  57. return out_datas
  58. def read(self, img):
  59. if isinstance(img, np.ndarray):
  60. ori_img = img
  61. if self.format == "RGB":
  62. img = img[:, :, ::-1]
  63. return img, ori_img
  64. elif isinstance(img, str):
  65. blob = self._img_reader.read(img)
  66. if blob is None:
  67. raise Exception(f"Image read Error: {img}")
  68. ori_img = blob
  69. if self.format == "RGB":
  70. if blob.ndim != 3:
  71. raise RuntimeError("Array is not 3-dimensional.")
  72. # BGR to RGB
  73. blob = blob[..., ::-1]
  74. return blob, ori_img
  75. else:
  76. raise TypeError(
  77. f"ReadImage only supports the following types:\n"
  78. f"1. str, indicating a image file path or a directory containing image files.\n"
  79. f"2. numpy.ndarray.\n"
  80. f"However, got type: {type(img).__name__}."
  81. )
  82. @benchmark.timeit
  83. class Resize(CommonResize):
  84. def __call__(self, datas: List[dict]) -> List[dict]:
  85. """
  86. Args:
  87. datas (List[dict]): A list of dictionaries, each containing image data with key 'img'.
  88. Returns:
  89. List[dict]: A list of dictionaries with updated image data, including resized images,
  90. original image sizes, resized image sizes, and scale factors.
  91. """
  92. for data in datas:
  93. ori_img = data["img"]
  94. if "ori_img_size" not in data:
  95. data["ori_img_size"] = [ori_img.shape[1], ori_img.shape[0]]
  96. ori_img_size = data["ori_img_size"]
  97. img = self.resize(ori_img)
  98. data["img"] = img
  99. img_size = [img.shape[1], img.shape[0]]
  100. data["img_size"] = img_size # [size_w, size_h]
  101. data["scale_factors"] = [ # [w_scale, h_scale]
  102. img_size[0] / ori_img_size[0],
  103. img_size[1] / ori_img_size[1],
  104. ]
  105. return datas
  106. @benchmark.timeit
  107. class Normalize(CommonNormalize):
  108. """Normalizes images in a list of dictionaries containing image data"""
  109. def apply(self, img: ndarray) -> ndarray:
  110. """Applies normalization to a single image."""
  111. old_type = img.dtype
  112. # XXX: If `old_type` has higher precision than float32,
  113. # we will lose some precision.
  114. img = img.astype("float32", copy=False)
  115. img *= self.scale
  116. img -= self.mean
  117. img /= self.std
  118. if self.preserve_dtype:
  119. img = img.astype(old_type, copy=False)
  120. return img
  121. def __call__(self, datas: List[dict]) -> List[dict]:
  122. """Normalizes images in a list of dictionaries. Iterates over each dictionary,
  123. applies normalization to the 'img' key, and returns the modified list.
  124. """
  125. for data in datas:
  126. data["img"] = self.apply(data["img"])
  127. return datas
  128. @benchmark.timeit
  129. class ToCHWImage:
  130. """Converts images in a list of dictionaries from HWC to CHW format."""
  131. def __call__(self, datas: List[dict]) -> List[dict]:
  132. """Converts the image data in the list of dictionaries from HWC to CHW format in-place.
  133. Args:
  134. datas (List[dict]): A list of dictionaries, each containing an image tensor in 'img' key with HWC format.
  135. Returns:
  136. List[dict]: The same list of dictionaries with the image tensors converted to CHW format.
  137. """
  138. for data in datas:
  139. data["img"] = data["img"].transpose((2, 0, 1))
  140. return datas
  141. @benchmark.timeit
  142. class ToBatch:
  143. """
  144. Class for batch processing of data dictionaries.
  145. Args:
  146. ordered_required_keys (Optional[Tuple[str]]): A tuple of keys that need to be present in the input data dictionaries in a specific order.
  147. """
  148. def __init__(self, ordered_required_keys: Optional[Tuple[str]] = None):
  149. self.ordered_required_keys = ordered_required_keys
  150. def apply(
  151. self, datas: List[dict], key: str, dtype: np.dtype = np.float32
  152. ) -> np.ndarray:
  153. """
  154. Apply batch processing to a list of data dictionaries.
  155. Args:
  156. datas (List[dict]): A list of data dictionaries to process.
  157. key (str): The key in the data dictionaries to extract and batch.
  158. dtype (np.dtype): The desired data type of the output array (default is np.float32).
  159. Returns:
  160. np.ndarray: A numpy array containing the batched data.
  161. Raises:
  162. KeyError: If the specified key is not found in any of the data dictionaries.
  163. """
  164. if key == "img_size":
  165. # [h, w] size for det models
  166. img_sizes = [data[key][::-1] for data in datas]
  167. return np.stack(img_sizes, axis=0).astype(dtype=dtype, copy=False)
  168. elif key == "scale_factors":
  169. # [h, w] scale factors for det models, default [1.0, 1.0]
  170. scale_factors = [data.get(key, [1.0, 1.0])[::-1] for data in datas]
  171. return np.stack(scale_factors, axis=0).astype(dtype=dtype, copy=False)
  172. else:
  173. return np.stack([data[key] for data in datas], axis=0).astype(
  174. dtype=dtype, copy=False
  175. )
  176. def __call__(self, datas: List[dict]) -> Sequence[ndarray]:
  177. return [self.apply(datas, key) for key in self.ordered_required_keys]
  178. @benchmark.timeit
  179. class DetPad:
  180. """
  181. Pad image to a specified size.
  182. Args:
  183. size (list[int]): image target size
  184. fill_value (list[float]): rgb value of pad area, default (114.0, 114.0, 114.0)
  185. """
  186. def __init__(
  187. self,
  188. size: List[int],
  189. fill_value: List[Union[int, float]] = [114.0, 114.0, 114.0],
  190. ):
  191. super().__init__()
  192. if isinstance(size, int):
  193. size = [size, size]
  194. self.size = size
  195. self.fill_value = fill_value
  196. def apply(self, img: ndarray) -> ndarray:
  197. im = img
  198. im_h, im_w = im.shape[:2]
  199. h, w = self.size
  200. if h == im_h and w == im_w:
  201. return im
  202. canvas = np.ones((h, w, 3), dtype=np.float32)
  203. canvas *= np.array(self.fill_value, dtype=np.float32)
  204. canvas[0:im_h, 0:im_w, :] = im.astype(np.float32)
  205. return canvas
  206. def __call__(self, datas: List[dict]) -> List[dict]:
  207. for data in datas:
  208. data["img"] = self.apply(data["img"])
  209. return datas
  210. @benchmark.timeit
  211. class PadStride:
  212. """padding image for model with FPN , instead PadBatch(pad_to_stride, pad_gt) in original config
  213. Args:
  214. stride (bool): model with FPN need image shape % stride == 0
  215. """
  216. def __init__(self, stride: int = 0):
  217. super().__init__()
  218. self.coarsest_stride = stride
  219. def apply(self, img: ndarray):
  220. """
  221. Args:
  222. im (np.ndarray): image (np.ndarray)
  223. Returns:
  224. im (np.ndarray): processed image (np.ndarray)
  225. """
  226. im = img
  227. coarsest_stride = self.coarsest_stride
  228. if coarsest_stride <= 0:
  229. return img
  230. im_c, im_h, im_w = im.shape
  231. pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride)
  232. pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride)
  233. padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32)
  234. padding_im[:, :im_h, :im_w] = im
  235. return padding_im
  236. def __call__(self, datas: List[dict]) -> List[dict]:
  237. for data in datas:
  238. data["img"] = self.apply(data["img"])
  239. return datas
  240. def rotate_point(pt: List[float], angle_rad: float) -> List[float]:
  241. """Rotate a point by an angle.
  242. Args:
  243. pt (list[float]): 2 dimensional point to be rotated
  244. angle_rad (float): rotation angle by radian
  245. Returns:
  246. list[float]: Rotated point.
  247. """
  248. assert len(pt) == 2
  249. sn, cs = np.sin(angle_rad), np.cos(angle_rad)
  250. new_x = pt[0] * cs - pt[1] * sn
  251. new_y = pt[0] * sn + pt[1] * cs
  252. rotated_pt = [new_x, new_y]
  253. return rotated_pt
  254. def _get_3rd_point(a: ndarray, b: ndarray) -> ndarray:
  255. """To calculate the affine matrix, three pairs of points are required. This
  256. function is used to get the 3rd point, given 2D points a & b.
  257. The 3rd point is defined by rotating vector `a - b` by 90 degrees
  258. anticlockwise, using b as the rotation center.
  259. Args:
  260. a (np.ndarray): point(x,y)
  261. b (np.ndarray): point(x,y)
  262. Returns:
  263. np.ndarray: The 3rd point.
  264. """
  265. assert len(a) == 2
  266. assert len(b) == 2
  267. direction = a - b
  268. third_pt = b + np.array([-direction[1], direction[0]], dtype=np.float32)
  269. return third_pt
  270. def get_affine_transform(
  271. center: ndarray,
  272. input_size: Union[Number, Tuple[Number, Number], ndarray],
  273. rot: float,
  274. output_size: ndarray,
  275. shift: Tuple[float, float] = (0.0, 0.0),
  276. inv: bool = False,
  277. ):
  278. """Get the affine transform matrix, given the center/scale/rot/output_size.
  279. Args:
  280. center (np.ndarray[2, ]): Center of the bounding box (x, y).
  281. input_size (np.ndarray[2, ]): Scale of the bounding box
  282. wrt [width, height].
  283. rot (float): Rotation angle (degree).
  284. output_size (np.ndarray[2, ]): Size of the destination heatmaps.
  285. shift (0-100%): Shift translation ratio wrt the width/height.
  286. Default (0., 0.).
  287. inv (bool): Option to inverse the affine transform direction.
  288. (inv=False: src->dst or inv=True: dst->src)
  289. Returns:
  290. np.ndarray: The transform matrix.
  291. """
  292. assert len(center) == 2
  293. assert len(output_size) == 2
  294. assert len(shift) == 2
  295. if not isinstance(input_size, (ndarray, list)):
  296. input_size = np.array([input_size, input_size], dtype=np.float32)
  297. scale_tmp = input_size
  298. shift = np.array(shift)
  299. src_w = scale_tmp[0]
  300. dst_w = output_size[0]
  301. dst_h = output_size[1]
  302. rot_rad = np.pi * rot / 180
  303. src_dir = rotate_point([0.0, src_w * -0.5], rot_rad)
  304. dst_dir = np.array([0.0, dst_w * -0.5])
  305. src = np.zeros((3, 2), dtype=np.float32)
  306. src[0, :] = center + scale_tmp * shift
  307. src[1, :] = center + src_dir + scale_tmp * shift
  308. src[2, :] = _get_3rd_point(src[0, :], src[1, :])
  309. dst = np.zeros((3, 2), dtype=np.float32)
  310. dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
  311. dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
  312. dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
  313. if inv:
  314. trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
  315. else:
  316. trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
  317. return trans
  318. @benchmark.timeit
  319. class WarpAffine:
  320. """Apply warp affine transformation to the image based on the given parameters.
  321. Args:
  322. keep_res (bool): Whether to keep the original resolution aspect ratio during transformation.
  323. pad (int): Padding value used when keep_res is True.
  324. input_h (int): Target height for the input image when keep_res is False.
  325. input_w (int): Target width for the input image when keep_res is False.
  326. scale (float): Scale factor for resizing.
  327. shift (float): Shift factor for transformation.
  328. down_ratio (int): Downsampling ratio for the output image.
  329. """
  330. def __init__(
  331. self,
  332. keep_res=False,
  333. pad=31,
  334. input_h=512,
  335. input_w=512,
  336. scale=0.4,
  337. shift=0.1,
  338. down_ratio=4,
  339. ):
  340. super().__init__()
  341. self.keep_res = keep_res
  342. self.pad = pad
  343. self.input_h = input_h
  344. self.input_w = input_w
  345. self.scale = scale
  346. self.shift = shift
  347. self.down_ratio = down_ratio
  348. def apply(self, img: ndarray):
  349. img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
  350. h, w = img.shape[:2]
  351. if self.keep_res:
  352. # True in detection eval/infer
  353. input_h = (h | self.pad) + 1
  354. input_w = (w | self.pad) + 1
  355. s = np.array([input_w, input_h], dtype=np.float32)
  356. c = np.array([w // 2, h // 2], dtype=np.float32)
  357. else:
  358. # False in centertrack eval_mot/eval_mot
  359. s = max(h, w) * 1.0
  360. input_h, input_w = self.input_h, self.input_w
  361. c = np.array([w / 2.0, h / 2.0], dtype=np.float32)
  362. trans_input = get_affine_transform(c, s, 0, [input_w, input_h])
  363. img = cv2.resize(img, (w, h))
  364. inp = cv2.warpAffine(
  365. img, trans_input, (input_w, input_h), flags=cv2.INTER_LINEAR
  366. )
  367. if not self.keep_res:
  368. out_h = input_h // self.down_ratio
  369. out_w = input_w // self.down_ratio
  370. trans_output = get_affine_transform(c, s, 0, [out_w, out_h])
  371. return inp
  372. def __call__(self, datas: List[dict]) -> List[dict]:
  373. for data in datas:
  374. ori_img = data["img"]
  375. if "ori_img_size" not in data:
  376. data["ori_img_size"] = [ori_img.shape[1], ori_img.shape[0]]
  377. img = self.apply(ori_img)
  378. data["img"] = img
  379. return datas
  380. def restructured_boxes(
  381. boxes: ndarray, labels: List[str], img_size: Tuple[int, int]
  382. ) -> Boxes:
  383. """
  384. Restructure the given bounding boxes and labels based on the image size.
  385. Args:
  386. boxes (ndarray): A 2D array of bounding boxes with each box represented as [cls_id, score, xmin, ymin, xmax, ymax].
  387. labels (List[str]): A list of class labels corresponding to the class ids.
  388. img_size (Tuple[int, int]): A tuple representing the width and height of the image.
  389. Returns:
  390. Boxes: A list of dictionaries, each containing 'cls_id', 'label', 'score', and 'coordinate' keys.
  391. """
  392. box_list = []
  393. w, h = img_size
  394. for box in boxes:
  395. xmin, ymin, xmax, ymax = box[2:]
  396. xmin = max(0, xmin)
  397. ymin = max(0, ymin)
  398. xmax = min(w, xmax)
  399. ymax = min(h, ymax)
  400. box_list.append(
  401. {
  402. "cls_id": int(box[0]),
  403. "label": labels[int(box[0])],
  404. "score": float(box[1]),
  405. "coordinate": [xmin, ymin, xmax, ymax],
  406. }
  407. )
  408. return box_list
  409. def restructured_rotated_boxes(
  410. boxes: ndarray, labels: List[str], img_size: Tuple[int, int]
  411. ) -> Boxes:
  412. """
  413. Restructure the given rotated bounding boxes and labels based on the image size.
  414. Args:
  415. boxes (ndarray): A 2D array of rotated bounding boxes with each box represented as [cls_id, score, x1, y1, x2, y2, x3, y3, x4, y4].
  416. labels (List[str]): A list of class labels corresponding to the class ids.
  417. img_size (Tuple[int, int]): A tuple representing the width and height of the image.
  418. Returns:
  419. Boxes: A list of dictionaries, each containing 'cls_id', 'label', 'score', and 'coordinate' keys.
  420. """
  421. box_list = []
  422. w, h = img_size
  423. assert boxes.shape[1] == 10, "The shape of rotated boxes should be [N, 10]"
  424. for box in boxes:
  425. x1, y1, x2, y2, x3, y3, x4, y4 = box[2:]
  426. x1 = min(max(0, x1), w)
  427. y1 = min(max(0, y1), h)
  428. x2 = min(max(0, x2), w)
  429. y2 = min(max(0, y2), h)
  430. x3 = min(max(0, x3), w)
  431. y3 = min(max(0, y3), h)
  432. x4 = min(max(0, x4), w)
  433. y4 = min(max(0, y4), h)
  434. box_list.append(
  435. {
  436. "cls_id": int(box[0]),
  437. "label": labels[int(box[0])],
  438. "score": float(box[1]),
  439. "coordinate": [x1, y1, x2, y2, x3, y3, x4, y4],
  440. }
  441. )
  442. return box_list
  443. def unclip_boxes(boxes, unclip_ratio=None):
  444. """
  445. Expand bounding boxes from (x1, y1, x2, y2) format using an unclipping ratio.
  446. Parameters:
  447. - boxes: np.ndarray of shape (N, 4), where each row is (x1, y1, x2, y2).
  448. - unclip_ratio: tuple of (width_ratio, height_ratio), optional.
  449. Returns:
  450. - expanded_boxes: np.ndarray of shape (N, 4), where each row is (x1, y1, x2, y2).
  451. """
  452. if unclip_ratio is None:
  453. return boxes
  454. if isinstance(unclip_ratio, dict):
  455. expanded_boxes = []
  456. for box in boxes:
  457. class_id, score, x1, y1, x2, y2 = box
  458. if class_id in unclip_ratio:
  459. width_ratio, height_ratio = unclip_ratio[class_id]
  460. width = x2 - x1
  461. height = y2 - y1
  462. new_w = width * width_ratio
  463. new_h = height * height_ratio
  464. center_x = x1 + width / 2
  465. center_y = y1 + height / 2
  466. new_x1 = center_x - new_w / 2
  467. new_y1 = center_y - new_h / 2
  468. new_x2 = center_x + new_w / 2
  469. new_y2 = center_y + new_h / 2
  470. expanded_boxes.append([class_id, score, new_x1, new_y1, new_x2, new_y2])
  471. else:
  472. expanded_boxes.append(box)
  473. return np.array(expanded_boxes)
  474. else:
  475. widths = boxes[:, 4] - boxes[:, 2]
  476. heights = boxes[:, 5] - boxes[:, 3]
  477. new_w = widths * unclip_ratio[0]
  478. new_h = heights * unclip_ratio[1]
  479. center_x = boxes[:, 2] + widths / 2
  480. center_y = boxes[:, 3] + heights / 2
  481. new_x1 = center_x - new_w / 2
  482. new_y1 = center_y - new_h / 2
  483. new_x2 = center_x + new_w / 2
  484. new_y2 = center_y + new_h / 2
  485. expanded_boxes = np.column_stack(
  486. (boxes[:, 0], boxes[:, 1], new_x1, new_y1, new_x2, new_y2)
  487. )
  488. return expanded_boxes
  489. def iou(box1, box2):
  490. """Compute the Intersection over Union (IoU) of two bounding boxes."""
  491. x1, y1, x2, y2 = box1
  492. x1_p, y1_p, x2_p, y2_p = box2
  493. # Compute the intersection coordinates
  494. x1_i = max(x1, x1_p)
  495. y1_i = max(y1, y1_p)
  496. x2_i = min(x2, x2_p)
  497. y2_i = min(y2, y2_p)
  498. # Compute the area of intersection
  499. inter_area = max(0, x2_i - x1_i + 1) * max(0, y2_i - y1_i + 1)
  500. # Compute the area of both bounding boxes
  501. box1_area = (x2 - x1 + 1) * (y2 - y1 + 1)
  502. box2_area = (x2_p - x1_p + 1) * (y2_p - y1_p + 1)
  503. # Compute the IoU
  504. iou_value = inter_area / float(box1_area + box2_area - inter_area)
  505. return iou_value
  506. def nms(boxes, iou_same=0.6, iou_diff=0.95):
  507. """Perform Non-Maximum Suppression (NMS) with different IoU thresholds for same and different classes."""
  508. # Extract class scores
  509. scores = boxes[:, 1]
  510. # Sort indices by scores in descending order
  511. indices = np.argsort(scores)[::-1]
  512. selected_boxes = []
  513. while len(indices) > 0:
  514. current = indices[0]
  515. current_box = boxes[current]
  516. current_class = current_box[0]
  517. current_score = current_box[1]
  518. current_coords = current_box[2:]
  519. selected_boxes.append(current)
  520. indices = indices[1:]
  521. filtered_indices = []
  522. for i in indices:
  523. box = boxes[i]
  524. box_class = box[0]
  525. box_coords = box[2:]
  526. iou_value = iou(current_coords, box_coords)
  527. threshold = iou_same if current_class == box_class else iou_diff
  528. # If the IoU is below the threshold, keep the box
  529. if iou_value < threshold:
  530. filtered_indices.append(i)
  531. indices = filtered_indices
  532. return selected_boxes
  533. def is_contained(box1, box2):
  534. """Check if box1 is contained within box2."""
  535. _, _, x1, y1, x2, y2 = box1
  536. _, _, x1_p, y1_p, x2_p, y2_p = box2
  537. box1_area = (x2 - x1) * (y2 - y1)
  538. xi1 = max(x1, x1_p)
  539. yi1 = max(y1, y1_p)
  540. xi2 = min(x2, x2_p)
  541. yi2 = min(y2, y2_p)
  542. inter_width = max(0, xi2 - xi1)
  543. inter_height = max(0, yi2 - yi1)
  544. intersect_area = inter_width * inter_height
  545. iou = intersect_area / box1_area if box1_area > 0 else 0
  546. return iou >= 0.9
  547. def check_containment(boxes, formula_index=None, category_index=None, mode=None):
  548. """Check containment relationships among boxes."""
  549. n = len(boxes)
  550. contains_other = np.zeros(n, dtype=int)
  551. contained_by_other = np.zeros(n, dtype=int)
  552. for i in range(n):
  553. for j in range(n):
  554. if i == j:
  555. continue
  556. if formula_index is not None:
  557. if boxes[i][0] == formula_index and boxes[j][0] != formula_index:
  558. continue
  559. if category_index is not None and mode is not None:
  560. if mode == "large" and boxes[j][0] == category_index:
  561. if is_contained(boxes[i], boxes[j]):
  562. contained_by_other[i] = 1
  563. contains_other[j] = 1
  564. if mode == "small" and boxes[i][0] == category_index:
  565. if is_contained(boxes[i], boxes[j]):
  566. contained_by_other[i] = 1
  567. contains_other[j] = 1
  568. else:
  569. if is_contained(boxes[i], boxes[j]):
  570. contained_by_other[i] = 1
  571. contains_other[j] = 1
  572. return contains_other, contained_by_other
  573. @benchmark.timeit
  574. class DetPostProcess:
  575. """Save Result Transform
  576. This class is responsible for post-processing detection results, including
  577. thresholding, non-maximum suppression (NMS), and restructuring the boxes
  578. based on the input type (normal or rotated object detection).
  579. """
  580. def __init__(self, labels: Optional[List[str]] = None) -> None:
  581. """Initialize the DetPostProcess class.
  582. Args:
  583. threshold (float, optional): The threshold to apply to the detection scores. Defaults to 0.5.
  584. labels (Optional[List[str]], optional): The list of labels for the detection categories. Defaults to None.
  585. layout_postprocess (bool, optional): Whether to apply layout post-processing. Defaults to False.
  586. """
  587. super().__init__()
  588. self.labels = labels
  589. def apply(
  590. self,
  591. boxes: ndarray,
  592. img_size: Tuple[int, int],
  593. threshold: Union[float, dict],
  594. layout_nms: Optional[bool],
  595. layout_unclip_ratio: Optional[Union[float, Tuple[float, float], dict]],
  596. layout_merge_bboxes_mode: Optional[Union[str, dict]],
  597. ) -> Boxes:
  598. """Apply post-processing to the detection boxes.
  599. Args:
  600. boxes (ndarray): The input detection boxes with scores.
  601. img_size (tuple): The original image size.
  602. Returns:
  603. Boxes: The post-processed detection boxes.
  604. """
  605. if isinstance(threshold, float):
  606. expect_boxes = (boxes[:, 1] > threshold) & (boxes[:, 0] > -1)
  607. boxes = boxes[expect_boxes, :]
  608. elif isinstance(threshold, dict):
  609. category_filtered_boxes = []
  610. for cat_id in np.unique(boxes[:, 0]):
  611. category_boxes = boxes[boxes[:, 0] == cat_id]
  612. category_threshold = threshold.get(int(cat_id), 0.5)
  613. selected_indices = (category_boxes[:, 1] > category_threshold) & (
  614. category_boxes[:, 0] > -1
  615. )
  616. category_filtered_boxes.append(category_boxes[selected_indices])
  617. boxes = (
  618. np.vstack(category_filtered_boxes)
  619. if category_filtered_boxes
  620. else np.array([])
  621. )
  622. if layout_nms:
  623. filtered_boxes = []
  624. ### Layout postprocess for NMS
  625. selected_indices = nms(boxes, iou_same=0.6, iou_diff=0.98)
  626. boxes = np.array(boxes[selected_indices])
  627. if layout_merge_bboxes_mode:
  628. formula_index = (self.labels.index("formula") if "formula" in self.labels else None)
  629. if isinstance(layout_merge_bboxes_mode, str):
  630. assert layout_merge_bboxes_mode in [
  631. "union",
  632. "large",
  633. "small",
  634. ], f"The value of `layout_merge_bboxes_mode` must be one of ['union', 'large', 'small'], but got {layout_merge_bboxes_mode}"
  635. if layout_merge_bboxes_mode == "union":
  636. pass
  637. else:
  638. contains_other, contained_by_other = check_containment(
  639. boxes, formula_index
  640. )
  641. if layout_merge_bboxes_mode == "large":
  642. boxes = boxes[contained_by_other == 0]
  643. elif layout_merge_bboxes_mode == "small":
  644. boxes = boxes[(contains_other == 0) | (contained_by_other == 1)]
  645. elif isinstance(layout_merge_bboxes_mode, dict):
  646. keep_mask = np.ones(len(boxes), dtype=bool)
  647. for category_index, layout_mode in layout_merge_bboxes_mode.items():
  648. assert layout_mode in [
  649. "union",
  650. "large",
  651. "small",
  652. ], f"The value of `layout_merge_bboxes_mode` must be one of ['union', 'large', 'small'], but got {layout_mode}"
  653. if layout_mode == "union":
  654. pass
  655. else:
  656. if layout_mode == "large":
  657. contains_other, contained_by_other = check_containment(
  658. boxes, formula_index, category_index, mode=layout_mode
  659. )
  660. # Remove boxes that are contained by other boxes
  661. keep_mask &= (contained_by_other == 0)
  662. elif layout_mode == "small":
  663. contains_other, contained_by_other = check_containment(
  664. boxes, formula_index, category_index, mode=layout_mode
  665. )
  666. # Keep boxes that do not contain others or are contained by others
  667. keep_mask &= (contains_other == 0) | (contained_by_other == 1)
  668. boxes = boxes[keep_mask]
  669. if layout_unclip_ratio:
  670. if isinstance(layout_unclip_ratio, float):
  671. layout_unclip_ratio = (layout_unclip_ratio, layout_unclip_ratio)
  672. elif isinstance(layout_unclip_ratio, (tuple, list)):
  673. assert (
  674. len(layout_unclip_ratio) == 2
  675. ), f"The length of `layout_unclip_ratio` should be 2."
  676. elif isinstance(layout_unclip_ratio, dict):
  677. pass
  678. else:
  679. raise ValueError(
  680. f"The type of `layout_unclip_ratio` must be float, Tuple[float, float] or Dict[int, Tuple[float, float]], but got {type(layout_unclip_ratio)}."
  681. )
  682. boxes = unclip_boxes(boxes, layout_unclip_ratio)
  683. if boxes.shape[1] == 6:
  684. """For Normal Object Detection"""
  685. boxes = restructured_boxes(boxes, self.labels, img_size)
  686. elif boxes.shape[1] == 10:
  687. """Adapt For Rotated Object Detection"""
  688. boxes = restructured_rotated_boxes(boxes, self.labels, img_size)
  689. else:
  690. """Unexpected Input Box Shape"""
  691. raise ValueError(
  692. f"The shape of boxes should be 6 or 10, instead of {boxes.shape[1]}"
  693. )
  694. return boxes
  695. def __call__(
  696. self,
  697. batch_outputs: List[dict],
  698. datas: List[dict],
  699. threshold: Optional[Union[float, dict]] = None,
  700. layout_nms: Optional[bool] = None,
  701. layout_unclip_ratio: Optional[Union[float, Tuple[float, float]]] = None,
  702. layout_merge_bboxes_mode: Optional[str] = None,
  703. ) -> List[Boxes]:
  704. """Apply the post-processing to a batch of outputs.
  705. Args:
  706. batch_outputs (List[dict]): The list of detection outputs.
  707. datas (List[dict]): The list of input data.
  708. Returns:
  709. List[Boxes]: The list of post-processed detection boxes.
  710. """
  711. outputs = []
  712. for data, output in zip(datas, batch_outputs):
  713. boxes = self.apply(
  714. output["boxes"],
  715. data["ori_img_size"],
  716. threshold,
  717. layout_nms,
  718. layout_unclip_ratio,
  719. layout_merge_bboxes_mode,
  720. )
  721. outputs.append(boxes)
  722. return outputs