processors.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import List, Optional, Sequence, Tuple, Union
  15. import numpy as np
  16. from numpy import ndarray
  17. from ....utils.deps import class_requires_deps, function_requires_deps, is_dep_available
  18. from ...common.reader import ReadImage as CommonReadImage
  19. from ...utils.benchmark import benchmark
  20. from ..common import Normalize as CommonNormalize
  21. from ..common import Resize as CommonResize
  22. if is_dep_available("opencv-contrib-python"):
  23. import cv2
  24. Boxes = List[dict]
  25. Number = Union[int, float]
  26. @benchmark.timeit_with_options(name=None, is_read_operation=True)
  27. @class_requires_deps("opencv-contrib-python")
  28. class ReadImage(CommonReadImage):
  29. """Reads images from a list of raw image data or file paths."""
  30. def __call__(self, raw_imgs: List[Union[ndarray, str, dict]]) -> List[dict]:
  31. """Processes the input list of raw image data or file paths and returns a list of dictionaries containing image information.
  32. Args:
  33. raw_imgs (List[Union[ndarray, str]]): A list of raw image data (numpy ndarrays) or file paths (strings).
  34. Returns:
  35. List[dict]: A list of dictionaries, each containing image information.
  36. """
  37. out_datas = []
  38. for raw_img in raw_imgs:
  39. data = dict()
  40. if isinstance(raw_img, str):
  41. data["img_path"] = raw_img
  42. if isinstance(raw_img, dict):
  43. if "img" in raw_img:
  44. src_img = raw_img["img"]
  45. elif "img_path" in raw_img:
  46. src_img = raw_img["img_path"]
  47. data["img_path"] = src_img
  48. else:
  49. raise ValueError(
  50. "When raw_img is dict, must have one of keys ['img', 'img_path']."
  51. )
  52. data.update(raw_img)
  53. raw_img = src_img
  54. img, ori_img = self.read(raw_img)
  55. data["img"] = img
  56. data["ori_img"] = ori_img
  57. data["img_size"] = [img.shape[1], img.shape[0]] # [size_w, size_h]
  58. data["ori_img_size"] = [img.shape[1], img.shape[0]] # [size_w, size_h]
  59. out_datas.append(data)
  60. return out_datas
  61. def read(self, img):
  62. if isinstance(img, np.ndarray):
  63. ori_img = img
  64. if self.format == "RGB":
  65. img = cv2.cvtColor(ori_img, cv2.COLOR_BGR2RGB)
  66. return img, ori_img
  67. elif isinstance(img, str):
  68. blob = self._img_reader.read(img)
  69. if blob is None:
  70. raise Exception(f"Image read Error: {img}")
  71. ori_img = blob
  72. if self.format == "RGB":
  73. if blob.ndim != 3:
  74. raise RuntimeError("Array is not 3-dimensional.")
  75. # BGR to RGB
  76. blob = cv2.cvtColor(blob, cv2.COLOR_BGR2RGB)
  77. return blob, ori_img
  78. else:
  79. raise TypeError(
  80. f"ReadImage only supports the following types:\n"
  81. f"1. str, indicating a image file path or a directory containing image files.\n"
  82. f"2. numpy.ndarray.\n"
  83. f"However, got type: {type(img).__name__}."
  84. )
  85. @benchmark.timeit
  86. class Resize(CommonResize):
  87. def __call__(self, datas: List[dict]) -> List[dict]:
  88. """
  89. Args:
  90. datas (List[dict]): A list of dictionaries, each containing image data with key 'img'.
  91. Returns:
  92. List[dict]: A list of dictionaries with updated image data, including resized images,
  93. original image sizes, resized image sizes, and scale factors.
  94. """
  95. for data in datas:
  96. ori_img = data["img"]
  97. if "ori_img_size" not in data:
  98. data["ori_img_size"] = [ori_img.shape[1], ori_img.shape[0]]
  99. ori_img_size = data["ori_img_size"]
  100. img = self.resize(ori_img)
  101. data["img"] = img
  102. img_size = [img.shape[1], img.shape[0]]
  103. data["img_size"] = img_size # [size_w, size_h]
  104. data["scale_factors"] = [ # [w_scale, h_scale]
  105. img_size[0] / ori_img_size[0],
  106. img_size[1] / ori_img_size[1],
  107. ]
  108. return datas
  109. @benchmark.timeit
  110. class Normalize(CommonNormalize):
  111. def __call__(self, datas: List[dict]) -> List[dict]:
  112. """Normalizes images in a list of dictionaries. Iterates over each dictionary,
  113. applies normalization to the 'img' key, and returns the modified list.
  114. """
  115. for data in datas:
  116. data["img"] = self.norm(data["img"])
  117. return datas
  118. @benchmark.timeit
  119. class ToCHWImage:
  120. """Converts images in a list of dictionaries from HWC to CHW format."""
  121. def __call__(self, datas: List[dict]) -> List[dict]:
  122. """Converts the image data in the list of dictionaries from HWC to CHW format in-place.
  123. Args:
  124. datas (List[dict]): A list of dictionaries, each containing an image tensor in 'img' key with HWC format.
  125. Returns:
  126. List[dict]: The same list of dictionaries with the image tensors converted to CHW format.
  127. """
  128. for data in datas:
  129. data["img"] = data["img"].transpose((2, 0, 1))
  130. return datas
  131. @benchmark.timeit
  132. class ToBatch:
  133. """
  134. Class for batch processing of data dictionaries.
  135. Args:
  136. ordered_required_keys (Optional[Tuple[str]]): A tuple of keys that need to be present in the input data dictionaries in a specific order.
  137. """
  138. def __init__(self, ordered_required_keys: Optional[Tuple[str]] = None):
  139. self.ordered_required_keys = ordered_required_keys
  140. def apply(
  141. self, datas: List[dict], key: str, dtype: np.dtype = np.float32
  142. ) -> np.ndarray:
  143. """
  144. Apply batch processing to a list of data dictionaries.
  145. Args:
  146. datas (List[dict]): A list of data dictionaries to process.
  147. key (str): The key in the data dictionaries to extract and batch.
  148. dtype (np.dtype): The desired data type of the output array (default is np.float32).
  149. Returns:
  150. np.ndarray: A numpy array containing the batched data.
  151. Raises:
  152. KeyError: If the specified key is not found in any of the data dictionaries.
  153. """
  154. if key == "img_size":
  155. # [h, w] size for det models
  156. img_sizes = [data[key][::-1] for data in datas]
  157. return np.stack(img_sizes, axis=0).astype(dtype=dtype, copy=False)
  158. elif key == "scale_factors":
  159. # [h, w] scale factors for det models, default [1.0, 1.0]
  160. scale_factors = [data.get(key, [1.0, 1.0])[::-1] for data in datas]
  161. return np.stack(scale_factors, axis=0).astype(dtype=dtype, copy=False)
  162. else:
  163. return np.stack([data[key] for data in datas], axis=0).astype(
  164. dtype=dtype, copy=False
  165. )
  166. def __call__(self, datas: List[dict]) -> Sequence[ndarray]:
  167. return [self.apply(datas, key) for key in self.ordered_required_keys]
  168. @benchmark.timeit
  169. class DetPad:
  170. """
  171. Pad image to a specified size.
  172. Args:
  173. size (list[int]): image target size
  174. fill_value (list[float]): rgb value of pad area, default (114.0, 114.0, 114.0)
  175. """
  176. def __init__(
  177. self,
  178. size: List[int],
  179. fill_value: List[Union[int, float]] = [114.0, 114.0, 114.0],
  180. ):
  181. super().__init__()
  182. if isinstance(size, int):
  183. size = [size, size]
  184. self.size = size
  185. self.fill_value = fill_value
  186. def apply(self, img: ndarray) -> ndarray:
  187. im = img
  188. im_h, im_w = im.shape[:2]
  189. h, w = self.size
  190. if h == im_h and w == im_w:
  191. return im
  192. canvas = np.ones((h, w, 3), dtype=np.float32)
  193. canvas *= np.array(self.fill_value, dtype=np.float32)
  194. canvas[0:im_h, 0:im_w, :] = im.astype(np.float32)
  195. return canvas
  196. def __call__(self, datas: List[dict]) -> List[dict]:
  197. for data in datas:
  198. data["img"] = self.apply(data["img"])
  199. return datas
  200. @benchmark.timeit
  201. class PadStride:
  202. """padding image for model with FPN , instead PadBatch(pad_to_stride, pad_gt) in original config
  203. Args:
  204. stride (bool): model with FPN need image shape % stride == 0
  205. """
  206. def __init__(self, stride: int = 0):
  207. super().__init__()
  208. self.coarsest_stride = stride
  209. def apply(self, img: ndarray):
  210. """
  211. Args:
  212. im (np.ndarray): image (np.ndarray)
  213. Returns:
  214. im (np.ndarray): processed image (np.ndarray)
  215. """
  216. im = img
  217. coarsest_stride = self.coarsest_stride
  218. if coarsest_stride <= 0:
  219. return img
  220. im_c, im_h, im_w = im.shape
  221. pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride)
  222. pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride)
  223. padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32)
  224. padding_im[:, :im_h, :im_w] = im
  225. return padding_im
  226. def __call__(self, datas: List[dict]) -> List[dict]:
  227. for data in datas:
  228. data["img"] = self.apply(data["img"])
  229. return datas
  230. def rotate_point(pt: List[float], angle_rad: float) -> List[float]:
  231. """Rotate a point by an angle.
  232. Args:
  233. pt (list[float]): 2 dimensional point to be rotated
  234. angle_rad (float): rotation angle by radian
  235. Returns:
  236. list[float]: Rotated point.
  237. """
  238. assert len(pt) == 2
  239. sn, cs = np.sin(angle_rad), np.cos(angle_rad)
  240. new_x = pt[0] * cs - pt[1] * sn
  241. new_y = pt[0] * sn + pt[1] * cs
  242. rotated_pt = [new_x, new_y]
  243. return rotated_pt
  244. def _get_3rd_point(a: ndarray, b: ndarray) -> ndarray:
  245. """To calculate the affine matrix, three pairs of points are required. This
  246. function is used to get the 3rd point, given 2D points a & b.
  247. The 3rd point is defined by rotating vector `a - b` by 90 degrees
  248. anticlockwise, using b as the rotation center.
  249. Args:
  250. a (np.ndarray): point(x,y)
  251. b (np.ndarray): point(x,y)
  252. Returns:
  253. np.ndarray: The 3rd point.
  254. """
  255. assert len(a) == 2
  256. assert len(b) == 2
  257. direction = a - b
  258. third_pt = b + np.array([-direction[1], direction[0]], dtype=np.float32)
  259. return third_pt
  260. @function_requires_deps("opencv-contrib-python")
  261. def get_affine_transform(
  262. center: ndarray,
  263. input_size: Union[Number, Tuple[Number, Number], ndarray],
  264. rot: float,
  265. output_size: ndarray,
  266. shift: Tuple[float, float] = (0.0, 0.0),
  267. inv: bool = False,
  268. ):
  269. """Get the affine transform matrix, given the center/scale/rot/output_size.
  270. Args:
  271. center (np.ndarray[2, ]): Center of the bounding box (x, y).
  272. input_size (np.ndarray[2, ]): Scale of the bounding box
  273. wrt [width, height].
  274. rot (float): Rotation angle (degree).
  275. output_size (np.ndarray[2, ]): Size of the destination heatmaps.
  276. shift (0-100%): Shift translation ratio wrt the width/height.
  277. Default (0., 0.).
  278. inv (bool): Option to inverse the affine transform direction.
  279. (inv=False: src->dst or inv=True: dst->src)
  280. Returns:
  281. np.ndarray: The transform matrix.
  282. """
  283. assert len(center) == 2
  284. assert len(output_size) == 2
  285. assert len(shift) == 2
  286. if not isinstance(input_size, (ndarray, list)):
  287. input_size = np.array([input_size, input_size], dtype=np.float32)
  288. scale_tmp = input_size
  289. shift = np.array(shift)
  290. src_w = scale_tmp[0]
  291. dst_w = output_size[0]
  292. dst_h = output_size[1]
  293. rot_rad = np.pi * rot / 180
  294. src_dir = rotate_point([0.0, src_w * -0.5], rot_rad)
  295. dst_dir = np.array([0.0, dst_w * -0.5])
  296. src = np.zeros((3, 2), dtype=np.float32)
  297. src[0, :] = center + scale_tmp * shift
  298. src[1, :] = center + src_dir + scale_tmp * shift
  299. src[2, :] = _get_3rd_point(src[0, :], src[1, :])
  300. dst = np.zeros((3, 2), dtype=np.float32)
  301. dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
  302. dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
  303. dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
  304. if inv:
  305. trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
  306. else:
  307. trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
  308. return trans
  309. @benchmark.timeit
  310. @class_requires_deps("opencv-contrib-python")
  311. class WarpAffine:
  312. """Apply warp affine transformation to the image based on the given parameters.
  313. Args:
  314. keep_res (bool): Whether to keep the original resolution aspect ratio during transformation.
  315. pad (int): Padding value used when keep_res is True.
  316. input_h (int): Target height for the input image when keep_res is False.
  317. input_w (int): Target width for the input image when keep_res is False.
  318. scale (float): Scale factor for resizing.
  319. shift (float): Shift factor for transformation.
  320. down_ratio (int): Downsampling ratio for the output image.
  321. """
  322. def __init__(
  323. self,
  324. keep_res=False,
  325. pad=31,
  326. input_h=512,
  327. input_w=512,
  328. scale=0.4,
  329. shift=0.1,
  330. down_ratio=4,
  331. ):
  332. super().__init__()
  333. self.keep_res = keep_res
  334. self.pad = pad
  335. self.input_h = input_h
  336. self.input_w = input_w
  337. self.scale = scale
  338. self.shift = shift
  339. self.down_ratio = down_ratio
  340. def apply(self, img: ndarray):
  341. img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
  342. h, w = img.shape[:2]
  343. if self.keep_res:
  344. # True in detection eval/infer
  345. input_h = (h | self.pad) + 1
  346. input_w = (w | self.pad) + 1
  347. s = np.array([input_w, input_h], dtype=np.float32)
  348. c = np.array([w // 2, h // 2], dtype=np.float32)
  349. else:
  350. # False in centertrack eval_mot/eval_mot
  351. s = max(h, w) * 1.0
  352. input_h, input_w = self.input_h, self.input_w
  353. c = np.array([w / 2.0, h / 2.0], dtype=np.float32)
  354. trans_input = get_affine_transform(c, s, 0, [input_w, input_h])
  355. img = cv2.resize(img, (w, h))
  356. inp = cv2.warpAffine(
  357. img, trans_input, (input_w, input_h), flags=cv2.INTER_LINEAR
  358. )
  359. if not self.keep_res:
  360. out_h = input_h // self.down_ratio
  361. out_w = input_w // self.down_ratio
  362. get_affine_transform(c, s, 0, [out_w, out_h])
  363. return inp
  364. def __call__(self, datas: List[dict]) -> List[dict]:
  365. for data in datas:
  366. ori_img = data["img"]
  367. if "ori_img_size" not in data:
  368. data["ori_img_size"] = [ori_img.shape[1], ori_img.shape[0]]
  369. img = self.apply(ori_img)
  370. data["img"] = img
  371. return datas
  372. def restructured_boxes(
  373. boxes: ndarray, labels: List[str], img_size: Tuple[int, int]
  374. ) -> Boxes:
  375. """
  376. Restructure the given bounding boxes and labels based on the image size.
  377. Args:
  378. boxes (ndarray): A 2D array of bounding boxes with each box represented as [cls_id, score, xmin, ymin, xmax, ymax].
  379. labels (List[str]): A list of class labels corresponding to the class ids.
  380. img_size (Tuple[int, int]): A tuple representing the width and height of the image.
  381. Returns:
  382. Boxes: A list of dictionaries, each containing 'cls_id', 'label', 'score', and 'coordinate' keys.
  383. """
  384. box_list = []
  385. w, h = img_size
  386. for box in boxes:
  387. xmin, ymin, xmax, ymax = box[2:]
  388. xmin = max(0, xmin)
  389. ymin = max(0, ymin)
  390. xmax = min(w, xmax)
  391. ymax = min(h, ymax)
  392. box_list.append(
  393. {
  394. "cls_id": int(box[0]),
  395. "label": labels[int(box[0])],
  396. "score": float(box[1]),
  397. "coordinate": [xmin, ymin, xmax, ymax],
  398. }
  399. )
  400. return box_list
  401. def restructured_rotated_boxes(
  402. boxes: ndarray, labels: List[str], img_size: Tuple[int, int]
  403. ) -> Boxes:
  404. """
  405. Restructure the given rotated bounding boxes and labels based on the image size.
  406. Args:
  407. boxes (ndarray): A 2D array of rotated bounding boxes with each box represented as [cls_id, score, x1, y1, x2, y2, x3, y3, x4, y4].
  408. labels (List[str]): A list of class labels corresponding to the class ids.
  409. img_size (Tuple[int, int]): A tuple representing the width and height of the image.
  410. Returns:
  411. Boxes: A list of dictionaries, each containing 'cls_id', 'label', 'score', and 'coordinate' keys.
  412. """
  413. box_list = []
  414. w, h = img_size
  415. assert boxes.shape[1] == 10, "The shape of rotated boxes should be [N, 10]"
  416. for box in boxes:
  417. x1, y1, x2, y2, x3, y3, x4, y4 = box[2:]
  418. x1 = min(max(0, x1), w)
  419. y1 = min(max(0, y1), h)
  420. x2 = min(max(0, x2), w)
  421. y2 = min(max(0, y2), h)
  422. x3 = min(max(0, x3), w)
  423. y3 = min(max(0, y3), h)
  424. x4 = min(max(0, x4), w)
  425. y4 = min(max(0, y4), h)
  426. box_list.append(
  427. {
  428. "cls_id": int(box[0]),
  429. "label": labels[int(box[0])],
  430. "score": float(box[1]),
  431. "coordinate": [x1, y1, x2, y2, x3, y3, x4, y4],
  432. }
  433. )
  434. return box_list
  435. def unclip_boxes(boxes, unclip_ratio=None):
  436. """
  437. Expand bounding boxes from (x1, y1, x2, y2) format using an unclipping ratio.
  438. Parameters:
  439. - boxes: np.ndarray of shape (N, 4), where each row is (x1, y1, x2, y2).
  440. - unclip_ratio: tuple of (width_ratio, height_ratio), optional.
  441. Returns:
  442. - expanded_boxes: np.ndarray of shape (N, 4), where each row is (x1, y1, x2, y2).
  443. """
  444. if unclip_ratio is None:
  445. return boxes
  446. if isinstance(unclip_ratio, dict):
  447. expanded_boxes = []
  448. for box in boxes:
  449. class_id, score, x1, y1, x2, y2 = box
  450. if class_id in unclip_ratio:
  451. width_ratio, height_ratio = unclip_ratio[class_id]
  452. width = x2 - x1
  453. height = y2 - y1
  454. new_w = width * width_ratio
  455. new_h = height * height_ratio
  456. center_x = x1 + width / 2
  457. center_y = y1 + height / 2
  458. new_x1 = center_x - new_w / 2
  459. new_y1 = center_y - new_h / 2
  460. new_x2 = center_x + new_w / 2
  461. new_y2 = center_y + new_h / 2
  462. expanded_boxes.append([class_id, score, new_x1, new_y1, new_x2, new_y2])
  463. else:
  464. expanded_boxes.append(box)
  465. return np.array(expanded_boxes)
  466. else:
  467. widths = boxes[:, 4] - boxes[:, 2]
  468. heights = boxes[:, 5] - boxes[:, 3]
  469. new_w = widths * unclip_ratio[0]
  470. new_h = heights * unclip_ratio[1]
  471. center_x = boxes[:, 2] + widths / 2
  472. center_y = boxes[:, 3] + heights / 2
  473. new_x1 = center_x - new_w / 2
  474. new_y1 = center_y - new_h / 2
  475. new_x2 = center_x + new_w / 2
  476. new_y2 = center_y + new_h / 2
  477. expanded_boxes = np.column_stack(
  478. (boxes[:, 0], boxes[:, 1], new_x1, new_y1, new_x2, new_y2)
  479. )
  480. return expanded_boxes
  481. def iou(box1, box2):
  482. """Compute the Intersection over Union (IoU) of two bounding boxes."""
  483. x1, y1, x2, y2 = box1
  484. x1_p, y1_p, x2_p, y2_p = box2
  485. # Compute the intersection coordinates
  486. x1_i = max(x1, x1_p)
  487. y1_i = max(y1, y1_p)
  488. x2_i = min(x2, x2_p)
  489. y2_i = min(y2, y2_p)
  490. # Compute the area of intersection
  491. inter_area = max(0, x2_i - x1_i + 1) * max(0, y2_i - y1_i + 1)
  492. # Compute the area of both bounding boxes
  493. box1_area = (x2 - x1 + 1) * (y2 - y1 + 1)
  494. box2_area = (x2_p - x1_p + 1) * (y2_p - y1_p + 1)
  495. # Compute the IoU
  496. iou_value = inter_area / float(box1_area + box2_area - inter_area)
  497. return iou_value
  498. def nms(boxes, iou_same=0.6, iou_diff=0.95):
  499. """Perform Non-Maximum Suppression (NMS) with different IoU thresholds for same and different classes."""
  500. # Extract class scores
  501. scores = boxes[:, 1]
  502. # Sort indices by scores in descending order
  503. indices = np.argsort(scores)[::-1]
  504. selected_boxes = []
  505. while len(indices) > 0:
  506. current = indices[0]
  507. current_box = boxes[current]
  508. current_class = current_box[0]
  509. current_box[1]
  510. current_coords = current_box[2:]
  511. selected_boxes.append(current)
  512. indices = indices[1:]
  513. filtered_indices = []
  514. for i in indices:
  515. box = boxes[i]
  516. box_class = box[0]
  517. box_coords = box[2:]
  518. iou_value = iou(current_coords, box_coords)
  519. threshold = iou_same if current_class == box_class else iou_diff
  520. # If the IoU is below the threshold, keep the box
  521. if iou_value < threshold:
  522. filtered_indices.append(i)
  523. indices = filtered_indices
  524. return selected_boxes
  525. def is_contained(box1, box2):
  526. """Check if box1 is contained within box2."""
  527. _, _, x1, y1, x2, y2 = box1
  528. _, _, x1_p, y1_p, x2_p, y2_p = box2
  529. box1_area = (x2 - x1) * (y2 - y1)
  530. xi1 = max(x1, x1_p)
  531. yi1 = max(y1, y1_p)
  532. xi2 = min(x2, x2_p)
  533. yi2 = min(y2, y2_p)
  534. inter_width = max(0, xi2 - xi1)
  535. inter_height = max(0, yi2 - yi1)
  536. intersect_area = inter_width * inter_height
  537. iou = intersect_area / box1_area if box1_area > 0 else 0
  538. return iou >= 0.9
  539. def check_containment(boxes, formula_index=None, category_index=None, mode=None):
  540. """Check containment relationships among boxes."""
  541. n = len(boxes)
  542. contains_other = np.zeros(n, dtype=int)
  543. contained_by_other = np.zeros(n, dtype=int)
  544. for i in range(n):
  545. for j in range(n):
  546. if i == j:
  547. continue
  548. if formula_index is not None:
  549. if boxes[i][0] == formula_index and boxes[j][0] != formula_index:
  550. continue
  551. if category_index is not None and mode is not None:
  552. if mode == "large" and boxes[j][0] == category_index:
  553. if is_contained(boxes[i], boxes[j]):
  554. contained_by_other[i] = 1
  555. contains_other[j] = 1
  556. if mode == "small" and boxes[i][0] == category_index:
  557. if is_contained(boxes[i], boxes[j]):
  558. contained_by_other[i] = 1
  559. contains_other[j] = 1
  560. else:
  561. if is_contained(boxes[i], boxes[j]):
  562. contained_by_other[i] = 1
  563. contains_other[j] = 1
  564. return contains_other, contained_by_other
  565. @benchmark.timeit
  566. class DetPostProcess:
  567. """Save Result Transform
  568. This class is responsible for post-processing detection results, including
  569. thresholding, non-maximum suppression (NMS), and restructuring the boxes
  570. based on the input type (normal or rotated object detection).
  571. """
  572. def __init__(self, labels: Optional[List[str]] = None) -> None:
  573. """Initialize the DetPostProcess class.
  574. Args:
  575. threshold (float, optional): The threshold to apply to the detection scores. Defaults to 0.5.
  576. labels (Optional[List[str]], optional): The list of labels for the detection categories. Defaults to None.
  577. layout_postprocess (bool, optional): Whether to apply layout post-processing. Defaults to False.
  578. """
  579. super().__init__()
  580. self.labels = labels
  581. def apply(
  582. self,
  583. boxes: ndarray,
  584. img_size: Tuple[int, int],
  585. threshold: Union[float, dict],
  586. layout_nms: Optional[bool],
  587. layout_unclip_ratio: Optional[Union[float, Tuple[float, float], dict]],
  588. layout_merge_bboxes_mode: Optional[Union[str, dict]],
  589. ) -> Boxes:
  590. """Apply post-processing to the detection boxes.
  591. Args:
  592. boxes (ndarray): The input detection boxes with scores.
  593. img_size (tuple): The original image size.
  594. Returns:
  595. Boxes: The post-processed detection boxes.
  596. """
  597. if isinstance(threshold, float):
  598. expect_boxes = (boxes[:, 1] > threshold) & (boxes[:, 0] > -1)
  599. boxes = boxes[expect_boxes, :]
  600. elif isinstance(threshold, dict):
  601. category_filtered_boxes = []
  602. for cat_id in np.unique(boxes[:, 0]):
  603. category_boxes = boxes[boxes[:, 0] == cat_id]
  604. category_threshold = threshold.get(int(cat_id), 0.5)
  605. selected_indices = (category_boxes[:, 1] > category_threshold) & (
  606. category_boxes[:, 0] > -1
  607. )
  608. category_filtered_boxes.append(category_boxes[selected_indices])
  609. boxes = (
  610. np.vstack(category_filtered_boxes)
  611. if category_filtered_boxes
  612. else np.array([])
  613. )
  614. if layout_nms:
  615. selected_indices = nms(boxes, iou_same=0.6, iou_diff=0.98)
  616. boxes = np.array(boxes[selected_indices])
  617. filter_large_image = True
  618. if filter_large_image and len(boxes) > 1:
  619. if img_size[0] > img_size[1]:
  620. area_thres = 0.82
  621. else:
  622. area_thres = 0.93
  623. image_index = (
  624. self.labels.index("image") if "image" in self.labels else None
  625. )
  626. img_area = img_size[0] * img_size[1]
  627. filtered_boxes = []
  628. for box in boxes:
  629. label_index, score, xmin, ymin, xmax, ymax = box
  630. if label_index == image_index:
  631. xmin = max(0, xmin)
  632. ymin = max(0, ymin)
  633. xmax = min(img_size[0], xmax)
  634. ymax = min(img_size[1], ymax)
  635. box_area = (xmax - xmin) * (ymax - ymin)
  636. if box_area <= area_thres * img_area:
  637. filtered_boxes.append(box)
  638. else:
  639. filtered_boxes.append(box)
  640. if len(filtered_boxes) == 0:
  641. filtered_boxes = boxes
  642. boxes = np.array(filtered_boxes)
  643. if layout_merge_bboxes_mode:
  644. formula_index = (
  645. self.labels.index("formula") if "formula" in self.labels else None
  646. )
  647. if isinstance(layout_merge_bboxes_mode, str):
  648. assert layout_merge_bboxes_mode in [
  649. "union",
  650. "large",
  651. "small",
  652. ], f"The value of `layout_merge_bboxes_mode` must be one of ['union', 'large', 'small'], but got {layout_merge_bboxes_mode}"
  653. if layout_merge_bboxes_mode == "union":
  654. pass
  655. else:
  656. contains_other, contained_by_other = check_containment(
  657. boxes, formula_index
  658. )
  659. if layout_merge_bboxes_mode == "large":
  660. boxes = boxes[contained_by_other == 0]
  661. elif layout_merge_bboxes_mode == "small":
  662. boxes = boxes[(contains_other == 0) | (contained_by_other == 1)]
  663. elif isinstance(layout_merge_bboxes_mode, dict):
  664. keep_mask = np.ones(len(boxes), dtype=bool)
  665. for category_index, layout_mode in layout_merge_bboxes_mode.items():
  666. assert layout_mode in [
  667. "union",
  668. "large",
  669. "small",
  670. ], f"The value of `layout_merge_bboxes_mode` must be one of ['union', 'large', 'small'], but got {layout_mode}"
  671. if layout_mode == "union":
  672. pass
  673. else:
  674. if layout_mode == "large":
  675. contains_other, contained_by_other = check_containment(
  676. boxes, formula_index, category_index, mode=layout_mode
  677. )
  678. # Remove boxes that are contained by other boxes
  679. keep_mask &= contained_by_other == 0
  680. elif layout_mode == "small":
  681. contains_other, contained_by_other = check_containment(
  682. boxes, formula_index, category_index, mode=layout_mode
  683. )
  684. # Keep boxes that do not contain others or are contained by others
  685. keep_mask &= (contains_other == 0) | (
  686. contained_by_other == 1
  687. )
  688. boxes = boxes[keep_mask]
  689. if boxes.size == 0:
  690. return np.array([])
  691. if layout_unclip_ratio:
  692. if isinstance(layout_unclip_ratio, float):
  693. layout_unclip_ratio = (layout_unclip_ratio, layout_unclip_ratio)
  694. elif isinstance(layout_unclip_ratio, (tuple, list)):
  695. assert (
  696. len(layout_unclip_ratio) == 2
  697. ), f"The length of `layout_unclip_ratio` should be 2."
  698. elif isinstance(layout_unclip_ratio, dict):
  699. pass
  700. else:
  701. raise ValueError(
  702. f"The type of `layout_unclip_ratio` must be float, Tuple[float, float] or Dict[int, Tuple[float, float]], but got {type(layout_unclip_ratio)}."
  703. )
  704. boxes = unclip_boxes(boxes, layout_unclip_ratio)
  705. if boxes.shape[1] == 6:
  706. """For Normal Object Detection"""
  707. boxes = restructured_boxes(boxes, self.labels, img_size)
  708. elif boxes.shape[1] == 10:
  709. """Adapt For Rotated Object Detection"""
  710. boxes = restructured_rotated_boxes(boxes, self.labels, img_size)
  711. else:
  712. """Unexpected Input Box Shape"""
  713. raise ValueError(
  714. f"The shape of boxes should be 6 or 10, instead of {boxes.shape[1]}"
  715. )
  716. return boxes
  717. def __call__(
  718. self,
  719. batch_outputs: List[dict],
  720. datas: List[dict],
  721. threshold: Optional[Union[float, dict]] = None,
  722. layout_nms: Optional[bool] = None,
  723. layout_unclip_ratio: Optional[Union[float, Tuple[float, float]]] = None,
  724. layout_merge_bboxes_mode: Optional[str] = None,
  725. ) -> List[Boxes]:
  726. """Apply the post-processing to a batch of outputs.
  727. Args:
  728. batch_outputs (List[dict]): The list of detection outputs.
  729. datas (List[dict]): The list of input data.
  730. Returns:
  731. List[Boxes]: The list of post-processed detection boxes.
  732. """
  733. outputs = []
  734. for data, output in zip(datas, batch_outputs):
  735. boxes = self.apply(
  736. output["boxes"],
  737. data["ori_img_size"],
  738. threshold,
  739. layout_nms,
  740. layout_unclip_ratio,
  741. layout_merge_bboxes_mode,
  742. )
  743. outputs.append(boxes)
  744. return outputs