common.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import ast
  15. import math
  16. from pathlib import Path
  17. from copy import deepcopy
  18. import numpy as np
  19. import cv2
  20. from .....utils.flags import (
  21. INFER_BENCHMARK,
  22. INFER_BENCHMARK_ITER,
  23. INFER_BENCHMARK_DATA_SIZE,
  24. )
  25. from .....utils.cache import CACHE_DIR, temp_file_manager
  26. from ....utils.io import ImageReader, ImageWriter, PDFReader
  27. from ...base import BaseComponent
  28. from ..read_data import _BaseRead
  29. from . import funcs as F
  30. __all__ = [
  31. "ReadImage",
  32. "Flip",
  33. "Crop",
  34. "Resize",
  35. "ResizeByLong",
  36. "ResizeByShort",
  37. "Pad",
  38. "Normalize",
  39. "ToCHWImage",
  40. "PadStride",
  41. ]
  42. def _check_image_size(input_):
  43. """check image size"""
  44. if not (
  45. isinstance(input_, (list, tuple))
  46. and len(input_) == 2
  47. and isinstance(input_[0], int)
  48. and isinstance(input_[1], int)
  49. ):
  50. raise TypeError(f"{input_} cannot represent a valid image size.")
  51. class ReadImage(_BaseRead):
  52. """Load image from the file."""
  53. INPUT_KEYS = ["img"]
  54. OUTPUT_KEYS = ["img", "img_size", "ori_img", "ori_img_size"]
  55. DEAULT_INPUTS = {"img": "img"}
  56. DEAULT_OUTPUTS = {
  57. "img": "img",
  58. "input_path": "input_path",
  59. "img_size": "img_size",
  60. "ori_img": "ori_img",
  61. "ori_img_size": "ori_img_size",
  62. }
  63. _FLAGS_DICT = {
  64. "BGR": cv2.IMREAD_COLOR,
  65. "RGB": cv2.IMREAD_COLOR,
  66. "GRAY": cv2.IMREAD_GRAYSCALE,
  67. }
  68. SUFFIX = ["jpg", "png", "jpeg", "JPEG", "JPG", "bmp", "PDF", "pdf"]
  69. def __init__(self, batch_size=1, format="BGR"):
  70. """
  71. Initialize the instance.
  72. Args:
  73. format (str, optional): Target color format to convert the image to.
  74. Choices are 'BGR', 'RGB', and 'GRAY'. Default: 'BGR'.
  75. """
  76. super().__init__(batch_size)
  77. self.format = format
  78. flags = self._FLAGS_DICT[self.format]
  79. self._img_reader = ImageReader(backend="opencv", flags=flags)
  80. self._pdf_reader = PDFReader()
  81. self._writer = ImageWriter(backend="opencv")
  82. def apply(self, img):
  83. """apply"""
  84. def rand_data():
  85. def parse_size(s):
  86. res = ast.literal_eval(s)
  87. if isinstance(res, int):
  88. return (res, res)
  89. else:
  90. assert isinstance(res, (tuple, list))
  91. assert len(res) == 2
  92. assert all(isinstance(item, int) for item in res)
  93. return res
  94. size = parse_size(INFER_BENCHMARK_DATA_SIZE)
  95. return np.random.randint(0, 256, (*size, 3), dtype=np.uint8)
  96. def process_ndarray(img):
  97. with temp_file_manager.temp_file_context(suffix=".png") as temp_file:
  98. img_path = Path(temp_file.name)
  99. self._writer.write(img_path, img)
  100. if self.format == "RGB":
  101. img = img[:, :, ::-1]
  102. return {
  103. "input_path": img_path,
  104. "img": img,
  105. "img_size": [img.shape[1], img.shape[0]],
  106. "ori_img": deepcopy(img),
  107. "ori_img_size": deepcopy([img.shape[1], img.shape[0]]),
  108. }
  109. if INFER_BENCHMARK and img is None:
  110. for _ in range(INFER_BENCHMARK_ITER):
  111. yield [process_ndarray(rand_data()) for _ in range(self.batch_size)]
  112. elif isinstance(img, np.ndarray):
  113. yield [process_ndarray(img)]
  114. elif isinstance(img, str):
  115. file_path = img
  116. file_path = self._download_from_url(file_path)
  117. file_list = self._get_files_list(file_path)
  118. batch = []
  119. for file_path in file_list:
  120. img = self._read(file_path)
  121. batch.extend(img)
  122. if len(batch) >= self.batch_size:
  123. yield batch
  124. batch = []
  125. if len(batch) > 0:
  126. yield batch
  127. else:
  128. raise TypeError(
  129. f"ReadImage only supports the following types:\n"
  130. f"1. str, indicating a image file path or a directory containing image files.\n"
  131. f"2. numpy.ndarray.\n"
  132. f"However, got type: {type(img).__name__}."
  133. )
  134. def _read(self, file_path):
  135. if str(file_path).lower().endswith(".pdf"):
  136. return self._read_pdf(file_path)
  137. else:
  138. return self._read_img(file_path)
  139. def _read_img(self, img_path):
  140. blob = self._img_reader.read(img_path)
  141. if blob is None:
  142. raise Exception("Image read Error")
  143. if self.format == "RGB":
  144. if blob.ndim != 3:
  145. raise RuntimeError("Array is not 3-dimensional.")
  146. # BGR to RGB
  147. blob = blob[..., ::-1]
  148. return [
  149. {
  150. "input_path": img_path,
  151. "img": blob,
  152. "img_size": [blob.shape[1], blob.shape[0]],
  153. "ori_img": deepcopy(blob),
  154. "ori_img_size": deepcopy([blob.shape[1], blob.shape[0]]),
  155. }
  156. ]
  157. def _read_pdf(self, pdf_path):
  158. img_list = self._pdf_reader.read(pdf_path)
  159. return [
  160. {
  161. "input_path": pdf_path,
  162. "img": img,
  163. "img_size": [img.shape[1], img.shape[0]],
  164. "ori_img": deepcopy(img),
  165. "ori_img_size": deepcopy([img.shape[1], img.shape[0]]),
  166. }
  167. for img in img_list
  168. ]
  169. class GetImageInfo(BaseComponent):
  170. """Get Image Info"""
  171. INPUT_KEYS = "img"
  172. OUTPUT_KEYS = "img_size"
  173. DEAULT_INPUTS = {"img": "img"}
  174. DEAULT_OUTPUTS = {"img_size": "img_size"}
  175. def __init__(self):
  176. super().__init__()
  177. def apply(self, img):
  178. """apply"""
  179. return {"img_size": [img.shape[1], img.shape[0]]}
  180. class Flip(BaseComponent):
  181. """Flip the image vertically or horizontally."""
  182. INPUT_KEYS = "img"
  183. OUTPUT_KEYS = "img"
  184. DEAULT_INPUTS = {"img": "img"}
  185. DEAULT_OUTPUTS = {"img": "img"}
  186. def __init__(self, mode="H"):
  187. """
  188. Initialize the instance.
  189. Args:
  190. mode (str, optional): 'H' for horizontal flipping and 'V' for vertical
  191. flipping. Default: 'H'.
  192. """
  193. super().__init__()
  194. if mode not in ("H", "V"):
  195. raise ValueError("`mode` should be 'H' or 'V'.")
  196. self.mode = mode
  197. def apply(self, img):
  198. """apply"""
  199. if self.mode == "H":
  200. img = F.flip_h(img)
  201. elif self.mode == "V":
  202. img = F.flip_v(img)
  203. return {"img": img}
  204. class Crop(BaseComponent):
  205. """Crop region from the image."""
  206. INPUT_KEYS = "img"
  207. OUTPUT_KEYS = ["img", "img_size"]
  208. DEAULT_INPUTS = {"img": "img"}
  209. DEAULT_OUTPUTS = {"img": "img", "img_size": "img_size"}
  210. def __init__(self, crop_size, mode="C"):
  211. """
  212. Initialize the instance.
  213. Args:
  214. crop_size (list|tuple|int): Width and height of the region to crop.
  215. mode (str, optional): 'C' for cropping the center part and 'TL' for
  216. cropping the top left part. Default: 'C'.
  217. """
  218. super().__init__()
  219. if isinstance(crop_size, int):
  220. crop_size = [crop_size, crop_size]
  221. _check_image_size(crop_size)
  222. self.crop_size = crop_size
  223. if mode not in ("C", "TL"):
  224. raise ValueError("Unsupported interpolation method")
  225. self.mode = mode
  226. def apply(self, img):
  227. """apply"""
  228. h, w = img.shape[:2]
  229. cw, ch = self.crop_size
  230. if self.mode == "C":
  231. x1 = max(0, (w - cw) // 2)
  232. y1 = max(0, (h - ch) // 2)
  233. elif self.mode == "TL":
  234. x1, y1 = 0, 0
  235. x2 = min(w, x1 + cw)
  236. y2 = min(h, y1 + ch)
  237. coords = (x1, y1, x2, y2)
  238. if coords == (0, 0, w, h):
  239. raise ValueError(
  240. f"Input image ({w}, {h}) smaller than the target size ({cw}, {ch})."
  241. )
  242. img = F.slice(img, coords=coords)
  243. return {"img": img, "img_size": [img.shape[1], img.shape[0]]}
  244. class _BaseResize(BaseComponent):
  245. _INTERP_DICT = {
  246. "NEAREST": cv2.INTER_NEAREST,
  247. "LINEAR": cv2.INTER_LINEAR,
  248. "CUBIC": cv2.INTER_CUBIC,
  249. "AREA": cv2.INTER_AREA,
  250. "LANCZOS4": cv2.INTER_LANCZOS4,
  251. }
  252. def __init__(self, size_divisor, interp):
  253. super().__init__()
  254. if size_divisor is not None:
  255. assert isinstance(
  256. size_divisor, int
  257. ), "`size_divisor` should be None or int."
  258. self.size_divisor = size_divisor
  259. try:
  260. interp = self._INTERP_DICT[interp]
  261. except KeyError:
  262. raise ValueError(
  263. "`interp` should be one of {}.".format(self._INTERP_DICT.keys())
  264. )
  265. self.interp = interp
  266. @staticmethod
  267. def _rescale_size(img_size, target_size):
  268. """rescale size"""
  269. scale = min(max(target_size) / max(img_size), min(target_size) / min(img_size))
  270. rescaled_size = [round(i * scale) for i in img_size]
  271. return rescaled_size, scale
  272. class Resize(_BaseResize):
  273. """Resize the image."""
  274. INPUT_KEYS = "img"
  275. OUTPUT_KEYS = ["img", "img_size", "scale_factors"]
  276. DEAULT_INPUTS = {"img": "img"}
  277. DEAULT_OUTPUTS = {
  278. "img": "img",
  279. "img_size": "img_size",
  280. "scale_factors": "scale_factors",
  281. }
  282. def __init__(
  283. self, target_size, keep_ratio=False, size_divisor=None, interp="LINEAR"
  284. ):
  285. """
  286. Initialize the instance.
  287. Args:
  288. target_size (list|tuple|int): Target width and height.
  289. keep_ratio (bool, optional): Whether to keep the aspect ratio of resized
  290. image. Default: False.
  291. size_divisor (int|None, optional): Divisor of resized image size.
  292. Default: None.
  293. interp (str, optional): Interpolation method. Choices are 'NEAREST',
  294. 'LINEAR', 'CUBIC', 'AREA', and 'LANCZOS4'. Default: 'LINEAR'.
  295. """
  296. super().__init__(size_divisor=size_divisor, interp=interp)
  297. if isinstance(target_size, int):
  298. target_size = [target_size, target_size]
  299. _check_image_size(target_size)
  300. self.target_size = target_size
  301. self.keep_ratio = keep_ratio
  302. def apply(self, img):
  303. """apply"""
  304. target_size = self.target_size
  305. original_size = img.shape[:2][::-1]
  306. if self.keep_ratio:
  307. h, w = img.shape[0:2]
  308. target_size, _ = self._rescale_size((w, h), self.target_size)
  309. if self.size_divisor:
  310. target_size = [
  311. math.ceil(i / self.size_divisor) * self.size_divisor
  312. for i in target_size
  313. ]
  314. img_scale_w, img_scale_h = [
  315. target_size[0] / original_size[0],
  316. target_size[1] / original_size[1],
  317. ]
  318. img = F.resize(img, target_size, interp=self.interp)
  319. return {
  320. "img": img,
  321. "img_size": [img.shape[1], img.shape[0]],
  322. "scale_factors": [img_scale_w, img_scale_h],
  323. }
  324. class ResizeByLong(_BaseResize):
  325. """
  326. Proportionally resize the image by specifying the target length of the
  327. longest side.
  328. """
  329. INPUT_KEYS = "img"
  330. OUTPUT_KEYS = ["img", "img_size"]
  331. DEAULT_INPUTS = {"img": "img"}
  332. DEAULT_OUTPUTS = {"img": "img", "img_size": "img_size"}
  333. def __init__(self, target_long_edge, size_divisor=None, interp="LINEAR"):
  334. """
  335. Initialize the instance.
  336. Args:
  337. target_long_edge (int): Target length of the longest side of image.
  338. size_divisor (int|None, optional): Divisor of resized image size.
  339. Default: None.
  340. interp (str, optional): Interpolation method. Choices are 'NEAREST',
  341. 'LINEAR', 'CUBIC', 'AREA', and 'LANCZOS4'. Default: 'LINEAR'.
  342. """
  343. super().__init__(size_divisor=size_divisor, interp=interp)
  344. self.target_long_edge = target_long_edge
  345. def apply(self, img):
  346. """apply"""
  347. h, w = img.shape[:2]
  348. scale = self.target_long_edge / max(h, w)
  349. h_resize = round(h * scale)
  350. w_resize = round(w * scale)
  351. if self.size_divisor is not None:
  352. h_resize = math.ceil(h_resize / self.size_divisor) * self.size_divisor
  353. w_resize = math.ceil(w_resize / self.size_divisor) * self.size_divisor
  354. img = F.resize(img, (w_resize, h_resize), interp=self.interp)
  355. return {"img": img, "img_size": [img.shape[1], img.shape[0]]}
  356. class ResizeByShort(_BaseResize):
  357. """
  358. Proportionally resize the image by specifying the target length of the
  359. shortest side.
  360. """
  361. INPUT_KEYS = "img"
  362. OUTPUT_KEYS = ["img", "img_size"]
  363. DEAULT_INPUTS = {"img": "img"}
  364. DEAULT_OUTPUTS = {"img": "img", "img_size": "img_size"}
  365. def __init__(self, target_short_edge, size_divisor=None, interp="LINEAR"):
  366. """
  367. Initialize the instance.
  368. Args:
  369. target_short_edge (int): Target length of the shortest side of image.
  370. size_divisor (int|None, optional): Divisor of resized image size.
  371. Default: None.
  372. interp (str, optional): Interpolation method. Choices are 'NEAREST',
  373. 'LINEAR', 'CUBIC', 'AREA', and 'LANCZOS4'. Default: 'LINEAR'.
  374. """
  375. super().__init__(size_divisor=size_divisor, interp=interp)
  376. self.target_short_edge = target_short_edge
  377. def apply(self, img):
  378. """apply"""
  379. h, w = img.shape[:2]
  380. scale = self.target_short_edge / min(h, w)
  381. h_resize = round(h * scale)
  382. w_resize = round(w * scale)
  383. if self.size_divisor is not None:
  384. h_resize = math.ceil(h_resize / self.size_divisor) * self.size_divisor
  385. w_resize = math.ceil(w_resize / self.size_divisor) * self.size_divisor
  386. img = F.resize(img, (w_resize, h_resize), interp=self.interp)
  387. return {"img": img, "img_size": [img.shape[1], img.shape[0]]}
  388. class Pad(BaseComponent):
  389. """Pad the image."""
  390. INPUT_KEYS = "img"
  391. OUTPUT_KEYS = ["img", "img_size"]
  392. DEAULT_INPUTS = {"img": "img"}
  393. DEAULT_OUTPUTS = {"img": "img", "img_size": "img_size"}
  394. def __init__(self, target_size, val=127.5):
  395. """
  396. Initialize the instance.
  397. Args:
  398. target_size (list|tuple|int): Target width and height of the image after
  399. padding.
  400. val (float, optional): Value to fill the padded area. Default: 127.5.
  401. """
  402. super().__init__()
  403. if isinstance(target_size, int):
  404. target_size = [target_size, target_size]
  405. _check_image_size(target_size)
  406. self.target_size = target_size
  407. self.val = val
  408. def apply(self, img):
  409. """apply"""
  410. h, w = img.shape[:2]
  411. tw, th = self.target_size
  412. ph = th - h
  413. pw = tw - w
  414. if ph < 0 or pw < 0:
  415. raise ValueError(
  416. f"Input image ({w}, {h}) smaller than the target size ({tw}, {th})."
  417. )
  418. else:
  419. img = F.pad(img, pad=(0, ph, 0, pw), val=self.val)
  420. return {"img": img, "img_size": [img.shape[1], img.shape[0]]}
  421. class PadStride(BaseComponent):
  422. """padding image for model with FPN , instead PadBatch(pad_to_stride, pad_gt) in original config
  423. Args:
  424. stride (bool): model with FPN need image shape % stride == 0
  425. """
  426. INPUT_KEYS = "img"
  427. OUTPUT_KEYS = "img"
  428. DEAULT_INPUTS = {"img": "img"}
  429. DEAULT_OUTPUTS = {"img": "img"}
  430. def __init__(self, stride=0):
  431. super().__init__()
  432. self.coarsest_stride = stride
  433. def apply(self, img):
  434. """
  435. Args:
  436. im (np.ndarray): image (np.ndarray)
  437. Returns:
  438. im (np.ndarray): processed image (np.ndarray)
  439. """
  440. im = img
  441. coarsest_stride = self.coarsest_stride
  442. if coarsest_stride <= 0:
  443. return {"img": im}
  444. im_c, im_h, im_w = im.shape
  445. pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride)
  446. pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride)
  447. padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32)
  448. padding_im[:, :im_h, :im_w] = im
  449. return {"img": padding_im}
  450. class Normalize(BaseComponent):
  451. """Normalize the image."""
  452. INPUT_KEYS = "img"
  453. OUTPUT_KEYS = "img"
  454. DEAULT_INPUTS = {"img": "img"}
  455. DEAULT_OUTPUTS = {"img": "img"}
  456. def __init__(self, scale=1.0 / 255, mean=0.5, std=0.5, preserve_dtype=False):
  457. """
  458. Initialize the instance.
  459. Args:
  460. scale (float, optional): Scaling factor to apply to the image before
  461. applying normalization. Default: 1/255.
  462. mean (float|tuple|list, optional): Means for each channel of the image.
  463. Default: 0.5.
  464. std (float|tuple|list, optional): Standard deviations for each channel
  465. of the image. Default: 0.5.
  466. preserve_dtype (bool, optional): Whether to preserve the original dtype
  467. of the image.
  468. """
  469. super().__init__()
  470. self.scale = np.float32(scale)
  471. if isinstance(mean, float):
  472. mean = [mean]
  473. self.mean = np.asarray(mean).astype("float32")
  474. if isinstance(std, float):
  475. std = [std]
  476. self.std = np.asarray(std).astype("float32")
  477. self.preserve_dtype = preserve_dtype
  478. def apply(self, img):
  479. """apply"""
  480. old_type = img.dtype
  481. # XXX: If `old_type` has higher precision than float32,
  482. # we will lose some precision.
  483. img = img.astype("float32", copy=False)
  484. img *= self.scale
  485. img -= self.mean
  486. img /= self.std
  487. if self.preserve_dtype:
  488. img = img.astype(old_type, copy=False)
  489. return {"img": img}
  490. class ToCHWImage(BaseComponent):
  491. """Reorder the dimensions of the image from HWC to CHW."""
  492. INPUT_KEYS = "img"
  493. OUTPUT_KEYS = "img"
  494. DEAULT_INPUTS = {"img": "img"}
  495. DEAULT_OUTPUTS = {"img": "img"}
  496. def apply(self, img):
  497. """apply"""
  498. img = img.transpose((2, 0, 1))
  499. return {"img": img}