common.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. from pathlib import Path
  16. from copy import deepcopy
  17. import numpy as np
  18. import cv2
  19. from .....utils.cache import CACHE_DIR, temp_file_manager
  20. from ....utils.io import ImageReader, ImageWriter, PDFReader
  21. from ...utils.mixin import BatchSizeMixin
  22. from ...base import BaseComponent
  23. from ..read_data import _BaseRead
  24. from . import funcs as F
  25. __all__ = [
  26. "ReadImage",
  27. "Flip",
  28. "Crop",
  29. "Resize",
  30. "ResizeByLong",
  31. "ResizeByShort",
  32. "Pad",
  33. "Normalize",
  34. "ToCHWImage",
  35. "PadStride",
  36. ]
  37. def _check_image_size(input_):
  38. """check image size"""
  39. if not (
  40. isinstance(input_, (list, tuple))
  41. and len(input_) == 2
  42. and isinstance(input_[0], int)
  43. and isinstance(input_[1], int)
  44. ):
  45. raise TypeError(f"{input_} cannot represent a valid image size.")
  46. class ReadImage(_BaseRead):
  47. """Load image from the file."""
  48. INPUT_KEYS = ["img"]
  49. OUTPUT_KEYS = ["img", "img_size", "ori_img", "ori_img_size"]
  50. DEAULT_INPUTS = {"img": "img"}
  51. DEAULT_OUTPUTS = {
  52. "img": "img",
  53. "input_path": "input_path",
  54. "img_size": "img_size",
  55. "ori_img": "ori_img",
  56. "ori_img_size": "ori_img_size",
  57. }
  58. _FLAGS_DICT = {
  59. "BGR": cv2.IMREAD_COLOR,
  60. "RGB": cv2.IMREAD_COLOR,
  61. "GRAY": cv2.IMREAD_GRAYSCALE,
  62. }
  63. SUFFIX = ["jpg", "png", "jpeg", "JPEG", "JPG", "bmp", "PDF", "pdf"]
  64. def __init__(self, batch_size=1, format="BGR"):
  65. """
  66. Initialize the instance.
  67. Args:
  68. format (str, optional): Target color format to convert the image to.
  69. Choices are 'BGR', 'RGB', and 'GRAY'. Default: 'BGR'.
  70. """
  71. super().__init__(batch_size)
  72. self.format = format
  73. flags = self._FLAGS_DICT[self.format]
  74. self._img_reader = ImageReader(backend="opencv", flags=flags)
  75. self._pdf_reader = PDFReader()
  76. self._writer = ImageWriter(backend="opencv")
  77. def apply(self, img):
  78. """apply"""
  79. if isinstance(img, np.ndarray):
  80. with temp_file_manager.temp_file_context(
  81. delete=True, suffix=".png"
  82. ) as temp_file:
  83. img_path = Path(temp_file.name)
  84. self._writer.write(img_path, img)
  85. yield [
  86. {
  87. "input_path": img_path,
  88. "img": img,
  89. "img_size": [img.shape[1], img.shape[0]],
  90. "ori_img": deepcopy(img),
  91. "ori_img_size": deepcopy([img.shape[1], img.shape[0]]),
  92. }
  93. ]
  94. elif isinstance(img, str):
  95. file_path = img
  96. file_path = self._download_from_url(file_path)
  97. file_list = self._get_files_list(file_path)
  98. batch = []
  99. for file_path in file_list:
  100. img = self._read(file_path)
  101. batch.extend(img)
  102. if len(batch) >= self.batch_size:
  103. yield batch
  104. batch = []
  105. if len(batch) > 0:
  106. yield batch
  107. else:
  108. raise TypeError(
  109. f"ReadImage only supports the following types:\n"
  110. f"1. str, indicating a image file path or a directory containing image files.\n"
  111. f"2. numpy.ndarray.\n"
  112. f"However, got type: {type(img).__name__}."
  113. )
  114. def _read(self, file_path):
  115. if str(file_path).lower().endswith(".pdf"):
  116. return self._read_pdf(file_path)
  117. else:
  118. return self._read_img(file_path)
  119. def _read_img(self, img_path):
  120. blob = self._img_reader.read(img_path)
  121. if blob is None:
  122. raise Exception("Image read Error")
  123. if self.format == "RGB":
  124. if blob.ndim != 3:
  125. raise RuntimeError("Array is not 3-dimensional.")
  126. # BGR to RGB
  127. blob = blob[..., ::-1]
  128. return [
  129. {
  130. "input_path": img_path,
  131. "img": blob,
  132. "img_size": [blob.shape[1], blob.shape[0]],
  133. "ori_img": deepcopy(blob),
  134. "ori_img_size": deepcopy([blob.shape[1], blob.shape[0]]),
  135. }
  136. ]
  137. def _read_pdf(self, pdf_path):
  138. img_list = self._pdf_reader.read(pdf_path)
  139. return [
  140. {
  141. "input_path": pdf_path,
  142. "img": img,
  143. "img_size": [img.shape[1], img.shape[0]],
  144. "ori_img": deepcopy(img),
  145. "ori_img_size": deepcopy([img.shape[1], img.shape[0]]),
  146. }
  147. for img in img_list
  148. ]
  149. class GetImageInfo(BaseComponent):
  150. """Get Image Info"""
  151. INPUT_KEYS = "img"
  152. OUTPUT_KEYS = "img_size"
  153. DEAULT_INPUTS = {"img": "img"}
  154. DEAULT_OUTPUTS = {"img_size": "img_size"}
  155. def __init__(self):
  156. super().__init__()
  157. def apply(self, img):
  158. """apply"""
  159. return {"img_size": [img.shape[1], img.shape[0]]}
  160. class Flip(BaseComponent):
  161. """Flip the image vertically or horizontally."""
  162. INPUT_KEYS = "img"
  163. OUTPUT_KEYS = "img"
  164. DEAULT_INPUTS = {"img": "img"}
  165. DEAULT_OUTPUTS = {"img": "img"}
  166. def __init__(self, mode="H"):
  167. """
  168. Initialize the instance.
  169. Args:
  170. mode (str, optional): 'H' for horizontal flipping and 'V' for vertical
  171. flipping. Default: 'H'.
  172. """
  173. super().__init__()
  174. if mode not in ("H", "V"):
  175. raise ValueError("`mode` should be 'H' or 'V'.")
  176. self.mode = mode
  177. def apply(self, img):
  178. """apply"""
  179. if self.mode == "H":
  180. img = F.flip_h(img)
  181. elif self.mode == "V":
  182. img = F.flip_v(img)
  183. return {"img": img}
  184. class Crop(BaseComponent):
  185. """Crop region from the image."""
  186. INPUT_KEYS = "img"
  187. OUTPUT_KEYS = ["img", "img_size"]
  188. DEAULT_INPUTS = {"img": "img"}
  189. DEAULT_OUTPUTS = {"img": "img", "img_size": "img_size"}
  190. def __init__(self, crop_size, mode="C"):
  191. """
  192. Initialize the instance.
  193. Args:
  194. crop_size (list|tuple|int): Width and height of the region to crop.
  195. mode (str, optional): 'C' for cropping the center part and 'TL' for
  196. cropping the top left part. Default: 'C'.
  197. """
  198. super().__init__()
  199. if isinstance(crop_size, int):
  200. crop_size = [crop_size, crop_size]
  201. _check_image_size(crop_size)
  202. self.crop_size = crop_size
  203. if mode not in ("C", "TL"):
  204. raise ValueError("Unsupported interpolation method")
  205. self.mode = mode
  206. def apply(self, img):
  207. """apply"""
  208. h, w = img.shape[:2]
  209. cw, ch = self.crop_size
  210. if self.mode == "C":
  211. x1 = max(0, (w - cw) // 2)
  212. y1 = max(0, (h - ch) // 2)
  213. elif self.mode == "TL":
  214. x1, y1 = 0, 0
  215. x2 = min(w, x1 + cw)
  216. y2 = min(h, y1 + ch)
  217. coords = (x1, y1, x2, y2)
  218. if coords == (0, 0, w, h):
  219. raise ValueError(
  220. f"Input image ({w}, {h}) smaller than the target size ({cw}, {ch})."
  221. )
  222. img = F.slice(img, coords=coords)
  223. return {"img": img, "img_size": [img.shape[1], img.shape[0]]}
  224. class _BaseResize(BaseComponent):
  225. _INTERP_DICT = {
  226. "NEAREST": cv2.INTER_NEAREST,
  227. "LINEAR": cv2.INTER_LINEAR,
  228. "CUBIC": cv2.INTER_CUBIC,
  229. "AREA": cv2.INTER_AREA,
  230. "LANCZOS4": cv2.INTER_LANCZOS4,
  231. }
  232. def __init__(self, size_divisor, interp):
  233. super().__init__()
  234. if size_divisor is not None:
  235. assert isinstance(
  236. size_divisor, int
  237. ), "`size_divisor` should be None or int."
  238. self.size_divisor = size_divisor
  239. try:
  240. interp = self._INTERP_DICT[interp]
  241. except KeyError:
  242. raise ValueError(
  243. "`interp` should be one of {}.".format(self._INTERP_DICT.keys())
  244. )
  245. self.interp = interp
  246. @staticmethod
  247. def _rescale_size(img_size, target_size):
  248. """rescale size"""
  249. scale = min(max(target_size) / max(img_size), min(target_size) / min(img_size))
  250. rescaled_size = [round(i * scale) for i in img_size]
  251. return rescaled_size, scale
  252. class Resize(_BaseResize):
  253. """Resize the image."""
  254. INPUT_KEYS = "img"
  255. OUTPUT_KEYS = ["img", "img_size", "scale_factors"]
  256. DEAULT_INPUTS = {"img": "img"}
  257. DEAULT_OUTPUTS = {
  258. "img": "img",
  259. "img_size": "img_size",
  260. "scale_factors": "scale_factors",
  261. }
  262. def __init__(
  263. self, target_size, keep_ratio=False, size_divisor=None, interp="LINEAR"
  264. ):
  265. """
  266. Initialize the instance.
  267. Args:
  268. target_size (list|tuple|int): Target width and height.
  269. keep_ratio (bool, optional): Whether to keep the aspect ratio of resized
  270. image. Default: False.
  271. size_divisor (int|None, optional): Divisor of resized image size.
  272. Default: None.
  273. interp (str, optional): Interpolation method. Choices are 'NEAREST',
  274. 'LINEAR', 'CUBIC', 'AREA', and 'LANCZOS4'. Default: 'LINEAR'.
  275. """
  276. super().__init__(size_divisor=size_divisor, interp=interp)
  277. if isinstance(target_size, int):
  278. target_size = [target_size, target_size]
  279. _check_image_size(target_size)
  280. self.target_size = target_size
  281. self.keep_ratio = keep_ratio
  282. def apply(self, img):
  283. """apply"""
  284. target_size = self.target_size
  285. original_size = img.shape[:2]
  286. if self.keep_ratio:
  287. h, w = img.shape[0:2]
  288. target_size, _ = self._rescale_size((h, w), self.target_size)
  289. if self.size_divisor:
  290. target_size = [
  291. math.ceil(i / self.size_divisor) * self.size_divisor
  292. for i in target_size
  293. ]
  294. img_scale_w, img_scale_h = [
  295. target_size[1] / original_size[1],
  296. target_size[0] / original_size[0],
  297. ]
  298. img = F.resize(img, target_size, interp=self.interp)
  299. return {
  300. "img": img,
  301. "img_size": [img.shape[1], img.shape[0]],
  302. "scale_factors": [img_scale_w, img_scale_h],
  303. }
  304. class ResizeByLong(_BaseResize):
  305. """
  306. Proportionally resize the image by specifying the target length of the
  307. longest side.
  308. """
  309. INPUT_KEYS = "img"
  310. OUTPUT_KEYS = ["img", "img_size"]
  311. DEAULT_INPUTS = {"img": "img"}
  312. DEAULT_OUTPUTS = {"img": "img", "img_size": "img_size"}
  313. def __init__(self, target_long_edge, size_divisor=None, interp="LINEAR"):
  314. """
  315. Initialize the instance.
  316. Args:
  317. target_long_edge (int): Target length of the longest side of image.
  318. size_divisor (int|None, optional): Divisor of resized image size.
  319. Default: None.
  320. interp (str, optional): Interpolation method. Choices are 'NEAREST',
  321. 'LINEAR', 'CUBIC', 'AREA', and 'LANCZOS4'. Default: 'LINEAR'.
  322. """
  323. super().__init__(size_divisor=size_divisor, interp=interp)
  324. self.target_long_edge = target_long_edge
  325. def apply(self, img):
  326. """apply"""
  327. h, w = img.shape[:2]
  328. scale = self.target_long_edge / max(h, w)
  329. h_resize = round(h * scale)
  330. w_resize = round(w * scale)
  331. if self.size_divisor is not None:
  332. h_resize = math.ceil(h_resize / self.size_divisor) * self.size_divisor
  333. w_resize = math.ceil(w_resize / self.size_divisor) * self.size_divisor
  334. img = F.resize(img, (w_resize, h_resize), interp=self.interp)
  335. return {"img": img, "img_size": [img.shape[1], img.shape[0]]}
  336. class ResizeByShort(_BaseResize):
  337. """
  338. Proportionally resize the image by specifying the target length of the
  339. shortest side.
  340. """
  341. INPUT_KEYS = "img"
  342. OUTPUT_KEYS = ["img", "img_size"]
  343. DEAULT_INPUTS = {"img": "img"}
  344. DEAULT_OUTPUTS = {"img": "img", "img_size": "img_size"}
  345. def __init__(self, target_short_edge, size_divisor=None, interp="LINEAR"):
  346. """
  347. Initialize the instance.
  348. Args:
  349. target_short_edge (int): Target length of the shortest side of image.
  350. size_divisor (int|None, optional): Divisor of resized image size.
  351. Default: None.
  352. interp (str, optional): Interpolation method. Choices are 'NEAREST',
  353. 'LINEAR', 'CUBIC', 'AREA', and 'LANCZOS4'. Default: 'LINEAR'.
  354. """
  355. super().__init__(size_divisor=size_divisor, interp=interp)
  356. self.target_short_edge = target_short_edge
  357. def apply(self, img):
  358. """apply"""
  359. h, w = img.shape[:2]
  360. scale = self.target_short_edge / min(h, w)
  361. h_resize = round(h * scale)
  362. w_resize = round(w * scale)
  363. if self.size_divisor is not None:
  364. h_resize = math.ceil(h_resize / self.size_divisor) * self.size_divisor
  365. w_resize = math.ceil(w_resize / self.size_divisor) * self.size_divisor
  366. img = F.resize(img, (w_resize, h_resize), interp=self.interp)
  367. return {"img": img, "img_size": [img.shape[1], img.shape[0]]}
  368. class Pad(BaseComponent):
  369. """Pad the image."""
  370. INPUT_KEYS = "img"
  371. OUTPUT_KEYS = ["img", "img_size"]
  372. DEAULT_INPUTS = {"img": "img"}
  373. DEAULT_OUTPUTS = {"img": "img", "img_size": "img_size"}
  374. def __init__(self, target_size, val=127.5):
  375. """
  376. Initialize the instance.
  377. Args:
  378. target_size (list|tuple|int): Target width and height of the image after
  379. padding.
  380. val (float, optional): Value to fill the padded area. Default: 127.5.
  381. """
  382. super().__init__()
  383. if isinstance(target_size, int):
  384. target_size = [target_size, target_size]
  385. _check_image_size(target_size)
  386. self.target_size = target_size
  387. self.val = val
  388. def apply(self, img):
  389. """apply"""
  390. h, w = img.shape[:2]
  391. tw, th = self.target_size
  392. ph = th - h
  393. pw = tw - w
  394. if ph < 0 or pw < 0:
  395. raise ValueError(
  396. f"Input image ({w}, {h}) smaller than the target size ({tw}, {th})."
  397. )
  398. else:
  399. img = F.pad(img, pad=(0, ph, 0, pw), val=self.val)
  400. return {"img": img, "img_size": [img.shape[1], img.shape[0]]}
  401. class PadStride(BaseComponent):
  402. """padding image for model with FPN , instead PadBatch(pad_to_stride, pad_gt) in original config
  403. Args:
  404. stride (bool): model with FPN need image shape % stride == 0
  405. """
  406. INPUT_KEYS = "img"
  407. OUTPUT_KEYS = "img"
  408. DEAULT_INPUTS = {"img": "img"}
  409. DEAULT_OUTPUTS = {"img": "img"}
  410. def __init__(self, stride=0):
  411. super().__init__()
  412. self.coarsest_stride = stride
  413. def apply(self, img):
  414. """
  415. Args:
  416. im (np.ndarray): image (np.ndarray)
  417. Returns:
  418. im (np.ndarray): processed image (np.ndarray)
  419. """
  420. im = img
  421. coarsest_stride = self.coarsest_stride
  422. if coarsest_stride <= 0:
  423. return {"img": im}
  424. im_c, im_h, im_w = im.shape
  425. pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride)
  426. pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride)
  427. padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32)
  428. padding_im[:, :im_h, :im_w] = im
  429. return {"img": padding_im}
  430. class Normalize(BaseComponent):
  431. """Normalize the image."""
  432. INPUT_KEYS = "img"
  433. OUTPUT_KEYS = "img"
  434. DEAULT_INPUTS = {"img": "img"}
  435. DEAULT_OUTPUTS = {"img": "img"}
  436. def __init__(self, scale=1.0 / 255, mean=0.5, std=0.5, preserve_dtype=False):
  437. """
  438. Initialize the instance.
  439. Args:
  440. scale (float, optional): Scaling factor to apply to the image before
  441. applying normalization. Default: 1/255.
  442. mean (float|tuple|list, optional): Means for each channel of the image.
  443. Default: 0.5.
  444. std (float|tuple|list, optional): Standard deviations for each channel
  445. of the image. Default: 0.5.
  446. preserve_dtype (bool, optional): Whether to preserve the original dtype
  447. of the image.
  448. """
  449. super().__init__()
  450. self.scale = np.float32(scale)
  451. if isinstance(mean, float):
  452. mean = [mean]
  453. self.mean = np.asarray(mean).astype("float32")
  454. if isinstance(std, float):
  455. std = [std]
  456. self.std = np.asarray(std).astype("float32")
  457. self.preserve_dtype = preserve_dtype
  458. def apply(self, img):
  459. """apply"""
  460. old_type = img.dtype
  461. # XXX: If `old_type` has higher precision than float32,
  462. # we will lose some precision.
  463. img = img.astype("float32", copy=False)
  464. img *= self.scale
  465. img -= self.mean
  466. img /= self.std
  467. if self.preserve_dtype:
  468. img = img.astype(old_type, copy=False)
  469. return {"img": img}
  470. class ToCHWImage(BaseComponent):
  471. """Reorder the dimensions of the image from HWC to CHW."""
  472. INPUT_KEYS = "img"
  473. OUTPUT_KEYS = "img"
  474. DEAULT_INPUTS = {"img": "img"}
  475. DEAULT_OUTPUTS = {"img": "img"}
  476. def apply(self, img):
  477. """apply"""
  478. img = img.transpose((2, 0, 1))
  479. return {"img": img}