image_common.py 18 KB


  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. from pathlib import Path
  16. import numpy as np
  17. import cv2
  18. from .....utils.download import download
  19. from .....utils.cache import CACHE_DIR
  20. from ..transform import BaseTransform
  21. from ..io.readers import ImageReader
  22. from ..io.writers import ImageWriter
  23. from . import image_functions as F
  24. __all__ = [
  25. 'ReadImage', 'Flip', 'Crop', 'Resize', 'ResizeByLong', 'ResizeByShort',
  26. 'Pad', 'Normalize', 'ToCHWImage'
  27. ]
  28. def _check_image_size(input_):
  29. """ check image size """
  30. if not (isinstance(input_, (list, tuple)) and len(input_) == 2 and
  31. isinstance(input_[0], int) and isinstance(input_[1], int)):
  32. raise TypeError(f"{input_} cannot represent a valid image size.")
  33. class ReadImage(BaseTransform):
  34. """Load image from the file."""
  35. _FLAGS_DICT = {
  36. 'BGR': cv2.IMREAD_COLOR,
  37. 'RGB': cv2.IMREAD_COLOR,
  38. 'GRAY': cv2.IMREAD_GRAYSCALE
  39. }
  40. def __init__(self, format='BGR'):
  41. """
  42. Initialize the instance.
  43. Args:
  44. format (str, optional): Target color format to convert the image to.
  45. Choices are 'BGR', 'RGB', and 'GRAY'. Default: 'BGR'.
  46. """
  47. super().__init__()
  48. self.format = format
  49. flags = self._FLAGS_DICT[self.format]
  50. self._reader = ImageReader(backend='opencv', flags=flags)
  51. self._writer = ImageWriter(backend='opencv')
  52. def apply(self, data):
  53. """ apply """
  54. if 'image' in data:
  55. img = data['image']
  56. img_path = (Path(CACHE_DIR) / "predict_input" /
  57. "tmp_img.jpg").as_posix()
  58. self._writer.write(img_path, img)
  59. data['input_path'] = img_path
  60. data['original_image'] = img
  61. data['original_image_size'] = [img.shape[1], img.shape[0]]
  62. return data
  63. elif 'input_path' not in data:
  64. raise KeyError(
  65. f"Key {repr('input_path')} is required, but not found.")
  66. im_path = data['input_path']
  67. # XXX: auto download for url
  68. im_path = self._download_from_url(im_path)
  69. blob = self._reader.read(im_path)
  70. if self.format == 'RGB':
  71. if blob.ndim != 3:
  72. raise RuntimeError("Array is not 3-dimensional.")
  73. # BGR to RGB
  74. blob = blob[..., ::-1]
  75. data['input_path'] = im_path
  76. data['image'] = blob
  77. data['original_image'] = blob
  78. data['original_image_size'] = [blob.shape[1], blob.shape[0]]
  79. return data
  80. def _download_from_url(self, in_path):
  81. if in_path.startswith("http"):
  82. file_name = Path(in_path).name
  83. save_path = Path(CACHE_DIR) / "predict_input" / file_name
  84. download(in_path, save_path, overwrite=True)
  85. return save_path.as_posix()
  86. return in_path
  87. @classmethod
  88. def get_input_keys(cls):
  89. """ get input keys """
  90. # input_path: Path of the image.
  91. return [['input_path'], ['image']]
  92. @classmethod
  93. def get_output_keys(cls):
  94. """ get output keys """
  95. # image: Image in hw or hwc format.
  96. # original_image: Original image in hw or hwc format.
  97. # original_image_size: Width and height of the original image.
  98. return ['image', 'original_image', 'original_image_size']
  99. class GetImageInfo(BaseTransform):
  100. """Get Image Info
  101. """
  102. def __init__(self):
  103. super().__init__()
  104. def apply(self, data):
  105. """ apply """
  106. blob = data['image']
  107. data['original_image'] = blob
  108. data['original_image_size'] = [blob.shape[1], blob.shape[0]]
  109. return data
  110. @classmethod
  111. def get_input_keys(cls):
  112. """ get input keys """
  113. # input_path: Path of the image.
  114. return ['image']
  115. @classmethod
  116. def get_output_keys(cls):
  117. """ get output keys """
  118. # image: Image in hw or hwc format.
  119. # original_image: Original image in hw or hwc format.
  120. # original_image_size: Width and height of the original image.
  121. return ['original_image', 'original_image_size']
  122. class Flip(BaseTransform):
  123. """Flip the image vertically or horizontally."""
  124. def __init__(self, mode='H'):
  125. """
  126. Initialize the instance.
  127. Args:
  128. mode (str, optional): 'H' for horizontal flipping and 'V' for vertical
  129. flipping. Default: 'H'.
  130. """
  131. super().__init__()
  132. if mode not in ('H', 'V'):
  133. raise ValueError("`mode` should be 'H' or 'V'.")
  134. self.mode = mode
  135. def apply(self, data):
  136. """ apply """
  137. im = data['image']
  138. if self.mode == 'H':
  139. im = F.flip_h(im)
  140. elif self.mode == 'V':
  141. im = F.flip_v(im)
  142. data['image'] = im
  143. return data
  144. @classmethod
  145. def get_input_keys(cls):
  146. """ get input keys """
  147. # image: Image in hw or hwc format.
  148. return ['image']
  149. @classmethod
  150. def get_output_keys(cls):
  151. """ get output keys """
  152. # image: Image in hw or hwc format.
  153. return ['image']
  154. class Crop(BaseTransform):
  155. """Crop region from the image."""
  156. def __init__(self, crop_size, mode='C'):
  157. """
  158. Initialize the instance.
  159. Args:
  160. crop_size (list|tuple|int): Width and height of the region to crop.
  161. mode (str, optional): 'C' for cropping the center part and 'TL' for
  162. cropping the top left part. Default: 'C'.
  163. """
  164. super().__init__()
  165. if isinstance(crop_size, int):
  166. crop_size = [crop_size, crop_size]
  167. _check_image_size(crop_size)
  168. self.crop_size = crop_size
  169. if mode not in ('C', 'TL'):
  170. raise ValueError("Unsupported interpolation method")
  171. self.mode = mode
  172. def apply(self, data):
  173. """ apply """
  174. im = data['image']
  175. h, w = im.shape[:2]
  176. cw, ch = self.crop_size
  177. if self.mode == 'C':
  178. x1 = max(0, (w - cw) // 2)
  179. y1 = max(0, (h - ch) // 2)
  180. elif self.mode == 'TL':
  181. x1, y1 = 0, 0
  182. x2 = min(w, x1 + cw)
  183. y2 = min(h, y1 + ch)
  184. coords = (x1, y1, x2, y2)
  185. if coords == (0, 0, w, h):
  186. raise ValueError(
  187. f"Input image ({w}, {h}) smaller than the target size ({cw}, {ch})."
  188. )
  189. im = F.slice(im, coords=coords)
  190. data['image'] = im
  191. data['image_size'] = [im.shape[1], im.shape[0]]
  192. return data
  193. @classmethod
  194. def get_input_keys(cls):
  195. """ get input keys """
  196. # image: Image in hw or hwc format.
  197. return ['image']
  198. @classmethod
  199. def get_output_keys(cls):
  200. """ get output keys """
  201. # image: Image in hw or hwc format.
  202. # image_size: Width and height of the image.
  203. return ['image', 'image_size']
  204. class _BaseResize(BaseTransform):
  205. _INTERP_DICT = {
  206. 'NEAREST': cv2.INTER_NEAREST,
  207. 'LINEAR': cv2.INTER_LINEAR,
  208. 'CUBIC': cv2.INTER_CUBIC,
  209. 'AREA': cv2.INTER_AREA,
  210. 'LANCZOS4': cv2.INTER_LANCZOS4
  211. }
  212. def __init__(self, size_divisor, interp):
  213. super().__init__()
  214. if size_divisor is not None:
  215. assert isinstance(size_divisor,
  216. int), "`size_divisor` should be None or int."
  217. self.size_divisor = size_divisor
  218. try:
  219. interp = self._INTERP_DICT[interp]
  220. except KeyError:
  221. raise ValueError("`interp` should be one of {}.".format(
  222. self._INTERP_DICT.keys()))
  223. self.interp = interp
  224. @staticmethod
  225. def _rescale_size(img_size, target_size):
  226. """ rescale size """
  227. scale = min(
  228. max(target_size) / max(img_size), min(target_size) / min(img_size))
  229. rescaled_size = [round(i * scale) for i in img_size]
  230. return rescaled_size, scale
  231. class Resize(_BaseResize):
  232. """Resize the image."""
  233. def __init__(self,
  234. target_size,
  235. keep_ratio=False,
  236. size_divisor=None,
  237. interp='LINEAR'):
  238. """
  239. Initialize the instance.
  240. Args:
  241. target_size (list|tuple|int): Target width and height.
  242. keep_ratio (bool, optional): Whether to keep the aspect ratio of resized
  243. image. Default: False.
  244. size_divisor (int|None, optional): Divisor of resized image size.
  245. Default: None.
  246. interp (str, optional): Interpolation method. Choices are 'NEAREST',
  247. 'LINEAR', 'CUBIC', 'AREA', and 'LANCZOS4'. Default: 'LINEAR'.
  248. """
  249. super().__init__(size_divisor=size_divisor, interp=interp)
  250. if isinstance(target_size, int):
  251. target_size = [target_size, target_size]
  252. _check_image_size(target_size)
  253. self.target_size = target_size
  254. self.keep_ratio = keep_ratio
  255. def apply(self, data):
  256. """ apply """
  257. target_size = self.target_size
  258. im = data['image']
  259. original_size = im.shape[:2]
  260. if self.keep_ratio:
  261. h, w = im.shape[0:2]
  262. target_size, _ = self._rescale_size((w, h), self.target_size)
  263. if self.size_divisor:
  264. target_size = [
  265. math.ceil(i / self.size_divisor) * self.size_divisor
  266. for i in target_size
  267. ]
  268. im_scale_w, im_scale_h = [
  269. target_size[1] / original_size[1], target_size[0] / original_size[0]
  270. ]
  271. im = F.resize(im, target_size, interp=self.interp)
  272. data['image'] = im
  273. data['image_size'] = [im.shape[1], im.shape[0]]
  274. data['scale_factors'] = [im_scale_w, im_scale_h]
  275. return data
  276. @classmethod
  277. def get_input_keys(cls):
  278. """ get input keys """
  279. # image: Image in hw or hwc format.
  280. return ['image']
  281. @classmethod
  282. def get_output_keys(cls):
  283. """ get output keys """
  284. # image: Image in hw or hwc format.
  285. # image_size: Width and height of the image.
  286. # scale_factors: Scale factors for image width and height.
  287. return ['image', 'image_size', 'scale_factors']
  288. class ResizeByLong(_BaseResize):
  289. """
  290. Proportionally resize the image by specifying the target length of the
  291. longest side.
  292. """
  293. def __init__(self, target_long_edge, size_divisor=None, interp='LINEAR'):
  294. """
  295. Initialize the instance.
  296. Args:
  297. target_long_edge (int): Target length of the longest side of image.
  298. size_divisor (int|None, optional): Divisor of resized image size.
  299. Default: None.
  300. interp (str, optional): Interpolation method. Choices are 'NEAREST',
  301. 'LINEAR', 'CUBIC', 'AREA', and 'LANCZOS4'. Default: 'LINEAR'.
  302. """
  303. super().__init__(size_divisor=size_divisor, interp=interp)
  304. self.target_long_edge = target_long_edge
  305. def apply(self, data):
  306. """ apply """
  307. im = data['image']
  308. h, w = im.shape[:2]
  309. scale = self.target_long_edge / max(h, w)
  310. h_resize = round(h * scale)
  311. w_resize = round(w * scale)
  312. if self.size_divisor is not None:
  313. h_resize = math.ceil(h_resize /
  314. self.size_divisor) * self.size_divisor
  315. w_resize = math.ceil(w_resize /
  316. self.size_divisor) * self.size_divisor
  317. im = F.resize(im, (w_resize, h_resize), interp=self.interp)
  318. data['image'] = im
  319. data['image_size'] = [im.shape[1], im.shape[0]]
  320. return data
  321. @classmethod
  322. def get_input_keys(cls):
  323. """ get input keys """
  324. # image: Image in hw or hwc format.
  325. return ['image']
  326. @classmethod
  327. def get_output_keys(cls):
  328. """ get output keys """
  329. # image: Image in hw or hwc format.
  330. # image_size: Width and height of the image.
  331. return ['image', 'image_size']
  332. class ResizeByShort(_BaseResize):
  333. """
  334. Proportionally resize the image by specifying the target length of the
  335. shortest side.
  336. """
  337. def __init__(self, target_short_edge, size_divisor=None, interp='LINEAR'):
  338. """
  339. Initialize the instance.
  340. Args:
  341. target_short_edge (int): Target length of the shortest side of image.
  342. size_divisor (int|None, optional): Divisor of resized image size.
  343. Default: None.
  344. interp (str, optional): Interpolation method. Choices are 'NEAREST',
  345. 'LINEAR', 'CUBIC', 'AREA', and 'LANCZOS4'. Default: 'LINEAR'.
  346. """
  347. super().__init__(size_divisor=size_divisor, interp=interp)
  348. self.target_short_edge = target_short_edge
  349. def apply(self, data):
  350. """ apply """
  351. im = data['image']
  352. h, w = im.shape[:2]
  353. scale = self.target_short_edge / min(h, w)
  354. h_resize = round(h * scale)
  355. w_resize = round(w * scale)
  356. if self.size_divisor is not None:
  357. h_resize = math.ceil(h_resize /
  358. self.size_divisor) * self.size_divisor
  359. w_resize = math.ceil(w_resize /
  360. self.size_divisor) * self.size_divisor
  361. im = F.resize(im, (w_resize, h_resize), interp=self.interp)
  362. data['image'] = im
  363. data['image_size'] = [im.shape[1], im.shape[0]]
  364. return data
  365. @classmethod
  366. def get_input_keys(cls):
  367. """ get input keys """
  368. # image: Image in hw or hwc format.
  369. return ['image']
  370. @classmethod
  371. def get_output_keys(cls):
  372. """ get output keys """
  373. # image: Image in hw or hwc format.
  374. # image_size: Width and height of the image.
  375. return ['image', 'image_size']
  376. class Pad(BaseTransform):
  377. """Pad the image."""
  378. def __init__(self, target_size, val=127.5):
  379. """
  380. Initialize the instance.
  381. Args:
  382. target_size (list|tuple|int): Target width and height of the image after
  383. padding.
  384. val (float, optional): Value to fill the padded area. Default: 127.5.
  385. """
  386. super().__init__()
  387. if isinstance(target_size, int):
  388. target_size = [target_size, target_size]
  389. _check_image_size(target_size)
  390. self.target_size = target_size
  391. self.val = val
  392. def apply(self, data):
  393. """ apply """
  394. im = data['image']
  395. h, w = im.shape[:2]
  396. tw, th = self.target_size
  397. ph = th - h
  398. pw = tw - w
  399. if ph < 0 or pw < 0:
  400. raise ValueError(
  401. f"Input image ({w}, {h}) smaller than the target size ({tw}, {th})."
  402. )
  403. else:
  404. im = F.pad(im, pad=(0, ph, 0, pw), val=self.val)
  405. data['image'] = im
  406. data['image_size'] = [im.shape[1], im.shape[0]]
  407. return data
  408. @classmethod
  409. def get_input_keys(cls):
  410. """ get input keys """
  411. # image: Image in hw or hwc format.
  412. return ['image']
  413. @classmethod
  414. def get_output_keys(cls):
  415. """ get output keys """
  416. # image: Image in hw or hwc format.
  417. # image_size: Width and height of the image.
  418. return ['image', 'image_size']
  419. class Normalize(BaseTransform):
  420. """Normalize the image."""
  421. def __init__(self, scale=1. / 255, mean=0.5, std=0.5, preserve_dtype=False):
  422. """
  423. Initialize the instance.
  424. Args:
  425. scale (float, optional): Scaling factor to apply to the image before
  426. applying normalization. Default: 1/255.
  427. mean (float|tuple|list, optional): Means for each channel of the image.
  428. Default: 0.5.
  429. std (float|tuple|list, optional): Standard deviations for each channel
  430. of the image. Default: 0.5.
  431. preserve_dtype (bool, optional): Whether to preserve the original dtype
  432. of the image.
  433. """
  434. super().__init__()
  435. self.scale = np.float32(scale)
  436. if isinstance(mean, float):
  437. mean = [mean]
  438. self.mean = np.asarray(mean).astype('float32')
  439. if isinstance(std, float):
  440. std = [std]
  441. self.std = np.asarray(std).astype('float32')
  442. self.preserve_dtype = preserve_dtype
  443. def apply(self, data):
  444. """ apply """
  445. im = data['image']
  446. old_type = im.dtype
  447. # XXX: If `old_type` has higher precision than float32,
  448. # we will lose some precision.
  449. im = im.astype('float32', copy=False)
  450. im *= self.scale
  451. im -= self.mean
  452. im /= self.std
  453. if self.preserve_dtype:
  454. im = im.astype(old_type, copy=False)
  455. data['image'] = im
  456. return data
  457. @classmethod
  458. def get_input_keys(cls):
  459. """ get input keys """
  460. # image: Image in hw or hwc format.
  461. return ['image']
  462. @classmethod
  463. def get_output_keys(cls):
  464. """ get output keys """
  465. # image: Image in hw or hwc format.
  466. return ['image']
  467. class ToCHWImage(BaseTransform):
  468. """Reorder the dimensions of the image from HWC to CHW."""
  469. def apply(self, data):
  470. """ apply """
  471. im = data['image']
  472. im = im.transpose((2, 0, 1))
  473. data['image'] = im
  474. return data
  475. @classmethod
  476. def get_input_keys(cls):
  477. """ get input keys """
  478. # image: Image in hwc format.
  479. return ['image']
  480. @classmethod
  481. def get_output_keys(cls):
  482. """ get output keys """
  483. # image: Image in chw format.
  484. return ['image']