processors.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import List, Optional, Sequence, Tuple, Union
  15. import numpy as np
  16. from numpy import ndarray
  17. from ....utils.deps import class_requires_deps, function_requires_deps, is_dep_available
  18. from ...common.reader import ReadImage as CommonReadImage
  19. from ...utils.benchmark import benchmark
  20. from ..common import Normalize as CommonNormalize
  21. from ..common import Resize as CommonResize
  22. if is_dep_available("opencv-contrib-python"):
  23. import cv2
  24. Boxes = List[dict]
  25. Number = Union[int, float]
  26. @benchmark.timeit_with_options(name=None, is_read_operation=True)
  27. @class_requires_deps("opencv-contrib-python")
  28. class ReadImage(CommonReadImage):
  29. """Reads images from a list of raw image data or file paths."""
  30. def __call__(self, raw_imgs: List[Union[ndarray, str, dict]]) -> List[dict]:
  31. """Processes the input list of raw image data or file paths and returns a list of dictionaries containing image information.
  32. Args:
  33. raw_imgs (List[Union[ndarray, str]]): A list of raw image data (numpy ndarrays) or file paths (strings).
  34. Returns:
  35. List[dict]: A list of dictionaries, each containing image information.
  36. """
  37. out_datas = []
  38. for raw_img in raw_imgs:
  39. data = dict()
  40. if isinstance(raw_img, str):
  41. data["img_path"] = raw_img
  42. if isinstance(raw_img, dict):
  43. if "img" in raw_img:
  44. src_img = raw_img["img"]
  45. elif "img_path" in raw_img:
  46. src_img = raw_img["img_path"]
  47. data["img_path"] = src_img
  48. else:
  49. raise ValueError(
  50. "When raw_img is dict, must have one of keys ['img', 'img_path']."
  51. )
  52. data.update(raw_img)
  53. raw_img = src_img
  54. img, ori_img = self.read(raw_img)
  55. data["img"] = img
  56. data["ori_img"] = ori_img
  57. data["img_size"] = [img.shape[1], img.shape[0]] # [size_w, size_h]
  58. data["ori_img_size"] = [img.shape[1], img.shape[0]] # [size_w, size_h]
  59. out_datas.append(data)
  60. return out_datas
  61. def read(self, img):
  62. if isinstance(img, np.ndarray):
  63. ori_img = img
  64. if self.format == "RGB":
  65. img = cv2.cvtColor(ori_img, cv2.COLOR_BGR2RGB)
  66. return img, ori_img
  67. elif isinstance(img, str):
  68. blob = self._img_reader.read(img)
  69. if blob is None:
  70. raise Exception(f"Image read Error: {img}")
  71. ori_img = blob
  72. if self.format == "RGB":
  73. if blob.ndim != 3:
  74. raise RuntimeError("Array is not 3-dimensional.")
  75. # BGR to RGB
  76. blob = cv2.cvtColor(blob, cv2.COLOR_BGR2RGB)
  77. return blob, ori_img
  78. else:
  79. raise TypeError(
  80. f"ReadImage only supports the following types:\n"
  81. f"1. str, indicating a image file path or a directory containing image files.\n"
  82. f"2. numpy.ndarray.\n"
  83. f"However, got type: {type(img).__name__}."
  84. )
  85. @benchmark.timeit
  86. class Resize(CommonResize):
  87. def __call__(self, datas: List[dict]) -> List[dict]:
  88. """
  89. Args:
  90. datas (List[dict]): A list of dictionaries, each containing image data with key 'img'.
  91. Returns:
  92. List[dict]: A list of dictionaries with updated image data, including resized images,
  93. original image sizes, resized image sizes, and scale factors.
  94. """
  95. for data in datas:
  96. ori_img = data["img"]
  97. if "ori_img_size" not in data:
  98. data["ori_img_size"] = [ori_img.shape[1], ori_img.shape[0]]
  99. ori_img_size = data["ori_img_size"]
  100. img = self.resize(ori_img)
  101. data["img"] = img
  102. img_size = [img.shape[1], img.shape[0]]
  103. data["img_size"] = img_size # [size_w, size_h]
  104. data["scale_factors"] = [ # [w_scale, h_scale]
  105. img_size[0] / ori_img_size[0],
  106. img_size[1] / ori_img_size[1],
  107. ]
  108. return datas
  109. @benchmark.timeit
  110. class Normalize(CommonNormalize):
  111. def __call__(self, datas: List[dict]) -> List[dict]:
  112. """Normalizes images in a list of dictionaries. Iterates over each dictionary,
  113. applies normalization to the 'img' key, and returns the modified list.
  114. """
  115. for data in datas:
  116. data["img"] = self.norm(data["img"])
  117. return datas
  118. @benchmark.timeit
  119. class ToCHWImage:
  120. """Converts images in a list of dictionaries from HWC to CHW format."""
  121. def __call__(self, datas: List[dict]) -> List[dict]:
  122. """Converts the image data in the list of dictionaries from HWC to CHW format in-place.
  123. Args:
  124. datas (List[dict]): A list of dictionaries, each containing an image tensor in 'img' key with HWC format.
  125. Returns:
  126. List[dict]: The same list of dictionaries with the image tensors converted to CHW format.
  127. """
  128. for data in datas:
  129. data["img"] = data["img"].transpose((2, 0, 1))
  130. return datas
  131. @benchmark.timeit
  132. class ToBatch:
  133. """
  134. Class for batch processing of data dictionaries.
  135. Args:
  136. ordered_required_keys (Optional[Tuple[str]]): A tuple of keys that need to be present in the input data dictionaries in a specific order.
  137. """
  138. def __init__(self, ordered_required_keys: Optional[Tuple[str]] = None):
  139. self.ordered_required_keys = ordered_required_keys
  140. def apply(
  141. self, datas: List[dict], key: str, dtype: np.dtype = np.float32
  142. ) -> np.ndarray:
  143. """
  144. Apply batch processing to a list of data dictionaries.
  145. Args:
  146. datas (List[dict]): A list of data dictionaries to process.
  147. key (str): The key in the data dictionaries to extract and batch.
  148. dtype (np.dtype): The desired data type of the output array (default is np.float32).
  149. Returns:
  150. np.ndarray: A numpy array containing the batched data.
  151. Raises:
  152. KeyError: If the specified key is not found in any of the data dictionaries.
  153. """
  154. if key == "img_size":
  155. # [h, w] size for det models
  156. img_sizes = [data[key][::-1] for data in datas]
  157. return np.stack(img_sizes, axis=0).astype(dtype=dtype, copy=False)
  158. elif key == "scale_factors":
  159. # [h, w] scale factors for det models, default [1.0, 1.0]
  160. scale_factors = [data.get(key, [1.0, 1.0])[::-1] for data in datas]
  161. return np.stack(scale_factors, axis=0).astype(dtype=dtype, copy=False)
  162. else:
  163. return np.stack([data[key] for data in datas], axis=0).astype(
  164. dtype=dtype, copy=False
  165. )
  166. def __call__(self, datas: List[dict]) -> Sequence[ndarray]:
  167. return [self.apply(datas, key) for key in self.ordered_required_keys]
  168. @benchmark.timeit
  169. class DetPad:
  170. """
  171. Pad image to a specified size.
  172. Args:
  173. size (list[int]): image target size
  174. fill_value (list[float]): rgb value of pad area, default (114.0, 114.0, 114.0)
  175. """
  176. def __init__(
  177. self,
  178. size: List[int],
  179. fill_value: List[Union[int, float]] = [114.0, 114.0, 114.0],
  180. ):
  181. super().__init__()
  182. if isinstance(size, int):
  183. size = [size, size]
  184. self.size = size
  185. self.fill_value = fill_value
  186. def apply(self, img: ndarray) -> ndarray:
  187. im = img
  188. im_h, im_w = im.shape[:2]
  189. h, w = self.size
  190. if h == im_h and w == im_w:
  191. return im
  192. canvas = np.ones((h, w, 3), dtype=np.float32)
  193. canvas *= np.array(self.fill_value, dtype=np.float32)
  194. canvas[0:im_h, 0:im_w, :] = im.astype(np.float32)
  195. return canvas
  196. def __call__(self, datas: List[dict]) -> List[dict]:
  197. for data in datas:
  198. data["img"] = self.apply(data["img"])
  199. return datas
  200. @benchmark.timeit
  201. class PadStride:
  202. """padding image for model with FPN , instead PadBatch(pad_to_stride, pad_gt) in original config
  203. Args:
  204. stride (bool): model with FPN need image shape % stride == 0
  205. """
  206. def __init__(self, stride: int = 0):
  207. super().__init__()
  208. self.coarsest_stride = stride
  209. def apply(self, img: ndarray):
  210. """
  211. Args:
  212. im (np.ndarray): image (np.ndarray)
  213. Returns:
  214. im (np.ndarray): processed image (np.ndarray)
  215. """
  216. im = img
  217. coarsest_stride = self.coarsest_stride
  218. if coarsest_stride <= 0:
  219. return img
  220. im_c, im_h, im_w = im.shape
  221. pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride)
  222. pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride)
  223. padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32)
  224. padding_im[:, :im_h, :im_w] = im
  225. return padding_im
  226. def __call__(self, datas: List[dict]) -> List[dict]:
  227. for data in datas:
  228. data["img"] = self.apply(data["img"])
  229. return datas
  230. def rotate_point(pt: List[float], angle_rad: float) -> List[float]:
  231. """Rotate a point by an angle.
  232. Args:
  233. pt (list[float]): 2 dimensional point to be rotated
  234. angle_rad (float): rotation angle by radian
  235. Returns:
  236. list[float]: Rotated point.
  237. """
  238. assert len(pt) == 2
  239. sn, cs = np.sin(angle_rad), np.cos(angle_rad)
  240. new_x = pt[0] * cs - pt[1] * sn
  241. new_y = pt[0] * sn + pt[1] * cs
  242. rotated_pt = [new_x, new_y]
  243. return rotated_pt
  244. def _get_3rd_point(a: ndarray, b: ndarray) -> ndarray:
  245. """To calculate the affine matrix, three pairs of points are required. This
  246. function is used to get the 3rd point, given 2D points a & b.
  247. The 3rd point is defined by rotating vector `a - b` by 90 degrees
  248. anticlockwise, using b as the rotation center.
  249. Args:
  250. a (np.ndarray): point(x,y)
  251. b (np.ndarray): point(x,y)
  252. Returns:
  253. np.ndarray: The 3rd point.
  254. """
  255. assert len(a) == 2
  256. assert len(b) == 2
  257. direction = a - b
  258. third_pt = b + np.array([-direction[1], direction[0]], dtype=np.float32)
  259. return third_pt
  260. @function_requires_deps("opencv-contrib-python")
  261. def get_affine_transform(
  262. center: ndarray,
  263. input_size: Union[Number, Tuple[Number, Number], ndarray],
  264. rot: float,
  265. output_size: ndarray,
  266. shift: Tuple[float, float] = (0.0, 0.0),
  267. inv: bool = False,
  268. ):
  269. """Get the affine transform matrix, given the center/scale/rot/output_size.
  270. Args:
  271. center (np.ndarray[2, ]): Center of the bounding box (x, y).
  272. input_size (np.ndarray[2, ]): Scale of the bounding box
  273. wrt [width, height].
  274. rot (float): Rotation angle (degree).
  275. output_size (np.ndarray[2, ]): Size of the destination heatmaps.
  276. shift (0-100%): Shift translation ratio wrt the width/height.
  277. Default (0., 0.).
  278. inv (bool): Option to inverse the affine transform direction.
  279. (inv=False: src->dst or inv=True: dst->src)
  280. Returns:
  281. np.ndarray: The transform matrix.
  282. """
  283. assert len(center) == 2
  284. assert len(output_size) == 2
  285. assert len(shift) == 2
  286. if not isinstance(input_size, (ndarray, list)):
  287. input_size = np.array([input_size, input_size], dtype=np.float32)
  288. scale_tmp = input_size
  289. shift = np.array(shift)
  290. src_w = scale_tmp[0]
  291. dst_w = output_size[0]
  292. dst_h = output_size[1]
  293. rot_rad = np.pi * rot / 180
  294. src_dir = rotate_point([0.0, src_w * -0.5], rot_rad)
  295. dst_dir = np.array([0.0, dst_w * -0.5])
  296. src = np.zeros((3, 2), dtype=np.float32)
  297. src[0, :] = center + scale_tmp * shift
  298. src[1, :] = center + src_dir + scale_tmp * shift
  299. src[2, :] = _get_3rd_point(src[0, :], src[1, :])
  300. dst = np.zeros((3, 2), dtype=np.float32)
  301. dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
  302. dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
  303. dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
  304. if inv:
  305. trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
  306. else:
  307. trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
  308. return trans
  309. @benchmark.timeit
  310. @class_requires_deps("opencv-contrib-python")
  311. class WarpAffine:
  312. """Apply warp affine transformation to the image based on the given parameters.
  313. Args:
  314. keep_res (bool): Whether to keep the original resolution aspect ratio during transformation.
  315. pad (int): Padding value used when keep_res is True.
  316. input_h (int): Target height for the input image when keep_res is False.
  317. input_w (int): Target width for the input image when keep_res is False.
  318. scale (float): Scale factor for resizing.
  319. shift (float): Shift factor for transformation.
  320. down_ratio (int): Downsampling ratio for the output image.
  321. """
  322. def __init__(
  323. self,
  324. keep_res=False,
  325. pad=31,
  326. input_h=512,
  327. input_w=512,
  328. scale=0.4,
  329. shift=0.1,
  330. down_ratio=4,
  331. ):
  332. super().__init__()
  333. self.keep_res = keep_res
  334. self.pad = pad
  335. self.input_h = input_h
  336. self.input_w = input_w
  337. self.scale = scale
  338. self.shift = shift
  339. self.down_ratio = down_ratio
  340. def apply(self, img: ndarray):
  341. img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
  342. h, w = img.shape[:2]
  343. if self.keep_res:
  344. # True in detection eval/infer
  345. input_h = (h | self.pad) + 1
  346. input_w = (w | self.pad) + 1
  347. s = np.array([input_w, input_h], dtype=np.float32)
  348. c = np.array([w // 2, h // 2], dtype=np.float32)
  349. else:
  350. # False in centertrack eval_mot/eval_mot
  351. s = max(h, w) * 1.0
  352. input_h, input_w = self.input_h, self.input_w
  353. c = np.array([w / 2.0, h / 2.0], dtype=np.float32)
  354. trans_input = get_affine_transform(c, s, 0, [input_w, input_h])
  355. img = cv2.resize(img, (w, h))
  356. inp = cv2.warpAffine(
  357. img, trans_input, (input_w, input_h), flags=cv2.INTER_LINEAR
  358. )
  359. if not self.keep_res:
  360. out_h = input_h // self.down_ratio
  361. out_w = input_w // self.down_ratio
  362. get_affine_transform(c, s, 0, [out_w, out_h])
  363. return inp
  364. def __call__(self, datas: List[dict]) -> List[dict]:
  365. for data in datas:
  366. ori_img = data["img"]
  367. if "ori_img_size" not in data:
  368. data["ori_img_size"] = [ori_img.shape[1], ori_img.shape[0]]
  369. img = self.apply(ori_img)
  370. data["img"] = img
  371. return datas
  372. def restructured_boxes(
  373. boxes: ndarray, labels: List[str], img_size: Tuple[int, int]
  374. ) -> Boxes:
  375. """
  376. Restructure the given bounding boxes and labels based on the image size.
  377. Args:
  378. boxes (ndarray): A 2D array of bounding boxes with each box represented as [cls_id, score, xmin, ymin, xmax, ymax].
  379. labels (List[str]): A list of class labels corresponding to the class ids.
  380. img_size (Tuple[int, int]): A tuple representing the width and height of the image.
  381. Returns:
  382. Boxes: A list of dictionaries, each containing 'cls_id', 'label', 'score', and 'coordinate' keys.
  383. """
  384. box_list = []
  385. w, h = img_size
  386. for box in boxes:
  387. xmin, ymin, xmax, ymax = box[2:]
  388. xmin = max(0, xmin)
  389. ymin = max(0, ymin)
  390. xmax = min(w, xmax)
  391. ymax = min(h, ymax)
  392. if xmax <= xmin or ymax <= ymin:
  393. continue
  394. box_list.append(
  395. {
  396. "cls_id": int(box[0]),
  397. "label": labels[int(box[0])],
  398. "score": float(box[1]),
  399. "coordinate": [xmin, ymin, xmax, ymax],
  400. }
  401. )
  402. return box_list
  403. def restructured_rotated_boxes(
  404. boxes: ndarray, labels: List[str], img_size: Tuple[int, int]
  405. ) -> Boxes:
  406. """
  407. Restructure the given rotated bounding boxes and labels based on the image size.
  408. Args:
  409. boxes (ndarray): A 2D array of rotated bounding boxes with each box represented as [cls_id, score, x1, y1, x2, y2, x3, y3, x4, y4].
  410. labels (List[str]): A list of class labels corresponding to the class ids.
  411. img_size (Tuple[int, int]): A tuple representing the width and height of the image.
  412. Returns:
  413. Boxes: A list of dictionaries, each containing 'cls_id', 'label', 'score', and 'coordinate' keys.
  414. """
  415. box_list = []
  416. w, h = img_size
  417. assert boxes.shape[1] == 10, "The shape of rotated boxes should be [N, 10]"
  418. for box in boxes:
  419. x1, y1, x2, y2, x3, y3, x4, y4 = box[2:]
  420. x1 = min(max(0, x1), w)
  421. y1 = min(max(0, y1), h)
  422. x2 = min(max(0, x2), w)
  423. y2 = min(max(0, y2), h)
  424. x3 = min(max(0, x3), w)
  425. y3 = min(max(0, y3), h)
  426. x4 = min(max(0, x4), w)
  427. y4 = min(max(0, y4), h)
  428. box_list.append(
  429. {
  430. "cls_id": int(box[0]),
  431. "label": labels[int(box[0])],
  432. "score": float(box[1]),
  433. "coordinate": [x1, y1, x2, y2, x3, y3, x4, y4],
  434. }
  435. )
  436. return box_list
  437. def unclip_boxes(boxes, unclip_ratio=None):
  438. """
  439. Expand bounding boxes from (x1, y1, x2, y2) format using an unclipping ratio.
  440. Parameters:
  441. - boxes: np.ndarray of shape (N, 4), where each row is (x1, y1, x2, y2).
  442. - unclip_ratio: tuple of (width_ratio, height_ratio), optional.
  443. Returns:
  444. - expanded_boxes: np.ndarray of shape (N, 4), where each row is (x1, y1, x2, y2).
  445. """
  446. if unclip_ratio is None:
  447. return boxes
  448. if isinstance(unclip_ratio, dict):
  449. expanded_boxes = []
  450. for box in boxes:
  451. class_id, score, x1, y1, x2, y2 = box
  452. if class_id in unclip_ratio:
  453. width_ratio, height_ratio = unclip_ratio[class_id]
  454. width = x2 - x1
  455. height = y2 - y1
  456. new_w = width * width_ratio
  457. new_h = height * height_ratio
  458. center_x = x1 + width / 2
  459. center_y = y1 + height / 2
  460. new_x1 = center_x - new_w / 2
  461. new_y1 = center_y - new_h / 2
  462. new_x2 = center_x + new_w / 2
  463. new_y2 = center_y + new_h / 2
  464. expanded_boxes.append([class_id, score, new_x1, new_y1, new_x2, new_y2])
  465. else:
  466. expanded_boxes.append(box)
  467. return np.array(expanded_boxes)
  468. else:
  469. widths = boxes[:, 4] - boxes[:, 2]
  470. heights = boxes[:, 5] - boxes[:, 3]
  471. new_w = widths * unclip_ratio[0]
  472. new_h = heights * unclip_ratio[1]
  473. center_x = boxes[:, 2] + widths / 2
  474. center_y = boxes[:, 3] + heights / 2
  475. new_x1 = center_x - new_w / 2
  476. new_y1 = center_y - new_h / 2
  477. new_x2 = center_x + new_w / 2
  478. new_y2 = center_y + new_h / 2
  479. expanded_boxes = np.column_stack(
  480. (boxes[:, 0], boxes[:, 1], new_x1, new_y1, new_x2, new_y2)
  481. )
  482. return expanded_boxes
  483. def iou(box1, box2):
  484. """Compute the Intersection over Union (IoU) of two bounding boxes."""
  485. x1, y1, x2, y2 = box1
  486. x1_p, y1_p, x2_p, y2_p = box2
  487. # Compute the intersection coordinates
  488. x1_i = max(x1, x1_p)
  489. y1_i = max(y1, y1_p)
  490. x2_i = min(x2, x2_p)
  491. y2_i = min(y2, y2_p)
  492. # Compute the area of intersection
  493. inter_area = max(0, x2_i - x1_i + 1) * max(0, y2_i - y1_i + 1)
  494. # Compute the area of both bounding boxes
  495. box1_area = (x2 - x1 + 1) * (y2 - y1 + 1)
  496. box2_area = (x2_p - x1_p + 1) * (y2_p - y1_p + 1)
  497. # Compute the IoU
  498. iou_value = inter_area / float(box1_area + box2_area - inter_area)
  499. return iou_value
  500. def nms(boxes, iou_same=0.6, iou_diff=0.95):
  501. """Perform Non-Maximum Suppression (NMS) with different IoU thresholds for same and different classes."""
  502. # Extract class scores
  503. scores = boxes[:, 1]
  504. # Sort indices by scores in descending order
  505. indices = np.argsort(scores)[::-1]
  506. selected_boxes = []
  507. while len(indices) > 0:
  508. current = indices[0]
  509. current_box = boxes[current]
  510. current_class = current_box[0]
  511. current_box[1]
  512. current_coords = current_box[2:]
  513. selected_boxes.append(current)
  514. indices = indices[1:]
  515. filtered_indices = []
  516. for i in indices:
  517. box = boxes[i]
  518. box_class = box[0]
  519. box_coords = box[2:]
  520. iou_value = iou(current_coords, box_coords)
  521. threshold = iou_same if current_class == box_class else iou_diff
  522. # If the IoU is below the threshold, keep the box
  523. if iou_value < threshold:
  524. filtered_indices.append(i)
  525. indices = filtered_indices
  526. return selected_boxes
  527. def is_contained(box1, box2):
  528. """Check if box1 is contained within box2."""
  529. _, _, x1, y1, x2, y2 = box1
  530. _, _, x1_p, y1_p, x2_p, y2_p = box2
  531. box1_area = (x2 - x1) * (y2 - y1)
  532. xi1 = max(x1, x1_p)
  533. yi1 = max(y1, y1_p)
  534. xi2 = min(x2, x2_p)
  535. yi2 = min(y2, y2_p)
  536. inter_width = max(0, xi2 - xi1)
  537. inter_height = max(0, yi2 - yi1)
  538. intersect_area = inter_width * inter_height
  539. iou = intersect_area / box1_area if box1_area > 0 else 0
  540. return iou >= 0.9
  541. def check_containment(boxes, formula_index=None, category_index=None, mode=None):
  542. """Check containment relationships among boxes."""
  543. n = len(boxes)
  544. contains_other = np.zeros(n, dtype=int)
  545. contained_by_other = np.zeros(n, dtype=int)
  546. for i in range(n):
  547. for j in range(n):
  548. if i == j:
  549. continue
  550. if formula_index is not None:
  551. if boxes[i][0] == formula_index and boxes[j][0] != formula_index:
  552. continue
  553. if category_index is not None and mode is not None:
  554. if mode == "large" and boxes[j][0] == category_index:
  555. if is_contained(boxes[i], boxes[j]):
  556. contained_by_other[i] = 1
  557. contains_other[j] = 1
  558. if mode == "small" and boxes[i][0] == category_index:
  559. if is_contained(boxes[i], boxes[j]):
  560. contained_by_other[i] = 1
  561. contains_other[j] = 1
  562. else:
  563. if is_contained(boxes[i], boxes[j]):
  564. contained_by_other[i] = 1
  565. contains_other[j] = 1
  566. return contains_other, contained_by_other
  567. @benchmark.timeit
  568. class DetPostProcess:
  569. """Save Result Transform
  570. This class is responsible for post-processing detection results, including
  571. thresholding, non-maximum suppression (NMS), and restructuring the boxes
  572. based on the input type (normal or rotated object detection).
  573. """
  574. def __init__(self, labels: Optional[List[str]] = None) -> None:
  575. """Initialize the DetPostProcess class.
  576. Args:
  577. threshold (float, optional): The threshold to apply to the detection scores. Defaults to 0.5.
  578. labels (Optional[List[str]], optional): The list of labels for the detection categories. Defaults to None.
  579. layout_postprocess (bool, optional): Whether to apply layout post-processing. Defaults to False.
  580. """
  581. super().__init__()
  582. self.labels = labels
  583. def apply(
  584. self,
  585. boxes: ndarray,
  586. img_size: Tuple[int, int],
  587. threshold: Union[float, dict],
  588. layout_nms: Optional[bool],
  589. layout_unclip_ratio: Optional[Union[float, Tuple[float, float], dict]],
  590. layout_merge_bboxes_mode: Optional[Union[str, dict]],
  591. ) -> Boxes:
  592. """Apply post-processing to the detection boxes.
  593. Args:
  594. boxes (ndarray): The input detection boxes with scores.
  595. img_size (tuple): The original image size.
  596. Returns:
  597. Boxes: The post-processed detection boxes.
  598. """
  599. if isinstance(threshold, float):
  600. expect_boxes = (boxes[:, 1] > threshold) & (boxes[:, 0] > -1)
  601. boxes = boxes[expect_boxes, :]
  602. elif isinstance(threshold, dict):
  603. category_filtered_boxes = []
  604. for cat_id in np.unique(boxes[:, 0]):
  605. category_boxes = boxes[boxes[:, 0] == cat_id]
  606. category_threshold = threshold.get(int(cat_id), 0.5)
  607. selected_indices = (category_boxes[:, 1] > category_threshold) & (
  608. category_boxes[:, 0] > -1
  609. )
  610. category_filtered_boxes.append(category_boxes[selected_indices])
  611. boxes = (
  612. np.vstack(category_filtered_boxes)
  613. if category_filtered_boxes
  614. else np.array([])
  615. )
  616. if layout_nms:
  617. selected_indices = nms(boxes[:, :6], iou_same=0.6, iou_diff=0.98)
  618. boxes = np.array(boxes[selected_indices])
  619. filter_large_image = True
  620. # boxes.shape[1] == 6 is object detection, 8 is ordered object detection
  621. if filter_large_image and len(boxes) > 1 and boxes.shape[1] in [6, 8]:
  622. if img_size[0] > img_size[1]:
  623. area_thres = 0.82
  624. else:
  625. area_thres = 0.93
  626. image_index = self.labels.index("image") if "image" in self.labels else None
  627. img_area = img_size[0] * img_size[1]
  628. filtered_boxes = []
  629. for box in boxes:
  630. (
  631. label_index,
  632. score,
  633. xmin,
  634. ymin,
  635. xmax,
  636. ymax,
  637. ) = box[:6]
  638. if label_index == image_index:
  639. xmin = max(0, xmin)
  640. ymin = max(0, ymin)
  641. xmax = min(img_size[0], xmax)
  642. ymax = min(img_size[1], ymax)
  643. box_area = (xmax - xmin) * (ymax - ymin)
  644. if box_area <= area_thres * img_area:
  645. filtered_boxes.append(box)
  646. else:
  647. filtered_boxes.append(box)
  648. if len(filtered_boxes) == 0:
  649. filtered_boxes = boxes
  650. boxes = np.array(filtered_boxes)
  651. if layout_merge_bboxes_mode:
  652. formula_index = (
  653. self.labels.index("formula") if "formula" in self.labels else None
  654. )
  655. if isinstance(layout_merge_bboxes_mode, str):
  656. assert layout_merge_bboxes_mode in [
  657. "union",
  658. "large",
  659. "small",
  660. ], f"The value of `layout_merge_bboxes_mode` must be one of ['union', 'large', 'small'], but got {layout_merge_bboxes_mode}"
  661. if layout_merge_bboxes_mode == "union":
  662. pass
  663. else:
  664. contains_other, contained_by_other = check_containment(
  665. boxes[:, :6], formula_index
  666. )
  667. if layout_merge_bboxes_mode == "large":
  668. boxes = boxes[contained_by_other == 0]
  669. elif layout_merge_bboxes_mode == "small":
  670. boxes = boxes[(contains_other == 0) | (contained_by_other == 1)]
  671. elif isinstance(layout_merge_bboxes_mode, dict):
  672. keep_mask = np.ones(len(boxes), dtype=bool)
  673. for category_index, layout_mode in layout_merge_bboxes_mode.items():
  674. assert layout_mode in [
  675. "union",
  676. "large",
  677. "small",
  678. ], f"The value of `layout_merge_bboxes_mode` must be one of ['union', 'large', 'small'], but got {layout_mode}"
  679. if layout_mode == "union":
  680. pass
  681. else:
  682. if layout_mode == "large":
  683. contains_other, contained_by_other = check_containment(
  684. boxes[:, :6],
  685. formula_index,
  686. category_index,
  687. mode=layout_mode,
  688. )
  689. # Remove boxes that are contained by other boxes
  690. keep_mask &= contained_by_other == 0
  691. elif layout_mode == "small":
  692. contains_other, contained_by_other = check_containment(
  693. boxes[:, :6],
  694. formula_index,
  695. category_index,
  696. mode=layout_mode,
  697. )
  698. # Keep boxes that do not contain others or are contained by others
  699. keep_mask &= (contains_other == 0) | (
  700. contained_by_other == 1
  701. )
  702. boxes = boxes[keep_mask]
  703. if boxes.size == 0:
  704. return np.array([])
  705. if boxes.shape[1] == 8:
  706. # Sort boxes by their order
  707. sorted_idx = np.lexsort((-boxes[:, 7], boxes[:, 6]))
  708. sorted_boxes = boxes[sorted_idx]
  709. boxes = sorted_boxes[:, :6]
  710. if layout_unclip_ratio:
  711. if isinstance(layout_unclip_ratio, float):
  712. layout_unclip_ratio = (layout_unclip_ratio, layout_unclip_ratio)
  713. elif isinstance(layout_unclip_ratio, (tuple, list)):
  714. assert (
  715. len(layout_unclip_ratio) == 2
  716. ), f"The length of `layout_unclip_ratio` should be 2."
  717. elif isinstance(layout_unclip_ratio, dict):
  718. pass
  719. else:
  720. raise ValueError(
  721. f"The type of `layout_unclip_ratio` must be float, Tuple[float, float] or Dict[int, Tuple[float, float]], but got {type(layout_unclip_ratio)}."
  722. )
  723. boxes = unclip_boxes(boxes, layout_unclip_ratio)
  724. if boxes.shape[1] == 6:
  725. """For Normal Object Detection"""
  726. boxes = restructured_boxes(boxes, self.labels, img_size)
  727. elif boxes.shape[1] == 10:
  728. """Adapt For Rotated Object Detection"""
  729. boxes = restructured_rotated_boxes(boxes, self.labels, img_size)
  730. else:
  731. """Unexpected Input Box Shape"""
  732. raise ValueError(
  733. f"The shape of boxes should be 6 or 10, instead of {boxes.shape[1]}"
  734. )
  735. return boxes
  736. def __call__(
  737. self,
  738. batch_outputs: List[dict],
  739. datas: List[dict],
  740. threshold: Optional[Union[float, dict]] = None,
  741. layout_nms: Optional[bool] = None,
  742. layout_unclip_ratio: Optional[Union[float, Tuple[float, float]]] = None,
  743. layout_merge_bboxes_mode: Optional[str] = None,
  744. ) -> List[Boxes]:
  745. """Apply the post-processing to a batch of outputs.
  746. Args:
  747. batch_outputs (List[dict]): The list of detection outputs.
  748. datas (List[dict]): The list of input data.
  749. Returns:
  750. List[Boxes]: The list of post-processed detection boxes.
  751. """
  752. outputs = []
  753. for data, output in zip(datas, batch_outputs):
  754. boxes = self.apply(
  755. output["boxes"],
  756. data["ori_img_size"],
  757. threshold,
  758. layout_nms,
  759. layout_unclip_ratio,
  760. layout_merge_bboxes_mode,
  761. )
  762. outputs.append(boxes)
  763. return outputs