text_rec.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import os.path as osp
  16. import re
  17. import numpy as np
  18. from PIL import Image
  19. import cv2
  20. import math
  21. import json
  22. import tempfile
  23. from tokenizers import Tokenizer as TokenizerFast
  24. from ....utils import logging
  25. from ..base import BaseComponent
  26. __all__ = [
  27. "OCRReisizeNormImg",
  28. "LaTeXOCRReisizeNormImg",
  29. "CTCLabelDecode",
  30. "LaTeXOCRDecode",
  31. ]
  32. class OCRReisizeNormImg(BaseComponent):
  33. """for ocr image resize and normalization"""
  34. INPUT_KEYS = ["img", "img_size"]
  35. OUTPUT_KEYS = ["img"]
  36. DEAULT_INPUTS = {"img": "img", "img_size": "img_size"}
  37. DEAULT_OUTPUTS = {"img": "img"}
  38. def __init__(self, rec_image_shape=[3, 48, 320]):
  39. super().__init__()
  40. self.rec_image_shape = rec_image_shape
  41. def resize_norm_img(self, img, max_wh_ratio):
  42. """resize and normalize the img"""
  43. imgC, imgH, imgW = self.rec_image_shape
  44. assert imgC == img.shape[2]
  45. imgW = int((imgH * max_wh_ratio))
  46. h, w = img.shape[:2]
  47. ratio = w / float(h)
  48. if math.ceil(imgH * ratio) > imgW:
  49. resized_w = imgW
  50. else:
  51. resized_w = int(math.ceil(imgH * ratio))
  52. resized_image = cv2.resize(img, (resized_w, imgH))
  53. resized_image = resized_image.astype("float32")
  54. resized_image = resized_image.transpose((2, 0, 1)) / 255
  55. resized_image -= 0.5
  56. resized_image /= 0.5
  57. padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
  58. padding_im[:, :, 0:resized_w] = resized_image
  59. return padding_im
  60. def apply(self, img, img_size):
  61. """apply"""
  62. imgC, imgH, imgW = self.rec_image_shape
  63. max_wh_ratio = imgW / imgH
  64. w, h = img_size[:2]
  65. wh_ratio = w * 1.0 / h
  66. max_wh_ratio = max(max_wh_ratio, wh_ratio)
  67. img = self.resize_norm_img(img, max_wh_ratio)
  68. return {"img": img}
  69. class LaTeXOCRReisizeNormImg(BaseComponent):
  70. """for ocr image resize and normalization"""
  71. INPUT_KEYS = "img"
  72. OUTPUT_KEYS = "img"
  73. DEAULT_INPUTS = {"img": "img"}
  74. DEAULT_OUTPUTS = {"img": "img"}
  75. def __init__(self, rec_image_shape=[3, 48, 320]):
  76. super().__init__()
  77. self.rec_image_shape = rec_image_shape
  78. def pad_(self, img, divable=32):
  79. threshold = 128
  80. data = np.array(img.convert("LA"))
  81. if data[..., -1].var() == 0:
  82. data = (data[..., 0]).astype(np.uint8)
  83. else:
  84. data = (255 - data[..., -1]).astype(np.uint8)
  85. data = (data - data.min()) / (data.max() - data.min()) * 255
  86. if data.mean() > threshold:
  87. # To invert the text to white
  88. gray = 255 * (data < threshold).astype(np.uint8)
  89. else:
  90. gray = 255 * (data > threshold).astype(np.uint8)
  91. data = 255 - data
  92. coords = cv2.findNonZero(gray) # Find all non-zero points (text)
  93. a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box
  94. rect = data[b : b + h, a : a + w]
  95. im = Image.fromarray(rect).convert("L")
  96. dims = []
  97. for x in [w, h]:
  98. div, mod = divmod(x, divable)
  99. dims.append(divable * (div + (1 if mod > 0 else 0)))
  100. padded = Image.new("L", dims, 255)
  101. padded.paste(im, (0, 0, im.size[0], im.size[1]))
  102. return padded
  103. def minmax_size_(
  104. self,
  105. img,
  106. max_dimensions,
  107. min_dimensions,
  108. ):
  109. if max_dimensions is not None:
  110. ratios = [a / b for a, b in zip(img.size, max_dimensions)]
  111. if any([r > 1 for r in ratios]):
  112. size = np.array(img.size) // max(ratios)
  113. img = img.resize(tuple(size.astype(int)), Image.BILINEAR)
  114. if min_dimensions is not None:
  115. # hypothesis: there is a dim in img smaller than min_dimensions, and return a proper dim >= min_dimensions
  116. padded_size = [
  117. max(img_dim, min_dim)
  118. for img_dim, min_dim in zip(img.size, min_dimensions)
  119. ]
  120. if padded_size != list(img.size): # assert hypothesis
  121. padded_im = Image.new("L", padded_size, 255)
  122. padded_im.paste(img, img.getbbox())
  123. img = padded_im
  124. return img
  125. def norm_img_latexocr(self, img):
  126. # CAN only predict gray scale image
  127. shape = (1, 1, 3)
  128. mean = [0.7931, 0.7931, 0.7931]
  129. std = [0.1738, 0.1738, 0.1738]
  130. scale = np.float32(1.0 / 255.0)
  131. min_dimensions = [32, 32]
  132. max_dimensions = [672, 192]
  133. mean = np.array(mean).reshape(shape).astype("float32")
  134. std = np.array(std).reshape(shape).astype("float32")
  135. im_h, im_w = img.shape[:2]
  136. if (
  137. min_dimensions[0] <= im_w <= max_dimensions[0]
  138. and min_dimensions[1] <= im_h <= max_dimensions[1]
  139. ):
  140. pass
  141. else:
  142. img = Image.fromarray(np.uint8(img))
  143. img = self.minmax_size_(self.pad_(img), max_dimensions, min_dimensions)
  144. img = np.array(img)
  145. im_h, im_w = img.shape[:2]
  146. img = np.dstack([img, img, img])
  147. img = (img.astype("float32") * scale - mean) / std
  148. img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  149. divide_h = math.ceil(im_h / 16) * 16
  150. divide_w = math.ceil(im_w / 16) * 16
  151. img = np.pad(
  152. img, ((0, divide_h - im_h), (0, divide_w - im_w)), constant_values=(1, 1)
  153. )
  154. img = img[:, :, np.newaxis].transpose(2, 0, 1)
  155. img = img.astype("float32")
  156. return img
  157. def apply(self, img):
  158. """apply"""
  159. img = self.norm_img_latexocr(img)
  160. return {"img": img}
  161. class BaseRecLabelDecode(BaseComponent):
  162. """Convert between text-label and text-index"""
  163. INPUT_KEYS = ["pred"]
  164. OUTPUT_KEYS = ["rec_text", "rec_score"]
  165. DEAULT_INPUTS = {"pred": "pred"}
  166. DEAULT_OUTPUTS = {"rec_text": "rec_text", "rec_score": "rec_score"}
  167. ENABLE_BATCH = True
  168. def __init__(self, character_str=None, use_space_char=True):
  169. super().__init__()
  170. self.reverse = False
  171. character_list = (
  172. list(character_str)
  173. if character_str is not None
  174. else list("0123456789abcdefghijklmnopqrstuvwxyz")
  175. )
  176. if use_space_char:
  177. character_list.append(" ")
  178. character_list = self.add_special_char(character_list)
  179. self.dict = {}
  180. for i, char in enumerate(character_list):
  181. self.dict[char] = i
  182. self.character = character_list
  183. def pred_reverse(self, pred):
  184. """pred_reverse"""
  185. pred_re = []
  186. c_current = ""
  187. for c in pred:
  188. if not bool(re.search("[a-zA-Z0-9 :*./%+-]", c)):
  189. if c_current != "":
  190. pred_re.append(c_current)
  191. pred_re.append(c)
  192. c_current = ""
  193. else:
  194. c_current += c
  195. if c_current != "":
  196. pred_re.append(c_current)
  197. return "".join(pred_re[::-1])
  198. def add_special_char(self, character_list):
  199. """add_special_char"""
  200. return character_list
  201. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  202. """convert text-index into text-label."""
  203. result_list = []
  204. ignored_tokens = self.get_ignored_tokens()
  205. batch_size = len(text_index)
  206. for batch_idx in range(batch_size):
  207. selection = np.ones(len(text_index[batch_idx]), dtype=bool)
  208. if is_remove_duplicate:
  209. selection[1:] = text_index[batch_idx][1:] != text_index[batch_idx][:-1]
  210. for ignored_token in ignored_tokens:
  211. selection &= text_index[batch_idx] != ignored_token
  212. char_list = [
  213. self.character[text_id] for text_id in text_index[batch_idx][selection]
  214. ]
  215. if text_prob is not None:
  216. conf_list = text_prob[batch_idx][selection]
  217. else:
  218. conf_list = [1] * len(selection)
  219. if len(conf_list) == 0:
  220. conf_list = [0]
  221. text = "".join(char_list)
  222. if self.reverse: # for arabic rec
  223. text = self.pred_reverse(text)
  224. result_list.append((text, np.mean(conf_list).tolist()))
  225. return result_list
  226. def get_ignored_tokens(self):
  227. """get_ignored_tokens"""
  228. return [0] # for ctc blank
  229. def apply(self, pred):
  230. """apply"""
  231. preds = np.array(pred)
  232. if isinstance(preds, tuple) or isinstance(preds, list):
  233. preds = preds[-1]
  234. preds_idx = preds.argmax(axis=2)
  235. preds_prob = preds.max(axis=2)
  236. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
  237. return [{"rec_text": t[0], "rec_score": t[1]} for t in text]
  238. class CTCLabelDecode(BaseRecLabelDecode):
  239. """Convert between text-label and text-index"""
  240. def __init__(self, character_list=None, use_space_char=True):
  241. super().__init__(character_list, use_space_char=use_space_char)
  242. def apply(self, pred):
  243. """apply"""
  244. preds = np.array(pred[0])
  245. preds_idx = preds.argmax(axis=2)
  246. preds_prob = preds.max(axis=2)
  247. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
  248. return [{"rec_text": t[0], "rec_score": t[1]} for t in text]
  249. def add_special_char(self, character_list):
  250. """add_special_char"""
  251. character_list = ["blank"] + character_list
  252. return character_list
  253. class LaTeXOCRDecode(BaseComponent):
  254. """Convert between latex-symbol and symbol-index"""
  255. INPUT_KEYS = ["pred"]
  256. OUTPUT_KEYS = ["rec_text"]
  257. DEAULT_INPUTS = {"pred": "pred"}
  258. DEAULT_OUTPUTS = {"rec_text": "rec_text"}
  259. def __init__(self, character_list=None):
  260. super().__init__()
  261. character_list = character_list
  262. temp_path = tempfile.gettempdir()
  263. rec_char_dict_path = os.path.join(temp_path, "latexocr_tokenizer.json")
  264. try:
  265. with open(rec_char_dict_path, "w") as f:
  266. json.dump(character_list, f)
  267. except Exception as e:
  268. print(f"创建 latexocr_tokenizer.json 文件失败, 原因{str(e)}")
  269. self.tokenizer = TokenizerFast.from_file(rec_char_dict_path)
  270. def post_process(self, s):
  271. text_reg = r"(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})"
  272. letter = "[a-zA-Z]"
  273. noletter = "[\W_^\d]"
  274. names = [x[0].replace(" ", "") for x in re.findall(text_reg, s)]
  275. s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
  276. news = s
  277. while True:
  278. s = news
  279. news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, noletter), r"\1\2", s)
  280. news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, letter), r"\1\2", news)
  281. news = re.sub(r"(%s)\s+?(%s)" % (letter, noletter), r"\1\2", news)
  282. if news == s:
  283. break
  284. return s
  285. def decode(self, tokens):
  286. if len(tokens.shape) == 1:
  287. tokens = tokens[None, :]
  288. dec = [self.tokenizer.decode(tok) for tok in tokens]
  289. dec_str_list = [
  290. "".join(detok.split(" "))
  291. .replace("Ġ", " ")
  292. .replace("[EOS]", "")
  293. .replace("[BOS]", "")
  294. .replace("[PAD]", "")
  295. .strip()
  296. for detok in dec
  297. ]
  298. return [str(self.post_process(dec_str)) for dec_str in dec_str_list]
  299. def apply(self, pred):
  300. preds = np.array(pred[0])
  301. text = self.decode(preds)
  302. return {"rec_text": text[0]}