__init__.py 16 KB


  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import math
  16. from pathlib import Path
  17. import numpy as np
  18. import cv2
  19. from .....utils.download import download
  20. from .....utils.cache import CACHE_DIR
  21. from ....utils.io import ImageReader, ImageWriter
  22. from ...base import BaseComponent
  23. from . import funcs as F
  24. __all__ = [
  25. "ReadImage",
  26. "Flip",
  27. "Crop",
  28. "Resize",
  29. "ResizeByLong",
  30. "ResizeByShort",
  31. "Pad",
  32. "Normalize",
  33. "ToCHWImage",
  34. ]
  35. def _check_image_size(input_):
  36. """check image size"""
  37. if not (
  38. isinstance(input_, (list, tuple))
  39. and len(input_) == 2
  40. and isinstance(input_[0], int)
  41. and isinstance(input_[1], int)
  42. ):
  43. raise TypeError(f"{input_} cannot represent a valid image size.")
  44. class ReadImage(BaseComponent):
  45. """Load image from the file."""
  46. INPUT_KEYS = ["img"]
  47. OUTPUT_KEYS = ["img", "img_size"]
  48. DEAULT_INPUTS = {"img": "img"}
  49. DEAULT_OUTPUTS = {"img": "img", "img_path": "img_path", "img_size": "img_size"}
  50. _FLAGS_DICT = {
  51. "BGR": cv2.IMREAD_COLOR,
  52. "RGB": cv2.IMREAD_COLOR,
  53. "GRAY": cv2.IMREAD_GRAYSCALE,
  54. }
  55. SUFFIX = ["jpg", "png", "jpeg", "JPEG", "JPG", "bmp"]
  56. def __init__(self, batch_size=1, format="BGR"):
  57. """
  58. Initialize the instance.
  59. Args:
  60. format (str, optional): Target color format to convert the image to.
  61. Choices are 'BGR', 'RGB', and 'GRAY'. Default: 'BGR'.
  62. """
  63. super().__init__()
  64. self.batch_size = batch_size
  65. self.format = format
  66. flags = self._FLAGS_DICT[self.format]
  67. self._reader = ImageReader(backend="opencv", flags=flags)
  68. self._writer = ImageWriter(backend="opencv")
  69. def apply(self, img):
  70. """apply"""
  71. if not isinstance(img, str):
  72. img_path = (Path(CACHE_DIR) / "predict_input" / "tmp_img.jpg").as_posix()
  73. self._writer.write(img_path, img)
  74. yield [
  75. {
  76. "img_path": img_path,
  77. "img": img,
  78. "img_size": [img.shape[1], img.shape[0]],
  79. }
  80. ]
  81. else:
  82. img_path = img
  83. # XXX: auto download for url
  84. img_path = self._download_from_url(img_path)
  85. image_list = self._get_image_list(img_path)
  86. batch = []
  87. for img_path in image_list:
  88. img = self._read_img(img_path)
  89. batch.append(img)
  90. if len(batch) >= self.batch_size:
  91. yield batch
  92. batch = []
  93. if len(batch) > 0:
  94. yield batch
  95. def _read_img(self, img_path):
  96. blob = self._reader.read(img_path)
  97. if blob is None:
  98. raise Exception("Image read Error")
  99. if self.format == "RGB":
  100. if blob.ndim != 3:
  101. raise RuntimeError("Array is not 3-dimensional.")
  102. # BGR to RGB
  103. blob = blob[..., ::-1]
  104. return {
  105. "img_path": img_path,
  106. "img": blob,
  107. "img_size": [blob.shape[1], blob.shape[0]],
  108. }
  109. def _download_from_url(self, in_path):
  110. if in_path.startswith("http"):
  111. file_name = Path(in_path).name
  112. save_path = Path(CACHE_DIR) / "predict_input" / file_name
  113. download(in_path, save_path, overwrite=True)
  114. return save_path.as_posix()
  115. return in_path
  116. def _get_image_list(self, img_file):
  117. imgs_lists = []
  118. if img_file is None or not os.path.exists(img_file):
  119. raise Exception(f"Not found any img file in path: {img_file}")
  120. if os.path.isfile(img_file) and img_file.split(".")[-1] in self.SUFFIX:
  121. imgs_lists.append(img_file)
  122. elif os.path.isdir(img_file):
  123. for root, dirs, files in os.walk(img_file):
  124. for single_file in files:
  125. if single_file.split(".")[-1] in self.SUFFIX:
  126. imgs_lists.append(os.path.join(root, single_file))
  127. if len(imgs_lists) == 0:
  128. raise Exception("not found any img file in {}".format(img_file))
  129. imgs_lists = sorted(imgs_lists)
  130. return imgs_lists
  131. class GetImageInfo(BaseComponent):
  132. """Get Image Info"""
  133. INPUT_KEYS = "img"
  134. OUTPUT_KEYS = "img_size"
  135. DEAULT_INPUTS = {"img": "img"}
  136. DEAULT_OUTPUTS = {"img_size": "img_size"}
  137. def __init__(self):
  138. super().__init__()
  139. def apply(self, img):
  140. """apply"""
  141. return {"img_size": [img.shape[1], img.shape[0]]}
  142. class Flip(BaseComponent):
  143. """Flip the image vertically or horizontally."""
  144. INPUT_KEYS = "img"
  145. OUTPUT_KEYS = "img"
  146. DEAULT_INPUTS = {"img": "img"}
  147. DEAULT_OUTPUTS = {"img": "img"}
  148. def __init__(self, mode="H"):
  149. """
  150. Initialize the instance.
  151. Args:
  152. mode (str, optional): 'H' for horizontal flipping and 'V' for vertical
  153. flipping. Default: 'H'.
  154. """
  155. super().__init__()
  156. if mode not in ("H", "V"):
  157. raise ValueError("`mode` should be 'H' or 'V'.")
  158. self.mode = mode
  159. def apply(self, img):
  160. """apply"""
  161. if self.mode == "H":
  162. img = F.flip_h(img)
  163. elif self.mode == "V":
  164. img = F.flip_v(img)
  165. return {"img": img}
  166. class Crop(BaseComponent):
  167. """Crop region from the image."""
  168. INPUT_KEYS = "img"
  169. OUTPUT_KEYS = ["img", "img_size"]
  170. DEAULT_INPUTS = {"img": "img"}
  171. DEAULT_OUTPUTS = {"img": "img", "img_size": "img_size"}
  172. def __init__(self, crop_size, mode="C"):
  173. """
  174. Initialize the instance.
  175. Args:
  176. crop_size (list|tuple|int): Width and height of the region to crop.
  177. mode (str, optional): 'C' for cropping the center part and 'TL' for
  178. cropping the top left part. Default: 'C'.
  179. """
  180. super().__init__()
  181. if isinstance(crop_size, int):
  182. crop_size = [crop_size, crop_size]
  183. _check_image_size(crop_size)
  184. self.crop_size = crop_size
  185. if mode not in ("C", "TL"):
  186. raise ValueError("Unsupported interpolation method")
  187. self.mode = mode
  188. def apply(self, img):
  189. """apply"""
  190. h, w = img.shape[:2]
  191. cw, ch = self.crop_size
  192. if self.mode == "C":
  193. x1 = max(0, (w - cw) // 2)
  194. y1 = max(0, (h - ch) // 2)
  195. elif self.mode == "TL":
  196. x1, y1 = 0, 0
  197. x2 = min(w, x1 + cw)
  198. y2 = min(h, y1 + ch)
  199. coords = (x1, y1, x2, y2)
  200. if coords == (0, 0, w, h):
  201. raise ValueError(
  202. f"Input image ({w}, {h}) smaller than the target size ({cw}, {ch})."
  203. )
  204. img = F.slice(img, coords=coords)
  205. return {"img": img, "img_size": [img.shape[1], img.shape[0]]}
  206. class _BaseResize(BaseComponent):
  207. _INTERP_DICT = {
  208. "NEAREST": cv2.INTER_NEAREST,
  209. "LINEAR": cv2.INTER_LINEAR,
  210. "CUBIC": cv2.INTER_CUBIC,
  211. "AREA": cv2.INTER_AREA,
  212. "LANCZOS4": cv2.INTER_LANCZOS4,
  213. }
  214. def __init__(self, size_divisor, interp):
  215. super().__init__()
  216. if size_divisor is not None:
  217. assert isinstance(
  218. size_divisor, int
  219. ), "`size_divisor` should be None or int."
  220. self.size_divisor = size_divisor
  221. try:
  222. interp = self._INTERP_DICT[interp]
  223. except KeyError:
  224. raise ValueError(
  225. "`interp` should be one of {}.".format(self._INTERP_DICT.keys())
  226. )
  227. self.interp = interp
  228. @staticmethod
  229. def _rescale_size(img_size, target_size):
  230. """rescale size"""
  231. scale = min(max(target_size) / max(img_size), min(target_size) / min(img_size))
  232. rescaled_size = [round(i * scale) for i in img_size]
  233. return rescaled_size, scale
  234. class Resize(_BaseResize):
  235. """Resize the image."""
  236. INPUT_KEYS = "img"
  237. OUTPUT_KEYS = ["img", "img_size", "scale_factors"]
  238. DEAULT_INPUTS = {"img": "img"}
  239. DEAULT_OUTPUTS = {
  240. "img": "img",
  241. "img_size": "img_size",
  242. "scale_factors": "scale_factors",
  243. }
  244. def __init__(
  245. self, target_size, keep_ratio=False, size_divisor=None, interp="LINEAR"
  246. ):
  247. """
  248. Initialize the instance.
  249. Args:
  250. target_size (list|tuple|int): Target width and height.
  251. keep_ratio (bool, optional): Whether to keep the aspect ratio of resized
  252. image. Default: False.
  253. size_divisor (int|None, optional): Divisor of resized image size.
  254. Default: None.
  255. interp (str, optional): Interpolation method. Choices are 'NEAREST',
  256. 'LINEAR', 'CUBIC', 'AREA', and 'LANCZOS4'. Default: 'LINEAR'.
  257. """
  258. super().__init__(size_divisor=size_divisor, interp=interp)
  259. if isinstance(target_size, int):
  260. target_size = [target_size, target_size]
  261. _check_image_size(target_size)
  262. self.target_size = target_size
  263. self.keep_ratio = keep_ratio
  264. def apply(self, img):
  265. """apply"""
  266. target_size = self.target_size
  267. original_size = img.shape[:2]
  268. if self.keep_ratio:
  269. h, w = img.shape[0:2]
  270. target_size, _ = self._rescale_size((w, h), self.target_size)
  271. if self.size_divisor:
  272. target_size = [
  273. math.ceil(i / self.size_divisor) * self.size_divisor
  274. for i in target_size
  275. ]
  276. img_scale_w, img_scale_h = [
  277. target_size[1] / original_size[1],
  278. target_size[0] / original_size[0],
  279. ]
  280. img = F.resize(img, target_size, interp=self.interp)
  281. return {
  282. "img": img,
  283. "img_size": [img.shape[1], img.shape[0]],
  284. "scale_factors": [img_scale_w, img_scale_h],
  285. }
  286. class ResizeByLong(_BaseResize):
  287. """
  288. Proportionally resize the image by specifying the target length of the
  289. longest side.
  290. """
  291. INPUT_KEYS = "img"
  292. OUTPUT_KEYS = ["img", "img_size"]
  293. DEAULT_INPUTS = {"img": "img"}
  294. DEAULT_OUTPUTS = {"img": "img", "img_size": "img_size"}
  295. def __init__(self, target_long_edge, size_divisor=None, interp="LINEAR"):
  296. """
  297. Initialize the instance.
  298. Args:
  299. target_long_edge (int): Target length of the longest side of image.
  300. size_divisor (int|None, optional): Divisor of resized image size.
  301. Default: None.
  302. interp (str, optional): Interpolation method. Choices are 'NEAREST',
  303. 'LINEAR', 'CUBIC', 'AREA', and 'LANCZOS4'. Default: 'LINEAR'.
  304. """
  305. super().__init__(size_divisor=size_divisor, interp=interp)
  306. self.target_long_edge = target_long_edge
  307. def apply(self, img):
  308. """apply"""
  309. h, w = img.shape[:2]
  310. scale = self.target_long_edge / max(h, w)
  311. h_resize = round(h * scale)
  312. w_resize = round(w * scale)
  313. if self.size_divisor is not None:
  314. h_resize = math.ceil(h_resize / self.size_divisor) * self.size_divisor
  315. w_resize = math.ceil(w_resize / self.size_divisor) * self.size_divisor
  316. img = F.resize(img, (w_resize, h_resize), interp=self.interp)
  317. return {"img": img, "img_size": [img.shape[1], img.shape[0]]}
  318. class ResizeByShort(_BaseResize):
  319. """
  320. Proportionally resize the image by specifying the target length of the
  321. shortest side.
  322. """
  323. INPUT_KEYS = "img"
  324. OUTPUT_KEYS = ["img", "img_size"]
  325. DEAULT_INPUTS = {"img": "img"}
  326. DEAULT_OUTPUTS = {"img": "img", "img_size": "img_size"}
  327. def __init__(self, target_short_edge, size_divisor=None, interp="LINEAR"):
  328. """
  329. Initialize the instance.
  330. Args:
  331. target_short_edge (int): Target length of the shortest side of image.
  332. size_divisor (int|None, optional): Divisor of resized image size.
  333. Default: None.
  334. interp (str, optional): Interpolation method. Choices are 'NEAREST',
  335. 'LINEAR', 'CUBIC', 'AREA', and 'LANCZOS4'. Default: 'LINEAR'.
  336. """
  337. super().__init__(size_divisor=size_divisor, interp=interp)
  338. self.target_short_edge = target_short_edge
  339. def apply(self, img):
  340. """apply"""
  341. h, w = img.shape[:2]
  342. scale = self.target_short_edge / min(h, w)
  343. h_resize = round(h * scale)
  344. w_resize = round(w * scale)
  345. if self.size_divisor is not None:
  346. h_resize = math.ceil(h_resize / self.size_divisor) * self.size_divisor
  347. w_resize = math.ceil(w_resize / self.size_divisor) * self.size_divisor
  348. img = F.resize(img, (w_resize, h_resize), interp=self.interp)
  349. return {"img": img, "img_size": [img.shape[1], img.shape[0]]}
  350. class Pad(BaseComponent):
  351. """Pad the image."""
  352. INPUT_KEYS = "img"
  353. OUTPUT_KEYS = ["img", "img_size"]
  354. DEAULT_INPUTS = {"img": "img"}
  355. DEAULT_OUTPUTS = {"img": "img", "img_size": "img_size"}
  356. def __init__(self, target_size, val=127.5):
  357. """
  358. Initialize the instance.
  359. Args:
  360. target_size (list|tuple|int): Target width and height of the image after
  361. padding.
  362. val (float, optional): Value to fill the padded area. Default: 127.5.
  363. """
  364. super().__init__()
  365. if isinstance(target_size, int):
  366. target_size = [target_size, target_size]
  367. _check_image_size(target_size)
  368. self.target_size = target_size
  369. self.val = val
  370. def apply(self, img):
  371. """apply"""
  372. h, w = img.shape[:2]
  373. tw, th = self.target_size
  374. ph = th - h
  375. pw = tw - w
  376. if ph < 0 or pw < 0:
  377. raise ValueError(
  378. f"Input image ({w}, {h}) smaller than the target size ({tw}, {th})."
  379. )
  380. else:
  381. img = F.pad(img, pad=(0, ph, 0, pw), val=self.val)
  382. return {"img": img, "img_size": [img.shape[1], img.shape[0]]}
  383. class Normalize(BaseComponent):
  384. """Normalize the image."""
  385. INPUT_KEYS = "img"
  386. OUTPUT_KEYS = "img"
  387. DEAULT_INPUTS = {"img": "img"}
  388. DEAULT_OUTPUTS = {"img": "img"}
  389. def __init__(self, scale=1.0 / 255, mean=0.5, std=0.5, preserve_dtype=False):
  390. """
  391. Initialize the instance.
  392. Args:
  393. scale (float, optional): Scaling factor to apply to the image before
  394. applying normalization. Default: 1/255.
  395. mean (float|tuple|list, optional): Means for each channel of the image.
  396. Default: 0.5.
  397. std (float|tuple|list, optional): Standard deviations for each channel
  398. of the image. Default: 0.5.
  399. preserve_dtype (bool, optional): Whether to preserve the original dtype
  400. of the image.
  401. """
  402. super().__init__()
  403. self.scale = np.float32(scale)
  404. if isinstance(mean, float):
  405. mean = [mean]
  406. self.mean = np.asarray(mean).astype("float32")
  407. if isinstance(std, float):
  408. std = [std]
  409. self.std = np.asarray(std).astype("float32")
  410. self.preserve_dtype = preserve_dtype
  411. def apply(self, img):
  412. """apply"""
  413. old_type = img.dtype
  414. # XXX: If `old_type` has higher precision than float32,
  415. # we will lose some precision.
  416. img = img.astype("float32", copy=False)
  417. img *= self.scale
  418. img -= self.mean
  419. img /= self.std
  420. if self.preserve_dtype:
  421. img = img.astype(old_type, copy=False)
  422. return {"img": img}
  423. class ToCHWImage(BaseComponent):
  424. """Reorder the dimensions of the image from HWC to CHW."""
  425. INPUT_KEYS = "img"
  426. OUTPUT_KEYS = "img"
  427. DEAULT_INPUTS = {"img": "img"}
  428. DEAULT_OUTPUTS = {"img": "img"}
  429. def apply(self, img):
  430. """apply"""
  431. img = img.transpose((2, 0, 1))
  432. return {"img": img}