image_common.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585
  1. # !/usr/bin/env python3
  2. # -*- coding: UTF-8 -*-
  3. ################################################################################
  4. #
  5. # Copyright (c) 2024 Baidu.com, Inc. All Rights Reserved
  6. #
  7. ################################################################################
  8. """
  9. Author: PaddlePaddle Authors
  10. """
  11. import math
  12. from pathlib import Path
  13. import numpy as np
  14. import cv2
  15. from .....utils.download import download
  16. from .....utils.cache import CACHE_DIR
  17. from ..transform import BaseTransform
  18. from ..io.readers import ImageReader
  19. from ..io.writers import ImageWriter
  20. from . import image_functions as F
  21. __all__ = [
  22. 'ReadImage', 'Flip', 'Crop', 'Resize', 'ResizeByLong', 'ResizeByShort',
  23. 'Pad', 'Normalize', 'ToCHWImage'
  24. ]
  25. def _check_image_size(input_):
  26. """ check image size """
  27. if not (isinstance(input_, (list, tuple)) and len(input_) == 2 and
  28. isinstance(input_[0], int) and isinstance(input_[1], int)):
  29. raise TypeError(f"{input_} cannot represent a valid image size.")
  30. class ReadImage(BaseTransform):
  31. """Load image from the file."""
  32. _FLAGS_DICT = {
  33. 'BGR': cv2.IMREAD_COLOR,
  34. 'RGB': cv2.IMREAD_COLOR,
  35. 'GRAY': cv2.IMREAD_GRAYSCALE
  36. }
  37. def __init__(self, format='BGR'):
  38. """
  39. Initialize the instance.
  40. Args:
  41. format (str, optional): Target color format to convert the image to.
  42. Choices are 'BGR', 'RGB', and 'GRAY'. Default: 'BGR'.
  43. """
  44. super().__init__()
  45. self.format = format
  46. flags = self._FLAGS_DICT[self.format]
  47. self._reader = ImageReader(backend='opencv', flags=flags)
  48. self._writer = ImageWriter(backend='opencv')
  49. def apply(self, data):
  50. """ apply """
  51. if 'image' in data:
  52. img = data['image']
  53. img_path = (Path(CACHE_DIR) / "predict_input" /
  54. "tmp_img.jpg").as_posix()
  55. self._writer.write(img_path, img)
  56. data['input_path'] = img_path
  57. data['original_image'] = img
  58. data['original_image_size'] = [img.shape[1], img.shape[0]]
  59. return data
  60. elif 'input_path' not in data:
  61. raise KeyError(
  62. f"Key {repr('input_path')} is required, but not found.")
  63. im_path = data['input_path']
  64. # XXX: auto download for url
  65. im_path = self._download_from_url(im_path)
  66. blob = self._reader.read(im_path)
  67. if self.format == 'RGB':
  68. if blob.ndim != 3:
  69. raise RuntimeError("Array is not 3-dimensional.")
  70. # BGR to RGB
  71. blob = blob[..., ::-1]
  72. data['input_path'] = im_path
  73. data['image'] = blob
  74. data['original_image'] = blob
  75. data['original_image_size'] = [blob.shape[1], blob.shape[0]]
  76. return data
  77. def _download_from_url(self, in_path):
  78. if in_path.startswith("http"):
  79. file_name = Path(in_path).name
  80. save_path = Path(CACHE_DIR) / "predict_input" / file_name
  81. download(in_path, save_path, overwrite=True)
  82. return save_path.as_posix()
  83. return in_path
  84. @classmethod
  85. def get_input_keys(cls):
  86. """ get input keys """
  87. # input_path: Path of the image.
  88. return [['input_path'], ['image']]
  89. @classmethod
  90. def get_output_keys(cls):
  91. """ get output keys """
  92. # image: Image in hw or hwc format.
  93. # original_image: Original image in hw or hwc format.
  94. # original_image_size: Width and height of the original image.
  95. return ['image', 'original_image', 'original_image_size']
  96. class GetImageInfo(BaseTransform):
  97. """Get Image Info
  98. """
  99. def __init__(self):
  100. super().__init__()
  101. def apply(self, data):
  102. """ apply """
  103. blob = data['image']
  104. data['original_image'] = blob
  105. data['original_image_size'] = [blob.shape[1], blob.shape[0]]
  106. return data
  107. @classmethod
  108. def get_input_keys(cls):
  109. """ get input keys """
  110. # input_path: Path of the image.
  111. return ['image']
  112. @classmethod
  113. def get_output_keys(cls):
  114. """ get output keys """
  115. # image: Image in hw or hwc format.
  116. # original_image: Original image in hw or hwc format.
  117. # original_image_size: Width and height of the original image.
  118. return ['original_image', 'original_image_size']
  119. class Flip(BaseTransform):
  120. """Flip the image vertically or horizontally."""
  121. def __init__(self, mode='H'):
  122. """
  123. Initialize the instance.
  124. Args:
  125. mode (str, optional): 'H' for horizontal flipping and 'V' for vertical
  126. flipping. Default: 'H'.
  127. """
  128. super().__init__()
  129. if mode not in ('H', 'V'):
  130. raise ValueError("`mode` should be 'H' or 'V'.")
  131. self.mode = mode
  132. def apply(self, data):
  133. """ apply """
  134. im = data['image']
  135. if self.mode == 'H':
  136. im = F.flip_h(im)
  137. elif self.mode == 'V':
  138. im = F.flip_v(im)
  139. data['image'] = im
  140. return data
  141. @classmethod
  142. def get_input_keys(cls):
  143. """ get input keys """
  144. # image: Image in hw or hwc format.
  145. return ['image']
  146. @classmethod
  147. def get_output_keys(cls):
  148. """ get output keys """
  149. # image: Image in hw or hwc format.
  150. return ['image']
  151. class Crop(BaseTransform):
  152. """Crop region from the image."""
  153. def __init__(self, crop_size, mode='C'):
  154. """
  155. Initialize the instance.
  156. Args:
  157. crop_size (list|tuple|int): Width and height of the region to crop.
  158. mode (str, optional): 'C' for cropping the center part and 'TL' for
  159. cropping the top left part. Default: 'C'.
  160. """
  161. super().__init__()
  162. if isinstance(crop_size, int):
  163. crop_size = [crop_size, crop_size]
  164. _check_image_size(crop_size)
  165. self.crop_size = crop_size
  166. if mode not in ('C', 'TL'):
  167. raise ValueError("Unsupported interpolation method")
  168. self.mode = mode
  169. def apply(self, data):
  170. """ apply """
  171. im = data['image']
  172. h, w = im.shape[:2]
  173. cw, ch = self.crop_size
  174. if self.mode == 'C':
  175. x1 = max(0, (w - cw) // 2)
  176. y1 = max(0, (h - ch) // 2)
  177. elif self.mode == 'TL':
  178. x1, y1 = 0, 0
  179. x2 = min(w, x1 + cw)
  180. y2 = min(h, y1 + ch)
  181. coords = (x1, y1, x2, y2)
  182. if coords == (0, 0, w, h):
  183. raise ValueError(
  184. f"Input image ({w}, {h}) smaller than the target size ({cw}, {ch})."
  185. )
  186. im = F.slice(im, coords=coords)
  187. data['image'] = im
  188. data['image_size'] = [im.shape[1], im.shape[0]]
  189. return data
  190. @classmethod
  191. def get_input_keys(cls):
  192. """ get input keys """
  193. # image: Image in hw or hwc format.
  194. return ['image']
  195. @classmethod
  196. def get_output_keys(cls):
  197. """ get output keys """
  198. # image: Image in hw or hwc format.
  199. # image_size: Width and height of the image.
  200. return ['image', 'image_size']
  201. class _BaseResize(BaseTransform):
  202. _INTERP_DICT = {
  203. 'NEAREST': cv2.INTER_NEAREST,
  204. 'LINEAR': cv2.INTER_LINEAR,
  205. 'CUBIC': cv2.INTER_CUBIC,
  206. 'AREA': cv2.INTER_AREA,
  207. 'LANCZOS4': cv2.INTER_LANCZOS4
  208. }
  209. def __init__(self, size_divisor, interp):
  210. super().__init__()
  211. if size_divisor is not None:
  212. assert isinstance(size_divisor,
  213. int), "`size_divisor` should be None or int."
  214. self.size_divisor = size_divisor
  215. try:
  216. interp = self._INTERP_DICT[interp]
  217. except KeyError:
  218. raise ValueError("`interp` should be one of {}.".format(
  219. self._INTERP_DICT.keys()))
  220. self.interp = interp
  221. @staticmethod
  222. def _rescale_size(img_size, target_size):
  223. """ rescale size """
  224. scale = min(
  225. max(target_size) / max(img_size), min(target_size) / min(img_size))
  226. rescaled_size = [round(i * scale) for i in img_size]
  227. return rescaled_size, scale
  228. class Resize(_BaseResize):
  229. """Resize the image."""
  230. def __init__(self,
  231. target_size,
  232. keep_ratio=False,
  233. size_divisor=None,
  234. interp='LINEAR'):
  235. """
  236. Initialize the instance.
  237. Args:
  238. target_size (list|tuple|int): Target width and height.
  239. keep_ratio (bool, optional): Whether to keep the aspect ratio of resized
  240. image. Default: False.
  241. size_divisor (int|None, optional): Divisor of resized image size.
  242. Default: None.
  243. interp (str, optional): Interpolation method. Choices are 'NEAREST',
  244. 'LINEAR', 'CUBIC', 'AREA', and 'LANCZOS4'. Default: 'LINEAR'.
  245. """
  246. super().__init__(size_divisor=size_divisor, interp=interp)
  247. if isinstance(target_size, int):
  248. target_size = [target_size, target_size]
  249. _check_image_size(target_size)
  250. self.target_size = target_size
  251. self.keep_ratio = keep_ratio
  252. def apply(self, data):
  253. """ apply """
  254. target_size = self.target_size
  255. im = data['image']
  256. original_size = im.shape[:2]
  257. if self.keep_ratio:
  258. h, w = im.shape[0:2]
  259. target_size, _ = self._rescale_size((w, h), self.target_size)
  260. if self.size_divisor:
  261. target_size = [
  262. math.ceil(i / self.size_divisor) * self.size_divisor
  263. for i in target_size
  264. ]
  265. im_scale_w, im_scale_h = [
  266. target_size[1] / original_size[1], target_size[0] / original_size[0]
  267. ]
  268. im = F.resize(im, target_size, interp=self.interp)
  269. data['image'] = im
  270. data['image_size'] = [im.shape[1], im.shape[0]]
  271. data['scale_factors'] = [im_scale_w, im_scale_h]
  272. return data
  273. @classmethod
  274. def get_input_keys(cls):
  275. """ get input keys """
  276. # image: Image in hw or hwc format.
  277. return ['image']
  278. @classmethod
  279. def get_output_keys(cls):
  280. """ get output keys """
  281. # image: Image in hw or hwc format.
  282. # image_size: Width and height of the image.
  283. # scale_factors: Scale factors for image width and height.
  284. return ['image', 'image_size', 'scale_factors']
  285. class ResizeByLong(_BaseResize):
  286. """
  287. Proportionally resize the image by specifying the target length of the
  288. longest side.
  289. """
  290. def __init__(self, target_long_edge, size_divisor=None, interp='LINEAR'):
  291. """
  292. Initialize the instance.
  293. Args:
  294. target_long_edge (int): Target length of the longest side of image.
  295. size_divisor (int|None, optional): Divisor of resized image size.
  296. Default: None.
  297. interp (str, optional): Interpolation method. Choices are 'NEAREST',
  298. 'LINEAR', 'CUBIC', 'AREA', and 'LANCZOS4'. Default: 'LINEAR'.
  299. """
  300. super().__init__(size_divisor=size_divisor, interp=interp)
  301. self.target_long_edge = target_long_edge
  302. def apply(self, data):
  303. """ apply """
  304. im = data['image']
  305. h, w = im.shape[:2]
  306. scale = self.target_long_edge / max(h, w)
  307. h_resize = round(h * scale)
  308. w_resize = round(w * scale)
  309. if self.size_divisor is not None:
  310. h_resize = math.ceil(h_resize /
  311. self.size_divisor) * self.size_divisor
  312. w_resize = math.ceil(w_resize /
  313. self.size_divisor) * self.size_divisor
  314. im = F.resize(im, (w_resize, h_resize), interp=self.interp)
  315. data['image'] = im
  316. data['image_size'] = [im.shape[1], im.shape[0]]
  317. return data
  318. @classmethod
  319. def get_input_keys(cls):
  320. """ get input keys """
  321. # image: Image in hw or hwc format.
  322. return ['image']
  323. @classmethod
  324. def get_output_keys(cls):
  325. """ get output keys """
  326. # image: Image in hw or hwc format.
  327. # image_size: Width and height of the image.
  328. return ['image', 'image_size']
  329. class ResizeByShort(_BaseResize):
  330. """
  331. Proportionally resize the image by specifying the target length of the
  332. shortest side.
  333. """
  334. def __init__(self, target_short_edge, size_divisor=None, interp='LINEAR'):
  335. """
  336. Initialize the instance.
  337. Args:
  338. target_short_edge (int): Target length of the shortest side of image.
  339. size_divisor (int|None, optional): Divisor of resized image size.
  340. Default: None.
  341. interp (str, optional): Interpolation method. Choices are 'NEAREST',
  342. 'LINEAR', 'CUBIC', 'AREA', and 'LANCZOS4'. Default: 'LINEAR'.
  343. """
  344. super().__init__(size_divisor=size_divisor, interp=interp)
  345. self.target_short_edge = target_short_edge
  346. def apply(self, data):
  347. """ apply """
  348. im = data['image']
  349. h, w = im.shape[:2]
  350. scale = self.target_short_edge / min(h, w)
  351. h_resize = round(h * scale)
  352. w_resize = round(w * scale)
  353. if self.size_divisor is not None:
  354. h_resize = math.ceil(h_resize /
  355. self.size_divisor) * self.size_divisor
  356. w_resize = math.ceil(w_resize /
  357. self.size_divisor) * self.size_divisor
  358. im = F.resize(im, (w_resize, h_resize), interp=self.interp)
  359. data['image'] = im
  360. data['image_size'] = [im.shape[1], im.shape[0]]
  361. return data
  362. @classmethod
  363. def get_input_keys(cls):
  364. """ get input keys """
  365. # image: Image in hw or hwc format.
  366. return ['image']
  367. @classmethod
  368. def get_output_keys(cls):
  369. """ get output keys """
  370. # image: Image in hw or hwc format.
  371. # image_size: Width and height of the image.
  372. return ['image', 'image_size']
  373. class Pad(BaseTransform):
  374. """Pad the image."""
  375. def __init__(self, target_size, val=127.5):
  376. """
  377. Initialize the instance.
  378. Args:
  379. target_size (list|tuple|int): Target width and height of the image after
  380. padding.
  381. val (float, optional): Value to fill the padded area. Default: 127.5.
  382. """
  383. super().__init__()
  384. if isinstance(target_size, int):
  385. target_size = [target_size, target_size]
  386. _check_image_size(target_size)
  387. self.target_size = target_size
  388. self.val = val
  389. def apply(self, data):
  390. """ apply """
  391. im = data['image']
  392. h, w = im.shape[:2]
  393. tw, th = self.target_size
  394. ph = th - h
  395. pw = tw - w
  396. if ph < 0 or pw < 0:
  397. raise ValueError(
  398. f"Input image ({w}, {h}) smaller than the target size ({tw}, {th})."
  399. )
  400. else:
  401. im = F.pad(im, pad=(0, ph, 0, pw), val=self.val)
  402. data['image'] = im
  403. data['image_size'] = [im.shape[1], im.shape[0]]
  404. return data
  405. @classmethod
  406. def get_input_keys(cls):
  407. """ get input keys """
  408. # image: Image in hw or hwc format.
  409. return ['image']
  410. @classmethod
  411. def get_output_keys(cls):
  412. """ get output keys """
  413. # image: Image in hw or hwc format.
  414. # image_size: Width and height of the image.
  415. return ['image', 'image_size']
  416. class Normalize(BaseTransform):
  417. """Normalize the image."""
  418. def __init__(self, scale=1. / 255, mean=0.5, std=0.5, preserve_dtype=False):
  419. """
  420. Initialize the instance.
  421. Args:
  422. scale (float, optional): Scaling factor to apply to the image before
  423. applying normalization. Default: 1/255.
  424. mean (float|tuple|list, optional): Means for each channel of the image.
  425. Default: 0.5.
  426. std (float|tuple|list, optional): Standard deviations for each channel
  427. of the image. Default: 0.5.
  428. preserve_dtype (bool, optional): Whether to preserve the original dtype
  429. of the image.
  430. """
  431. super().__init__()
  432. self.scale = np.float32(scale)
  433. if isinstance(mean, float):
  434. mean = [mean]
  435. self.mean = np.asarray(mean).astype('float32')
  436. if isinstance(std, float):
  437. std = [std]
  438. self.std = np.asarray(std).astype('float32')
  439. self.preserve_dtype = preserve_dtype
  440. def apply(self, data):
  441. """ apply """
  442. im = data['image']
  443. old_type = im.dtype
  444. # XXX: If `old_type` has higher precision than float32,
  445. # we will lose some precision.
  446. im = im.astype('float32', copy=False)
  447. im *= self.scale
  448. im -= self.mean
  449. im /= self.std
  450. if self.preserve_dtype:
  451. im = im.astype(old_type, copy=False)
  452. data['image'] = im
  453. return data
  454. @classmethod
  455. def get_input_keys(cls):
  456. """ get input keys """
  457. # image: Image in hw or hwc format.
  458. return ['image']
  459. @classmethod
  460. def get_output_keys(cls):
  461. """ get output keys """
  462. # image: Image in hw or hwc format.
  463. return ['image']
  464. class ToCHWImage(BaseTransform):
  465. """Reorder the dimensions of the image from HWC to CHW."""
  466. def apply(self, data):
  467. """ apply """
  468. im = data['image']
  469. im = im.transpose((2, 0, 1))
  470. data['image'] = im
  471. return data
  472. @classmethod
  473. def get_input_keys(cls):
  474. """ get input keys """
  475. # image: Image in hwc format.
  476. return ['image']
  477. @classmethod
  478. def get_output_keys(cls):
  479. """ get output keys """
  480. # image: Image in chw format.
  481. return ['image']