transforms.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import os.path as osp
  16. import re
  17. import numpy as np
  18. from PIL import Image
  19. import cv2
  20. import math
  21. import lazy_paddle as paddle
  22. import json
  23. import tempfile
  24. from tokenizers import Tokenizer as TokenizerFast
  25. from ....utils import logging
  26. from ...base.predictor import BaseTransform
  27. from ...base.predictor.io.writers import TextWriter
  28. from .keys import TextRecKeys as K
  29. __all__ = [
  30. "OCRReisizeNormImg",
  31. "LaTeXOCRReisizeNormImg",
  32. "CTCLabelDecode",
  33. "LaTeXOCRDecode",
  34. "SaveTextRecResults",
  35. ]
  36. class OCRReisizeNormImg(BaseTransform):
  37. """for ocr image resize and normalization"""
  38. def __init__(self, rec_image_shape=[3, 48, 320]):
  39. super().__init__()
  40. self.rec_image_shape = rec_image_shape
  41. def resize_norm_img(self, img, max_wh_ratio):
  42. """resize and normalize the img"""
  43. imgC, imgH, imgW = self.rec_image_shape
  44. assert imgC == img.shape[2]
  45. imgW = int((imgH * max_wh_ratio))
  46. h, w = img.shape[:2]
  47. ratio = w / float(h)
  48. if math.ceil(imgH * ratio) > imgW:
  49. resized_w = imgW
  50. else:
  51. resized_w = int(math.ceil(imgH * ratio))
  52. resized_image = cv2.resize(img, (resized_w, imgH))
  53. resized_image = resized_image.astype("float32")
  54. resized_image = resized_image.transpose((2, 0, 1)) / 255
  55. resized_image -= 0.5
  56. resized_image /= 0.5
  57. padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
  58. padding_im[:, :, 0:resized_w] = resized_image
  59. return padding_im
  60. def apply(self, data):
  61. """apply"""
  62. imgC, imgH, imgW = self.rec_image_shape
  63. max_wh_ratio = imgW / imgH
  64. w, h = data[K.ORI_IM_SIZE]
  65. wh_ratio = w * 1.0 / h
  66. max_wh_ratio = max(max_wh_ratio, wh_ratio)
  67. data[K.IMAGE] = self.resize_norm_img(data[K.IMAGE], max_wh_ratio)
  68. return data
  69. @classmethod
  70. def get_input_keys(cls):
  71. """get input keys"""
  72. return [K.IMAGE, K.ORI_IM_SIZE]
  73. @classmethod
  74. def get_output_keys(cls):
  75. """get output keys"""
  76. return [K.IMAGE]
  77. class LaTeXOCRReisizeNormImg(BaseTransform):
  78. """for ocr image resize and normalization"""
  79. def __init__(self, rec_image_shape=[3, 48, 320]):
  80. super().__init__()
  81. self.rec_image_shape = rec_image_shape
  82. def pad_(self, img, divable=32):
  83. threshold = 128
  84. data = np.array(img.convert("LA"))
  85. if data[..., -1].var() == 0:
  86. data = (data[..., 0]).astype(np.uint8)
  87. else:
  88. data = (255 - data[..., -1]).astype(np.uint8)
  89. data = (data - data.min()) / (data.max() - data.min()) * 255
  90. if data.mean() > threshold:
  91. # To invert the text to white
  92. gray = 255 * (data < threshold).astype(np.uint8)
  93. else:
  94. gray = 255 * (data > threshold).astype(np.uint8)
  95. data = 255 - data
  96. coords = cv2.findNonZero(gray) # Find all non-zero points (text)
  97. a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box
  98. rect = data[b : b + h, a : a + w]
  99. im = Image.fromarray(rect).convert("L")
  100. dims = []
  101. for x in [w, h]:
  102. div, mod = divmod(x, divable)
  103. dims.append(divable * (div + (1 if mod > 0 else 0)))
  104. padded = Image.new("L", dims, 255)
  105. padded.paste(im, (0, 0, im.size[0], im.size[1]))
  106. return padded
  107. def minmax_size_(
  108. self,
  109. img,
  110. max_dimensions,
  111. min_dimensions,
  112. ):
  113. if max_dimensions is not None:
  114. ratios = [a / b for a, b in zip(img.size, max_dimensions)]
  115. if any([r > 1 for r in ratios]):
  116. size = np.array(img.size) // max(ratios)
  117. img = img.resize(tuple(size.astype(int)), Image.BILINEAR)
  118. if min_dimensions is not None:
  119. # hypothesis: there is a dim in img smaller than min_dimensions, and return a proper dim >= min_dimensions
  120. padded_size = [
  121. max(img_dim, min_dim)
  122. for img_dim, min_dim in zip(img.size, min_dimensions)
  123. ]
  124. if padded_size != list(img.size): # assert hypothesis
  125. padded_im = Image.new("L", padded_size, 255)
  126. padded_im.paste(img, img.getbbox())
  127. img = padded_im
  128. return img
  129. def norm_img_latexocr(self, img):
  130. # CAN only predict gray scale image
  131. shape = (1, 1, 3)
  132. mean = [0.7931, 0.7931, 0.7931]
  133. std = [0.1738, 0.1738, 0.1738]
  134. scale = np.float32(1.0 / 255.0)
  135. min_dimensions = [32, 32]
  136. max_dimensions = [672, 192]
  137. mean = np.array(mean).reshape(shape).astype("float32")
  138. std = np.array(std).reshape(shape).astype("float32")
  139. im_h, im_w = img.shape[:2]
  140. if (
  141. min_dimensions[0] <= im_w <= max_dimensions[0]
  142. and min_dimensions[1] <= im_h <= max_dimensions[1]
  143. ):
  144. pass
  145. else:
  146. img = Image.fromarray(np.uint8(img))
  147. img = self.minmax_size_(self.pad_(img), max_dimensions, min_dimensions)
  148. img = np.array(img)
  149. im_h, im_w = img.shape[:2]
  150. img = np.dstack([img, img, img])
  151. img = (img.astype("float32") * scale - mean) / std
  152. img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  153. divide_h = math.ceil(im_h / 16) * 16
  154. divide_w = math.ceil(im_w / 16) * 16
  155. img = np.pad(
  156. img, ((0, divide_h - im_h), (0, divide_w - im_w)), constant_values=(1, 1)
  157. )
  158. img = img[:, :, np.newaxis].transpose(2, 0, 1)
  159. img = img.astype("float32")
  160. return img
  161. def apply(self, data):
  162. """apply"""
  163. data[K.IMAGE] = self.norm_img_latexocr(data[K.IMAGE])
  164. return data
  165. @classmethod
  166. def get_input_keys(cls):
  167. """get input keys"""
  168. return [K.IMAGE, K.ORI_IM_SIZE]
  169. @classmethod
  170. def get_output_keys(cls):
  171. """get output keys"""
  172. return [K.IMAGE]
  173. class BaseRecLabelDecode(BaseTransform):
  174. """Convert between text-label and text-index"""
  175. def __init__(self, character_str=None, use_space_char=True):
  176. self.reverse = False
  177. character_list = (
  178. list(character_str)
  179. if character_str is not None
  180. else list("0123456789abcdefghijklmnopqrstuvwxyz")
  181. )
  182. if use_space_char:
  183. character_list.append(" ")
  184. character_list = self.add_special_char(character_list)
  185. self.dict = {}
  186. for i, char in enumerate(character_list):
  187. self.dict[char] = i
  188. self.character = character_list
  189. def pred_reverse(self, pred):
  190. """pred_reverse"""
  191. pred_re = []
  192. c_current = ""
  193. for c in pred:
  194. if not bool(re.search("[a-zA-Z0-9 :*./%+-]", c)):
  195. if c_current != "":
  196. pred_re.append(c_current)
  197. pred_re.append(c)
  198. c_current = ""
  199. else:
  200. c_current += c
  201. if c_current != "":
  202. pred_re.append(c_current)
  203. return "".join(pred_re[::-1])
  204. def add_special_char(self, character_list):
  205. """add_special_char"""
  206. return character_list
  207. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  208. """convert text-index into text-label."""
  209. result_list = []
  210. ignored_tokens = self.get_ignored_tokens()
  211. batch_size = len(text_index)
  212. for batch_idx in range(batch_size):
  213. selection = np.ones(len(text_index[batch_idx]), dtype=bool)
  214. if is_remove_duplicate:
  215. selection[1:] = text_index[batch_idx][1:] != text_index[batch_idx][:-1]
  216. for ignored_token in ignored_tokens:
  217. selection &= text_index[batch_idx] != ignored_token
  218. char_list = [
  219. self.character[text_id] for text_id in text_index[batch_idx][selection]
  220. ]
  221. if text_prob is not None:
  222. conf_list = text_prob[batch_idx][selection]
  223. else:
  224. conf_list = [1] * len(selection)
  225. if len(conf_list) == 0:
  226. conf_list = [0]
  227. text = "".join(char_list)
  228. if self.reverse: # for arabic rec
  229. text = self.pred_reverse(text)
  230. result_list.append((text, np.mean(conf_list).tolist()))
  231. return result_list
  232. def get_ignored_tokens(self):
  233. """get_ignored_tokens"""
  234. return [0] # for ctc blank
  235. def apply(self, data):
  236. """apply"""
  237. preds = data[K.REC_PROBS]
  238. if isinstance(preds, tuple) or isinstance(preds, list):
  239. preds = preds[-1]
  240. preds_idx = preds.argmax(axis=2)
  241. preds_prob = preds.max(axis=2)
  242. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
  243. data[K.REC_TEXT] = []
  244. data[K.REC_SCORE] = []
  245. for t in text:
  246. data[K.REC_TEXT].append(t[0])
  247. data[K.REC_SCORE].append(t[1])
  248. return data
  249. @classmethod
  250. def get_input_keys(cls):
  251. """get_input_keys"""
  252. return [K.REC_PROBS]
  253. @classmethod
  254. def get_output_keys(cls):
  255. """get_output_keys"""
  256. return [K.REC_TEXT, K.REC_SCORE]
  257. class CTCLabelDecode(BaseRecLabelDecode):
  258. """Convert between text-label and text-index"""
  259. def __init__(self, post_process_cfg=None, use_space_char=True):
  260. assert post_process_cfg["name"] == "CTCLabelDecode"
  261. character_list = post_process_cfg["character_dict"]
  262. super().__init__(character_list, use_space_char=use_space_char)
  263. def apply(self, data):
  264. """apply"""
  265. preds = data[K.REC_PROBS]
  266. if isinstance(preds, tuple) or isinstance(preds, list):
  267. preds = preds[-1]
  268. preds_idx = preds.argmax(axis=2)
  269. preds_prob = preds.max(axis=2)
  270. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
  271. data[K.REC_TEXT] = []
  272. data[K.REC_SCORE] = []
  273. for t in text:
  274. data[K.REC_TEXT].append(t[0])
  275. data[K.REC_SCORE].append(t[1])
  276. return data
  277. def add_special_char(self, character_list):
  278. """add_special_char"""
  279. character_list = ["blank"] + character_list
  280. return character_list
  281. @classmethod
  282. def get_input_keys(cls):
  283. """get_input_keys"""
  284. return [K.REC_PROBS]
  285. @classmethod
  286. def get_output_keys(cls):
  287. """get_output_keys"""
  288. return [K.REC_TEXT, K.REC_SCORE]
  289. class LaTeXOCRDecode(object):
  290. """Convert between latex-symbol and symbol-index"""
  291. def __init__(self, post_process_cfg=None, **kwargs):
  292. assert post_process_cfg["name"] == "LaTeXOCRDecode"
  293. super(LaTeXOCRDecode, self).__init__()
  294. character_list = post_process_cfg["character_dict"]
  295. temp_path = tempfile.gettempdir()
  296. rec_char_dict_path = os.path.join(temp_path, "latexocr_tokenizer.json")
  297. try:
  298. with open(rec_char_dict_path, "w") as f:
  299. json.dump(character_list, f)
  300. except Exception as e:
  301. print(f"创建 latexocr_tokenizer.json 文件失败, 原因{str(e)}")
  302. self.tokenizer = TokenizerFast.from_file(rec_char_dict_path)
  303. def post_process(self, s):
  304. text_reg = r"(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})"
  305. letter = "[a-zA-Z]"
  306. noletter = "[\W_^\d]"
  307. names = [x[0].replace(" ", "") for x in re.findall(text_reg, s)]
  308. s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
  309. news = s
  310. while True:
  311. s = news
  312. news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, noletter), r"\1\2", s)
  313. news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, letter), r"\1\2", news)
  314. news = re.sub(r"(%s)\s+?(%s)" % (letter, noletter), r"\1\2", news)
  315. if news == s:
  316. break
  317. return s
  318. def decode(self, tokens):
  319. if len(tokens.shape) == 1:
  320. tokens = tokens[None, :]
  321. dec = [self.tokenizer.decode(tok) for tok in tokens]
  322. dec_str_list = [
  323. "".join(detok.split(" "))
  324. .replace("Ġ", " ")
  325. .replace("[EOS]", "")
  326. .replace("[BOS]", "")
  327. .replace("[PAD]", "")
  328. .strip()
  329. for detok in dec
  330. ]
  331. return [str(self.post_process(dec_str)) for dec_str in dec_str_list]
  332. def __call__(self, data):
  333. preds = data[K.REC_PROBS]
  334. text = self.decode(preds)
  335. data[K.REC_TEXT] = text[0]
  336. return data
  337. class SaveTextRecResults(BaseTransform):
  338. """SaveTextRecResults"""
  339. _TEXT_REC_RES_SUFFIX = "_text_rec"
  340. _FILE_EXT = ".txt"
  341. def __init__(self, save_dir):
  342. super().__init__()
  343. self.save_dir = save_dir
  344. # We use python backend to save text object
  345. self._writer = TextWriter(backend="python")
  346. def apply(self, data):
  347. """apply"""
  348. ori_path = data[K.IM_PATH]
  349. file_name = os.path.basename(ori_path)
  350. file_name = self._replace_ext(file_name, self._FILE_EXT)
  351. text_rec_res_save_path = os.path.join(self.save_dir, file_name)
  352. rec_res = ""
  353. for text, score in zip(data[K.REC_TEXT], data[K.REC_SCORE]):
  354. line = text + "\t" + str(score) + "\n"
  355. rec_res += line
  356. text_rec_res_save_path = self._add_suffix(
  357. text_rec_res_save_path, self._TEXT_REC_RES_SUFFIX
  358. )
  359. self._write_txt(text_rec_res_save_path, rec_res)
  360. return data
  361. @classmethod
  362. def get_input_keys(cls):
  363. """get_input_keys"""
  364. return [K.IM_PATH, K.REC_TEXT, K.REC_SCORE]
  365. @classmethod
  366. def get_output_keys(cls):
  367. """get_output_keys"""
  368. return []
  369. def _write_txt(self, path, txt_str):
  370. """_write_txt"""
  371. if os.path.exists(path):
  372. logging.warning(f"{path} already exists. Overwriting it.")
  373. self._writer.write(path, txt_str)
  374. @staticmethod
  375. def _add_suffix(path, suffix):
  376. """_add_suffix"""
  377. stem, ext = os.path.splitext(path)
  378. return stem + suffix + ext
  379. @staticmethod
  380. def _replace_ext(path, new_ext):
  381. """_replace_ext"""
  382. stem, _ = os.path.splitext(path)
  383. return stem + new_ext
  384. class PrintResult(BaseTransform):
  385. """Print Result Transform"""
  386. def apply(self, data):
  387. """apply"""
  388. logging.info("The prediction result is:")
  389. logging.info(data[K.REC_TEXT])
  390. return data
  391. @classmethod
  392. def get_input_keys(cls):
  393. """get input keys"""
  394. return [K.REC_TEXT]
  395. @classmethod
  396. def get_output_keys(cls):
  397. """get output keys"""
  398. return []