processors.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import numpy as np
  16. import cv2
  17. from PIL import Image
  18. from ..base import PyOnlyProcessor
  19. __all__ = [
  20. "GetImageInfo",
  21. "Flip",
  22. "Crop",
  23. "Resize",
  24. "ResizeByLong",
  25. "ResizeByShort",
  26. "Pad",
  27. "PadStride",
  28. "Normalize",
  29. "ToCHWImage",
  30. "LaTeXOCRReisizeNormImg",
  31. ]
  32. def _resize(im, target_size, interp):
  33. w, h = target_size
  34. im = cv2.resize(im, (w, h), interpolation=interp)
  35. return im
  36. def _flip_h(im):
  37. if len(im.shape) == 3:
  38. im = im[:, ::-1, :]
  39. elif len(im.shape) == 2:
  40. im = im[:, ::-1]
  41. return im
  42. def _flip_v(im):
  43. if len(im.shape) == 3:
  44. im = im[::-1, :, :]
  45. elif len(im.shape) == 2:
  46. im = im[::-1, :]
  47. return im
  48. def _slice(im, coords):
  49. x1, y1, x2, y2 = coords
  50. im = im[y1:y2, x1:x2, ...]
  51. return im
  52. def _pad(im, pad, val):
  53. if isinstance(pad, int):
  54. pad = [pad] * 4
  55. if len(pad) != 4:
  56. raise ValueError
  57. chns = 1 if im.ndim == 2 else im.shape[2]
  58. im = cv2.copyMakeBorder(im, *pad, cv2.BORDER_CONSTANT, value=(val,) * chns)
  59. return im
  60. def _check_image_size(input_):
  61. if not (
  62. isinstance(input_, (list, tuple))
  63. and len(input_) == 2
  64. and isinstance(input_[0], int)
  65. and isinstance(input_[1], int)
  66. ):
  67. raise TypeError(f"{input_} cannot represent a valid image size.")
  68. class GetImageInfo(PyOnlyProcessor):
  69. def __call__(self, data):
  70. img = data["img"]
  71. return {**data, "img_size": [img.shape[1], img.shape[0]]}
  72. class Flip(PyOnlyProcessor):
  73. def __init__(self, mode="H"):
  74. super().__init__()
  75. if mode not in ("H", "V"):
  76. raise ValueError("`mode` should be 'H' or 'V'.")
  77. self._mode = mode
  78. def __call__(self, data):
  79. img = data["img"]
  80. if self._mode == "H":
  81. img = _flip_h(img)
  82. elif self._mode == "V":
  83. img = _flip_v(img)
  84. return {**data, "img": img}
  85. class Crop(PyOnlyProcessor):
  86. def __init__(self, crop_size, mode="C"):
  87. super().__init__()
  88. if isinstance(crop_size, int):
  89. crop_size = [crop_size, crop_size]
  90. _check_image_size(crop_size)
  91. self._crop_size = crop_size
  92. if mode not in ("C", "TL"):
  93. raise ValueError("Unsupported interpolation method")
  94. self._mode = mode
  95. def __call__(self, data):
  96. img = data["img"]
  97. h, w = img.shape[:2]
  98. cw, ch = self._crop_size
  99. if self._mode == "C":
  100. x1 = max(0, (w - cw) // 2)
  101. y1 = max(0, (h - ch) // 2)
  102. elif self._mode == "TL":
  103. x1, y1 = 0, 0
  104. x2 = min(w, x1 + cw)
  105. y2 = min(h, y1 + ch)
  106. coords = (x1, y1, x2, y2)
  107. if coords == (0, 0, w, h):
  108. raise ValueError(
  109. f"Input image ({w}, {h}) smaller than the target size ({cw}, {ch})."
  110. )
  111. img = _slice(img, coords=coords)
  112. return {**data, "img": img, "img_size": [img.shape[1], img.shape[0]]}
  113. class _BaseResize(PyOnlyProcessor):
  114. _INTERP_DICT = {
  115. "NEAREST": cv2.INTER_NEAREST,
  116. "LINEAR": cv2.INTER_LINEAR,
  117. "CUBIC": cv2.INTER_CUBIC,
  118. "AREA": cv2.INTER_AREA,
  119. "LANCZOS4": cv2.INTER_LANCZOS4,
  120. }
  121. def __init__(self, size_divisor, interp):
  122. super().__init__()
  123. if size_divisor is not None:
  124. assert isinstance(
  125. size_divisor, int
  126. ), "`size_divisor` should be None or int."
  127. self._size_divisor = size_divisor
  128. try:
  129. interp = self._INTERP_DICT[interp]
  130. except KeyError:
  131. raise ValueError(
  132. "`interp` should be one of {}.".format(self._INTERP_DICT.keys())
  133. )
  134. self._interp = interp
  135. @staticmethod
  136. def _rescale_size(img_size, target_size):
  137. scale = min(max(target_size) / max(img_size), min(target_size) / min(img_size))
  138. rescaled_size = [round(i * scale) for i in img_size]
  139. return rescaled_size, scale
  140. class Resize(_BaseResize):
  141. def __init__(
  142. self, target_size, keep_ratio=False, size_divisor=None, interp="LINEAR"
  143. ):
  144. super().__init__(size_divisor=size_divisor, interp=interp)
  145. if isinstance(target_size, int):
  146. target_size = [target_size, target_size]
  147. _check_image_size(target_size)
  148. self._target_size = target_size
  149. self._keep_ratio = keep_ratio
  150. def __call__(self, data):
  151. img = data["img"]
  152. target_size = self._target_size
  153. original_size = img.shape[:2][::-1]
  154. if self._keep_ratio:
  155. h, w = img.shape[0:2]
  156. target_size, _ = self._rescale_size((w, h), self._target_size)
  157. if self._size_divisor:
  158. target_size = [
  159. math.ceil(i / self._size_divisor) * self._size_divisor
  160. for i in target_size
  161. ]
  162. img_scale_w, img_scale_h = [
  163. target_size[0] / original_size[0],
  164. target_size[1] / original_size[1],
  165. ]
  166. img = _resize(img, target_size, interp=self._interp)
  167. return {
  168. **data,
  169. "img": img,
  170. "img_size": [img.shape[1], img.shape[0]],
  171. "scale_factors": [img_scale_w, img_scale_h],
  172. }
  173. class ResizeByLong(_BaseResize):
  174. def __init__(self, target_long_edge, size_divisor=None, interp="LINEAR"):
  175. super().__init__(size_divisor=size_divisor, interp=interp)
  176. self._target_long_edge = target_long_edge
  177. def __call__(self, data):
  178. img = data["img"]
  179. h, w = img.shape[:2]
  180. scale = self._target_long_edge / max(h, w)
  181. h_resize = round(h * scale)
  182. w_resize = round(w * scale)
  183. if self._size_divisor is not None:
  184. h_resize = math.ceil(h_resize / self._size_divisor) * self._size_divisor
  185. w_resize = math.ceil(w_resize / self._size_divisor) * self._size_divisor
  186. img = _resize(img, (w_resize, h_resize), interp=self._interp)
  187. return {**data, "img": img, "img_size": [img.shape[1], img.shape[0]]}
  188. class ResizeByShort(_BaseResize):
  189. INPUT_KEYS = "img"
  190. OUTPUT_KEYS = ["img", "img_size"]
  191. DEAULT_INPUTS = {"img": "img"}
  192. DEAULT_OUTPUTS = {"img": "img", "img_size": "img_size"}
  193. def __init__(self, target_short_edge, size_divisor=None, interp="LINEAR"):
  194. super().__init__(size_divisor=size_divisor, interp=interp)
  195. self._target_short_edge = target_short_edge
  196. def __call__(self, data):
  197. img = data["img"]
  198. h, w = img.shape[:2]
  199. scale = self._target_short_edge / min(h, w)
  200. h_resize = round(h * scale)
  201. w_resize = round(w * scale)
  202. if self._size_divisor is not None:
  203. h_resize = math.ceil(h_resize / self._size_divisor) * self._size_divisor
  204. w_resize = math.ceil(w_resize / self._size_divisor) * self._size_divisor
  205. img = _resize(img, (w_resize, h_resize), interp=self._interp)
  206. return {**data, "img": img, "img_size": [img.shape[1], img.shape[0]]}
  207. class Pad(PyOnlyProcessor):
  208. def __init__(self, target_size, val=127.5):
  209. super().__init__()
  210. if isinstance(target_size, int):
  211. target_size = [target_size, target_size]
  212. _check_image_size(target_size)
  213. self._target_size = target_size
  214. self._val = val
  215. def __call__(self, data):
  216. img = data["img"]
  217. h, w = img.shape[:2]
  218. tw, th = self._target_size
  219. ph = th - h
  220. pw = tw - w
  221. if ph < 0 or pw < 0:
  222. raise ValueError(
  223. f"Input image ({w}, {h}) smaller than the target size ({tw}, {th})."
  224. )
  225. else:
  226. img = _pad(img, pad=(0, ph, 0, pw), val=self._val)
  227. return {**data, "img": img, "img_size": [img.shape[1], img.shape[0]]}
  228. class PadStride(PyOnlyProcessor):
  229. INPUT_KEYS = "img"
  230. OUTPUT_KEYS = "img"
  231. DEAULT_INPUTS = {"img": "img"}
  232. DEAULT_OUTPUTS = {"img": "img"}
  233. def __init__(self, stride=0):
  234. super().__init__()
  235. self._coarsest_stride = stride
  236. def __call__(self, data):
  237. img = data["img"]
  238. im = img
  239. coarsest_stride = self._coarsest_stride
  240. if coarsest_stride <= 0:
  241. return {"img": im}
  242. im_c, im_h, im_w = im.shape
  243. pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride)
  244. pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride)
  245. padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32)
  246. padding_im[:, :im_h, :im_w] = im
  247. return {**data, "img": padding_im}
  248. class Normalize(PyOnlyProcessor):
  249. def __init__(self, scale=1.0 / 255, mean=0.5, std=0.5, preserve_dtype=False):
  250. super().__init__()
  251. self._scale = np.float32(scale)
  252. if isinstance(mean, float):
  253. mean = [mean]
  254. self._mean = np.asarray(mean).astype("float32")
  255. if isinstance(std, float):
  256. std = [std]
  257. self._std = np.asarray(std).astype("float32")
  258. self._preserve_dtype = preserve_dtype
  259. def __call__(self, data):
  260. img = data["img"]
  261. old_type = img.dtype
  262. # XXX: If `old_type` has higher precision than float32,
  263. # we will lose some precision.
  264. img = img.astype("float32", copy=False)
  265. img *= self._scale
  266. img -= self._mean
  267. img /= self._std
  268. if self._preserve_dtype:
  269. img = img.astype(old_type, copy=False)
  270. return {**data, "img": img}
  271. class ToCHWImage(PyOnlyProcessor):
  272. def __call__(self, data):
  273. img = data["img"]
  274. img = img.transpose((2, 0, 1))
  275. return {**data, "img": img}
  276. class BGR2RGB(PyOnlyProcessor):
  277. def __call__(self, data):
  278. img = data["img"]
  279. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  280. return {**data, "img": img}
  281. class LaTeXOCRReisizeNormImg(PyOnlyProcessor):
  282. """for ocr image resize and normalization"""
  283. def __init__(self, rec_image_shape=(3, 48, 320)):
  284. super().__init__()
  285. self.rec_image_shape = rec_image_shape
  286. def pad_(self, img, divable=32):
  287. threshold = 128
  288. data = np.array(img.convert("LA"))
  289. if data[..., -1].var() == 0:
  290. data = (data[..., 0]).astype(np.uint8)
  291. else:
  292. data = (255 - data[..., -1]).astype(np.uint8)
  293. data = (data - data.min()) / (data.max() - data.min()) * 255
  294. if data.mean() > threshold:
  295. # To invert the text to white
  296. gray = 255 * (data < threshold).astype(np.uint8)
  297. else:
  298. gray = 255 * (data > threshold).astype(np.uint8)
  299. data = 255 - data
  300. coords = cv2.findNonZero(gray) # Find all non-zero points (text)
  301. a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box
  302. rect = data[b : b + h, a : a + w]
  303. im = Image.fromarray(rect).convert("L")
  304. dims = []
  305. for x in [w, h]:
  306. div, mod = divmod(x, divable)
  307. dims.append(divable * (div + (1 if mod > 0 else 0)))
  308. padded = Image.new("L", dims, 255)
  309. padded.paste(im, (0, 0, im.size[0], im.size[1]))
  310. return padded
  311. def minmax_size_(
  312. self,
  313. img,
  314. max_dimensions,
  315. min_dimensions,
  316. ):
  317. if max_dimensions is not None:
  318. ratios = [a / b for a, b in zip(img.size, max_dimensions)]
  319. if any([r > 1 for r in ratios]):
  320. size = np.array(img.size) // max(ratios)
  321. img = img.resize(tuple(size.astype(int)), Image.BILINEAR)
  322. if min_dimensions is not None:
  323. # hypothesis: there is a dim in img smaller than min_dimensions, and return a proper dim >= min_dimensions
  324. padded_size = [
  325. max(img_dim, min_dim)
  326. for img_dim, min_dim in zip(img.size, min_dimensions)
  327. ]
  328. if padded_size != list(img.size): # assert hypothesis
  329. padded_im = Image.new("L", padded_size, 255)
  330. padded_im.paste(img, img.getbbox())
  331. img = padded_im
  332. return img
  333. def norm_img_latexocr(self, img):
  334. # CAN only predict gray scale image
  335. shape = (1, 1, 3)
  336. mean = [0.7931, 0.7931, 0.7931]
  337. std = [0.1738, 0.1738, 0.1738]
  338. scale = np.float32(1.0 / 255.0)
  339. min_dimensions = [32, 32]
  340. max_dimensions = [672, 192]
  341. mean = np.array(mean).reshape(shape).astype("float32")
  342. std = np.array(std).reshape(shape).astype("float32")
  343. im_h, im_w = img.shape[:2]
  344. if (
  345. min_dimensions[0] <= im_w <= max_dimensions[0]
  346. and min_dimensions[1] <= im_h <= max_dimensions[1]
  347. ):
  348. pass
  349. else:
  350. img = Image.fromarray(np.uint8(img))
  351. img = self.minmax_size_(self.pad_(img), max_dimensions, min_dimensions)
  352. img = np.array(img)
  353. im_h, im_w = img.shape[:2]
  354. img = np.dstack([img, img, img])
  355. img = (img.astype("float32") * scale - mean) / std
  356. img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  357. divide_h = math.ceil(im_h / 16) * 16
  358. divide_w = math.ceil(im_w / 16) * 16
  359. img = np.pad(
  360. img, ((0, divide_h - im_h), (0, divide_w - im_w)), constant_values=(1, 1)
  361. )
  362. img = img[:, :, np.newaxis].transpose(2, 0, 1)
  363. img = img.astype("float32")
  364. return img
  365. def __call__(self, data):
  366. """apply"""
  367. img = data["img"]
  368. img = self.norm_img_latexocr(img)
  369. return {"img": img}