image_common.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. from pathlib import Path
  16. import numpy as np
  17. import cv2
  18. from .....utils.download import download
  19. from .....utils.cache import CACHE_DIR
  20. from ..transform import BaseTransform
  21. from ..io.readers import ImageReader
  22. from ..io.writers import ImageWriter
  23. from . import image_functions as F
  24. __all__ = [
  25. "ReadImage",
  26. "Flip",
  27. "Crop",
  28. "Resize",
  29. "ResizeByLong",
  30. "ResizeByShort",
  31. "Pad",
  32. "Normalize",
  33. "ToCHWImage",
  34. ]
  35. def _check_image_size(input_):
  36. """check image size"""
  37. if not (
  38. isinstance(input_, (list, tuple))
  39. and len(input_) == 2
  40. and isinstance(input_[0], int)
  41. and isinstance(input_[1], int)
  42. ):
  43. raise TypeError(f"{input_} cannot represent a valid image size.")
  44. class ReadImage(BaseTransform):
  45. """Load image from the file."""
  46. _FLAGS_DICT = {
  47. "BGR": cv2.IMREAD_COLOR,
  48. "RGB": cv2.IMREAD_COLOR,
  49. "GRAY": cv2.IMREAD_GRAYSCALE,
  50. }
  51. def __init__(self, format="BGR"):
  52. """
  53. Initialize the instance.
  54. Args:
  55. format (str, optional): Target color format to convert the image to.
  56. Choices are 'BGR', 'RGB', and 'GRAY'. Default: 'BGR'.
  57. """
  58. super().__init__()
  59. self.format = format
  60. flags = self._FLAGS_DICT[self.format]
  61. self._reader = ImageReader(backend="opencv", flags=flags)
  62. self._writer = ImageWriter(backend="opencv")
  63. def apply(self, data):
  64. """apply"""
  65. if "image" in data:
  66. img = data["image"]
  67. img_path = (Path(CACHE_DIR) / "predict_input" / "tmp_img.jpg").as_posix()
  68. self._writer.write(img_path, img)
  69. data["input_path"] = img_path
  70. data["original_image"] = img
  71. data["original_image_size"] = [img.shape[1], img.shape[0]]
  72. return data
  73. elif "input_path" not in data:
  74. raise KeyError(f"Key {repr('input_path')} is required, but not found.")
  75. im_path = data["input_path"]
  76. # XXX: auto download for url
  77. im_path = self._download_from_url(im_path)
  78. blob = self._reader.read(im_path)
  79. if self.format == "RGB":
  80. if blob.ndim != 3:
  81. raise RuntimeError("Array is not 3-dimensional.")
  82. # BGR to RGB
  83. blob = blob[..., ::-1]
  84. data["input_path"] = im_path
  85. data["image"] = blob
  86. data["original_image"] = blob
  87. data["original_image_size"] = [blob.shape[1], blob.shape[0]]
  88. return data
  89. def _download_from_url(self, in_path):
  90. if in_path.startswith("http"):
  91. file_name = Path(in_path).name
  92. save_path = Path(CACHE_DIR) / "predict_input" / file_name
  93. download(in_path, save_path, overwrite=True)
  94. return save_path.as_posix()
  95. return in_path
  96. @classmethod
  97. def get_input_keys(cls):
  98. """get input keys"""
  99. # input_path: Path of the image.
  100. return [["input_path"], ["image"]]
  101. @classmethod
  102. def get_output_keys(cls):
  103. """get output keys"""
  104. # image: Image in hw or hwc format.
  105. # original_image: Original image in hw or hwc format.
  106. # original_image_size: Width and height of the original image.
  107. return ["image", "original_image", "original_image_size"]
  108. class GetImageInfo(BaseTransform):
  109. """Get Image Info"""
  110. def __init__(self):
  111. super().__init__()
  112. def apply(self, data):
  113. """apply"""
  114. blob = data["image"]
  115. data["original_image"] = blob
  116. data["original_image_size"] = [blob.shape[1], blob.shape[0]]
  117. return data
  118. @classmethod
  119. def get_input_keys(cls):
  120. """get input keys"""
  121. # input_path: Path of the image.
  122. return ["image"]
  123. @classmethod
  124. def get_output_keys(cls):
  125. """get output keys"""
  126. # image: Image in hw or hwc format.
  127. # original_image: Original image in hw or hwc format.
  128. # original_image_size: Width and height of the original image.
  129. return ["original_image", "original_image_size"]
  130. class Flip(BaseTransform):
  131. """Flip the image vertically or horizontally."""
  132. def __init__(self, mode="H"):
  133. """
  134. Initialize the instance.
  135. Args:
  136. mode (str, optional): 'H' for horizontal flipping and 'V' for vertical
  137. flipping. Default: 'H'.
  138. """
  139. super().__init__()
  140. if mode not in ("H", "V"):
  141. raise ValueError("`mode` should be 'H' or 'V'.")
  142. self.mode = mode
  143. def apply(self, data):
  144. """apply"""
  145. im = data["image"]
  146. if self.mode == "H":
  147. im = F.flip_h(im)
  148. elif self.mode == "V":
  149. im = F.flip_v(im)
  150. data["image"] = im
  151. return data
  152. @classmethod
  153. def get_input_keys(cls):
  154. """get input keys"""
  155. # image: Image in hw or hwc format.
  156. return ["image"]
  157. @classmethod
  158. def get_output_keys(cls):
  159. """get output keys"""
  160. # image: Image in hw or hwc format.
  161. return ["image"]
  162. class Crop(BaseTransform):
  163. """Crop region from the image."""
  164. def __init__(self, crop_size, mode="C"):
  165. """
  166. Initialize the instance.
  167. Args:
  168. crop_size (list|tuple|int): Width and height of the region to crop.
  169. mode (str, optional): 'C' for cropping the center part and 'TL' for
  170. cropping the top left part. Default: 'C'.
  171. """
  172. super().__init__()
  173. if isinstance(crop_size, int):
  174. crop_size = [crop_size, crop_size]
  175. _check_image_size(crop_size)
  176. self.crop_size = crop_size
  177. if mode not in ("C", "TL"):
  178. raise ValueError("Unsupported interpolation method")
  179. self.mode = mode
  180. def apply(self, data):
  181. """apply"""
  182. im = data["image"]
  183. h, w = im.shape[:2]
  184. cw, ch = self.crop_size
  185. if self.mode == "C":
  186. x1 = max(0, (w - cw) // 2)
  187. y1 = max(0, (h - ch) // 2)
  188. elif self.mode == "TL":
  189. x1, y1 = 0, 0
  190. x2 = min(w, x1 + cw)
  191. y2 = min(h, y1 + ch)
  192. coords = (x1, y1, x2, y2)
  193. if coords == (0, 0, w, h):
  194. raise ValueError(
  195. f"Input image ({w}, {h}) smaller than the target size ({cw}, {ch})."
  196. )
  197. im = F.slice(im, coords=coords)
  198. data["image"] = im
  199. data["image_size"] = [im.shape[1], im.shape[0]]
  200. return data
  201. @classmethod
  202. def get_input_keys(cls):
  203. """get input keys"""
  204. # image: Image in hw or hwc format.
  205. return ["image"]
  206. @classmethod
  207. def get_output_keys(cls):
  208. """get output keys"""
  209. # image: Image in hw or hwc format.
  210. # image_size: Width and height of the image.
  211. return ["image", "image_size"]
  212. class _BaseResize(BaseTransform):
  213. _INTERP_DICT = {
  214. "NEAREST": cv2.INTER_NEAREST,
  215. "LINEAR": cv2.INTER_LINEAR,
  216. "CUBIC": cv2.INTER_CUBIC,
  217. "AREA": cv2.INTER_AREA,
  218. "LANCZOS4": cv2.INTER_LANCZOS4,
  219. }
  220. def __init__(self, size_divisor, interp):
  221. super().__init__()
  222. if size_divisor is not None:
  223. assert isinstance(
  224. size_divisor, int
  225. ), "`size_divisor` should be None or int."
  226. self.size_divisor = size_divisor
  227. try:
  228. interp = self._INTERP_DICT[interp]
  229. except KeyError:
  230. raise ValueError(
  231. "`interp` should be one of {}.".format(self._INTERP_DICT.keys())
  232. )
  233. self.interp = interp
  234. @staticmethod
  235. def _rescale_size(img_size, target_size):
  236. """rescale size"""
  237. scale = min(max(target_size) / max(img_size), min(target_size) / min(img_size))
  238. rescaled_size = [round(i * scale) for i in img_size]
  239. return rescaled_size, scale
  240. class Resize(_BaseResize):
  241. """Resize the image."""
  242. def __init__(
  243. self, target_size, keep_ratio=False, size_divisor=None, interp="LINEAR"
  244. ):
  245. """
  246. Initialize the instance.
  247. Args:
  248. target_size (list|tuple|int): Target width and height.
  249. keep_ratio (bool, optional): Whether to keep the aspect ratio of resized
  250. image. Default: False.
  251. size_divisor (int|None, optional): Divisor of resized image size.
  252. Default: None.
  253. interp (str, optional): Interpolation method. Choices are 'NEAREST',
  254. 'LINEAR', 'CUBIC', 'AREA', and 'LANCZOS4'. Default: 'LINEAR'.
  255. """
  256. super().__init__(size_divisor=size_divisor, interp=interp)
  257. if isinstance(target_size, int):
  258. target_size = [target_size, target_size]
  259. _check_image_size(target_size)
  260. self.target_size = target_size
  261. self.keep_ratio = keep_ratio
  262. def apply(self, data):
  263. """apply"""
  264. target_size = self.target_size
  265. im = data["image"]
  266. original_size = im.shape[:2]
  267. if self.keep_ratio:
  268. h, w = im.shape[0:2]
  269. target_size, _ = self._rescale_size((w, h), self.target_size)
  270. if self.size_divisor:
  271. target_size = [
  272. math.ceil(i / self.size_divisor) * self.size_divisor
  273. for i in target_size
  274. ]
  275. im_scale_w, im_scale_h = [
  276. target_size[1] / original_size[1],
  277. target_size[0] / original_size[0],
  278. ]
  279. im = F.resize(im, target_size, interp=self.interp)
  280. data["image"] = im
  281. data["image_size"] = [im.shape[1], im.shape[0]]
  282. data["scale_factors"] = [im_scale_w, im_scale_h]
  283. return data
  284. @classmethod
  285. def get_input_keys(cls):
  286. """get input keys"""
  287. # image: Image in hw or hwc format.
  288. return ["image"]
  289. @classmethod
  290. def get_output_keys(cls):
  291. """get output keys"""
  292. # image: Image in hw or hwc format.
  293. # image_size: Width and height of the image.
  294. # scale_factors: Scale factors for image width and height.
  295. return ["image", "image_size", "scale_factors"]
  296. class ResizeByLong(_BaseResize):
  297. """
  298. Proportionally resize the image by specifying the target length of the
  299. longest side.
  300. """
  301. def __init__(self, target_long_edge, size_divisor=None, interp="LINEAR"):
  302. """
  303. Initialize the instance.
  304. Args:
  305. target_long_edge (int): Target length of the longest side of image.
  306. size_divisor (int|None, optional): Divisor of resized image size.
  307. Default: None.
  308. interp (str, optional): Interpolation method. Choices are 'NEAREST',
  309. 'LINEAR', 'CUBIC', 'AREA', and 'LANCZOS4'. Default: 'LINEAR'.
  310. """
  311. super().__init__(size_divisor=size_divisor, interp=interp)
  312. self.target_long_edge = target_long_edge
  313. def apply(self, data):
  314. """apply"""
  315. im = data["image"]
  316. h, w = im.shape[:2]
  317. scale = self.target_long_edge / max(h, w)
  318. h_resize = round(h * scale)
  319. w_resize = round(w * scale)
  320. if self.size_divisor is not None:
  321. h_resize = math.ceil(h_resize / self.size_divisor) * self.size_divisor
  322. w_resize = math.ceil(w_resize / self.size_divisor) * self.size_divisor
  323. im = F.resize(im, (w_resize, h_resize), interp=self.interp)
  324. data["image"] = im
  325. data["image_size"] = [im.shape[1], im.shape[0]]
  326. return data
  327. @classmethod
  328. def get_input_keys(cls):
  329. """get input keys"""
  330. # image: Image in hw or hwc format.
  331. return ["image"]
  332. @classmethod
  333. def get_output_keys(cls):
  334. """get output keys"""
  335. # image: Image in hw or hwc format.
  336. # image_size: Width and height of the image.
  337. return ["image", "image_size"]
  338. class ResizeByShort(_BaseResize):
  339. """
  340. Proportionally resize the image by specifying the target length of the
  341. shortest side.
  342. """
  343. def __init__(self, target_short_edge, size_divisor=None, interp="LINEAR"):
  344. """
  345. Initialize the instance.
  346. Args:
  347. target_short_edge (int): Target length of the shortest side of image.
  348. size_divisor (int|None, optional): Divisor of resized image size.
  349. Default: None.
  350. interp (str, optional): Interpolation method. Choices are 'NEAREST',
  351. 'LINEAR', 'CUBIC', 'AREA', and 'LANCZOS4'. Default: 'LINEAR'.
  352. """
  353. super().__init__(size_divisor=size_divisor, interp=interp)
  354. self.target_short_edge = target_short_edge
  355. def apply(self, data):
  356. """apply"""
  357. im = data["image"]
  358. h, w = im.shape[:2]
  359. scale = self.target_short_edge / min(h, w)
  360. h_resize = round(h * scale)
  361. w_resize = round(w * scale)
  362. if self.size_divisor is not None:
  363. h_resize = math.ceil(h_resize / self.size_divisor) * self.size_divisor
  364. w_resize = math.ceil(w_resize / self.size_divisor) * self.size_divisor
  365. im = F.resize(im, (w_resize, h_resize), interp=self.interp)
  366. data["image"] = im
  367. data["image_size"] = [im.shape[1], im.shape[0]]
  368. return data
  369. @classmethod
  370. def get_input_keys(cls):
  371. """get input keys"""
  372. # image: Image in hw or hwc format.
  373. return ["image"]
  374. @classmethod
  375. def get_output_keys(cls):
  376. """get output keys"""
  377. # image: Image in hw or hwc format.
  378. # image_size: Width and height of the image.
  379. return ["image", "image_size"]
  380. class Pad(BaseTransform):
  381. """Pad the image."""
  382. def __init__(self, target_size, val=127.5):
  383. """
  384. Initialize the instance.
  385. Args:
  386. target_size (list|tuple|int): Target width and height of the image after
  387. padding.
  388. val (float, optional): Value to fill the padded area. Default: 127.5.
  389. """
  390. super().__init__()
  391. if isinstance(target_size, int):
  392. target_size = [target_size, target_size]
  393. _check_image_size(target_size)
  394. self.target_size = target_size
  395. self.val = val
  396. def apply(self, data):
  397. """apply"""
  398. im = data["image"]
  399. h, w = im.shape[:2]
  400. tw, th = self.target_size
  401. ph = th - h
  402. pw = tw - w
  403. if ph < 0 or pw < 0:
  404. raise ValueError(
  405. f"Input image ({w}, {h}) smaller than the target size ({tw}, {th})."
  406. )
  407. else:
  408. im = F.pad(im, pad=(0, ph, 0, pw), val=self.val)
  409. data["image"] = im
  410. data["image_size"] = [im.shape[1], im.shape[0]]
  411. return data
  412. @classmethod
  413. def get_input_keys(cls):
  414. """get input keys"""
  415. # image: Image in hw or hwc format.
  416. return ["image"]
  417. @classmethod
  418. def get_output_keys(cls):
  419. """get output keys"""
  420. # image: Image in hw or hwc format.
  421. # image_size: Width and height of the image.
  422. return ["image", "image_size"]
  423. class Normalize(BaseTransform):
  424. """Normalize the image."""
  425. def __init__(self, scale=1.0 / 255, mean=0.5, std=0.5, preserve_dtype=False):
  426. """
  427. Initialize the instance.
  428. Args:
  429. scale (float, optional): Scaling factor to apply to the image before
  430. applying normalization. Default: 1/255.
  431. mean (float|tuple|list, optional): Means for each channel of the image.
  432. Default: 0.5.
  433. std (float|tuple|list, optional): Standard deviations for each channel
  434. of the image. Default: 0.5.
  435. preserve_dtype (bool, optional): Whether to preserve the original dtype
  436. of the image.
  437. """
  438. super().__init__()
  439. self.scale = np.float32(scale)
  440. if isinstance(mean, float):
  441. mean = [mean]
  442. self.mean = np.asarray(mean).astype("float32")
  443. if isinstance(std, float):
  444. std = [std]
  445. self.std = np.asarray(std).astype("float32")
  446. self.preserve_dtype = preserve_dtype
  447. def apply(self, data):
  448. """apply"""
  449. im = data["image"]
  450. old_type = im.dtype
  451. # XXX: If `old_type` has higher precision than float32,
  452. # we will lose some precision.
  453. im = im.astype("float32", copy=False)
  454. im *= self.scale
  455. im -= self.mean
  456. im /= self.std
  457. if self.preserve_dtype:
  458. im = im.astype(old_type, copy=False)
  459. data["image"] = im
  460. return data
  461. @classmethod
  462. def get_input_keys(cls):
  463. """get input keys"""
  464. # image: Image in hw or hwc format.
  465. return ["image"]
  466. @classmethod
  467. def get_output_keys(cls):
  468. """get output keys"""
  469. # image: Image in hw or hwc format.
  470. return ["image"]
  471. class ToCHWImage(BaseTransform):
  472. """Reorder the dimensions of the image from HWC to CHW."""
  473. def apply(self, data):
  474. """apply"""
  475. im = data["image"]
  476. im = im.transpose((2, 0, 1))
  477. data["image"] = im
  478. return data
  479. @classmethod
  480. def get_input_keys(cls):
  481. """get input keys"""
  482. # image: Image in hwc format.
  483. return ["image"]
  484. @classmethod
  485. def get_output_keys(cls):
  486. """get output keys"""
  487. # image: Image in chw format.
  488. return ["image"]
  489. def rotate_point(pt, angle_rad):
  490. """Rotate a point by an angle.
  491. Args:
  492. pt (list[float]): 2 dimensional point to be rotated
  493. angle_rad (float): rotation angle by radian
  494. Returns:
  495. list[float]: Rotated point.
  496. """
  497. assert len(pt) == 2
  498. sn, cs = np.sin(angle_rad), np.cos(angle_rad)
  499. new_x = pt[0] * cs - pt[1] * sn
  500. new_y = pt[0] * sn + pt[1] * cs
  501. rotated_pt = [new_x, new_y]
  502. return rotated_pt
  503. def _get_3rd_point(a, b):
  504. """To calculate the affine matrix, three pairs of points are required. This
  505. function is used to get the 3rd point, given 2D points a & b.
  506. The 3rd point is defined by rotating vector `a - b` by 90 degrees
  507. anticlockwise, using b as the rotation center.
  508. Args:
  509. a (np.ndarray): point(x,y)
  510. b (np.ndarray): point(x,y)
  511. Returns:
  512. np.ndarray: The 3rd point.
  513. """
  514. assert len(a) == 2
  515. assert len(b) == 2
  516. direction = a - b
  517. third_pt = b + np.array([-direction[1], direction[0]], dtype=np.float32)
  518. return third_pt
  519. def get_affine_transform(center,
  520. input_size,
  521. rot,
  522. output_size,
  523. shift=(0., 0.),
  524. inv=False):
  525. """Get the affine transform matrix, given the center/scale/rot/output_size.
  526. Args:
  527. center (np.ndarray[2, ]): Center of the bounding box (x, y).
  528. scale (np.ndarray[2, ]): Scale of the bounding box
  529. wrt [width, height].
  530. rot (float): Rotation angle (degree).
  531. output_size (np.ndarray[2, ]): Size of the destination heatmaps.
  532. shift (0-100%): Shift translation ratio wrt the width/height.
  533. Default (0., 0.).
  534. inv (bool): Option to inverse the affine transform direction.
  535. (inv=False: src->dst or inv=True: dst->src)
  536. Returns:
  537. np.ndarray: The transform matrix.
  538. """
  539. assert len(center) == 2
  540. assert len(output_size) == 2
  541. assert len(shift) == 2
  542. if not isinstance(input_size, (np.ndarray, list)):
  543. input_size = np.array([input_size, input_size], dtype=np.float32)
  544. scale_tmp = input_size
  545. shift = np.array(shift)
  546. src_w = scale_tmp[0]
  547. dst_w = output_size[0]
  548. dst_h = output_size[1]
  549. rot_rad = np.pi * rot / 180
  550. src_dir = rotate_point([0., src_w * -0.5], rot_rad)
  551. dst_dir = np.array([0., dst_w * -0.5])
  552. src = np.zeros((3, 2), dtype=np.float32)
  553. src[0, :] = center + scale_tmp * shift
  554. src[1, :] = center + src_dir + scale_tmp * shift
  555. src[2, :] = _get_3rd_point(src[0, :], src[1, :])
  556. dst = np.zeros((3, 2), dtype=np.float32)
  557. dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
  558. dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
  559. dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
  560. if inv:
  561. trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
  562. else:
  563. trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
  564. return trans
  565. class WarpAffine(object):
  566. """Warp affine the image
  567. """
  568. def __init__(self,
  569. keep_res=False,
  570. pad=31,
  571. input_h=512,
  572. input_w=512,
  573. scale=0.4,
  574. shift=0.1,
  575. down_ratio=4):
  576. self.keep_res = keep_res
  577. self.pad = pad
  578. self.input_h = input_h
  579. self.input_w = input_w
  580. self.scale = scale
  581. self.shift = shift
  582. self.down_ratio = down_ratio
  583. def __call__(self, data):
  584. im = data['image']
  585. img = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
  586. h, w = img.shape[:2]
  587. if self.keep_res:
  588. # True in detection eval/infer
  589. input_h = (h | self.pad) + 1
  590. input_w = (w | self.pad) + 1
  591. s = np.array([input_w, input_h], dtype=np.float32)
  592. c = np.array([w // 2, h // 2], dtype=np.float32)
  593. else:
  594. # False in centertrack eval_mot/eval_mot
  595. s = max(h, w) * 1.0
  596. input_h, input_w = self.input_h, self.input_w
  597. c = np.array([w / 2., h / 2.], dtype=np.float32)
  598. trans_input = get_affine_transform(c, s, 0, [input_w, input_h])
  599. img = cv2.resize(img, (w, h))
  600. inp = cv2.warpAffine(
  601. img, trans_input, (input_w, input_h), flags=cv2.INTER_LINEAR)
  602. if not self.keep_res:
  603. out_h = input_h // self.down_ratio
  604. out_w = input_w // self.down_ratio
  605. trans_output = get_affine_transform(c, s, 0, [out_w, out_h])
  606. data['image'] = inp
  607. im_scale_w, im_scale_h = [
  608. input_w / w, input_h / h
  609. ]
  610. data['image_size'] = [inp.shape[1], inp.shape[0]]
  611. data['scale_factors'] = [im_scale_w, im_scale_h]
  612. return data
  613. @classmethod
  614. def get_input_keys(cls):
  615. """ get input keys """
  616. # image: Image in hwc format.
  617. return ['image']
  618. @classmethod
  619. def get_output_keys(cls):
  620. """ get output keys """
  621. # image: Image in chw format.
  622. return ["image"]