processors.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import os.path as osp
  16. from typing import List
  17. import re
  18. import numpy as np
  19. from PIL import Image
  20. import cv2
  21. import math
  22. import json
  23. import tempfile
  24. from tokenizers import Tokenizer as TokenizerFast
  25. from ....utils import logging
  26. from ...utils.benchmark import benchmark
  27. @benchmark.timeit
  28. class OCRReisizeNormImg:
  29. """for ocr image resize and normalization"""
  30. def __init__(self, rec_image_shape=[3, 48, 320], input_shape=None):
  31. super().__init__()
  32. self.rec_image_shape = rec_image_shape
  33. self.input_shape = input_shape
  34. self.max_imgW = 3200
  35. def resize_norm_img(self, img, max_wh_ratio):
  36. """resize and normalize the img"""
  37. imgC, imgH, imgW = self.rec_image_shape
  38. assert imgC == img.shape[2]
  39. imgW = int((imgH * max_wh_ratio))
  40. if imgW > self.max_imgW:
  41. resized_image = cv2.resize(img, (self.max_imgW, imgH))
  42. resized_w = self.max_imgW
  43. imgW = self.max_imgW
  44. else:
  45. h, w = img.shape[:2]
  46. ratio = w / float(h)
  47. if math.ceil(imgH * ratio) > imgW:
  48. resized_w = imgW
  49. else:
  50. resized_w = int(math.ceil(imgH * ratio))
  51. resized_image = cv2.resize(img, (resized_w, imgH))
  52. resized_image = resized_image.astype("float32")
  53. resized_image = resized_image.transpose((2, 0, 1)) / 255
  54. resized_image -= 0.5
  55. resized_image /= 0.5
  56. padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
  57. padding_im[:, :, 0:resized_w] = resized_image
  58. return padding_im
  59. def __call__(self, imgs):
  60. """apply"""
  61. if self.input_shape is None:
  62. return [self.resize(img) for img in imgs]
  63. else:
  64. return [self.staticResize(img) for img in imgs]
  65. def resize(self, img):
  66. imgC, imgH, imgW = self.rec_image_shape
  67. max_wh_ratio = imgW / imgH
  68. h, w = img.shape[:2]
  69. wh_ratio = w * 1.0 / h
  70. max_wh_ratio = max(max_wh_ratio, wh_ratio)
  71. img = self.resize_norm_img(img, max_wh_ratio)
  72. return img
  73. def staticResize(self, img):
  74. imgC, imgH, imgW = self.input_shape
  75. resized_image = cv2.resize(img, (int(imgW), int(imgH)))
  76. resized_image = resized_image.transpose((2, 0, 1)) / 255
  77. resized_image -= 0.5
  78. resized_image /= 0.5
  79. return resized_image
  80. @benchmark.timeit
  81. class BaseRecLabelDecode:
  82. """Convert between text-label and text-index"""
  83. def __init__(self, character_str=None, use_space_char=True):
  84. super().__init__()
  85. self.reverse = False
  86. character_list = (
  87. list(character_str)
  88. if character_str is not None
  89. else list("0123456789abcdefghijklmnopqrstuvwxyz")
  90. )
  91. if use_space_char:
  92. character_list.append(" ")
  93. character_list = self.add_special_char(character_list)
  94. self.dict = {}
  95. for i, char in enumerate(character_list):
  96. self.dict[char] = i
  97. self.character = character_list
  98. def pred_reverse(self, pred):
  99. """pred_reverse"""
  100. pred_re = []
  101. c_current = ""
  102. for c in pred:
  103. if not bool(re.search("[a-zA-Z0-9 :*./%+-]", c)):
  104. if c_current != "":
  105. pred_re.append(c_current)
  106. pred_re.append(c)
  107. c_current = ""
  108. else:
  109. c_current += c
  110. if c_current != "":
  111. pred_re.append(c_current)
  112. return "".join(pred_re[::-1])
  113. def add_special_char(self, character_list):
  114. """add_special_char"""
  115. return character_list
  116. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  117. """convert text-index into text-label."""
  118. result_list = []
  119. ignored_tokens = self.get_ignored_tokens()
  120. batch_size = len(text_index)
  121. for batch_idx in range(batch_size):
  122. selection = np.ones(len(text_index[batch_idx]), dtype=bool)
  123. if is_remove_duplicate:
  124. selection[1:] = text_index[batch_idx][1:] != text_index[batch_idx][:-1]
  125. for ignored_token in ignored_tokens:
  126. selection &= text_index[batch_idx] != ignored_token
  127. char_list = [
  128. self.character[text_id] for text_id in text_index[batch_idx][selection]
  129. ]
  130. if text_prob is not None:
  131. conf_list = text_prob[batch_idx][selection]
  132. else:
  133. conf_list = [1] * len(selection)
  134. if len(conf_list) == 0:
  135. conf_list = [0]
  136. text = "".join(char_list)
  137. if self.reverse: # for arabic rec
  138. text = self.pred_reverse(text)
  139. result_list.append((text, np.mean(conf_list).tolist()))
  140. return result_list
  141. def get_ignored_tokens(self):
  142. """get_ignored_tokens"""
  143. return [0] # for ctc blank
  144. def __call__(self, pred):
  145. """apply"""
  146. preds = np.array(pred)
  147. if isinstance(preds, tuple) or isinstance(preds, list):
  148. preds = preds[-1]
  149. preds_idx = preds.argmax(axis=-1)
  150. preds_prob = preds.max(axis=-1)
  151. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
  152. texts = []
  153. scores = []
  154. for t in text:
  155. texts.append(t[0])
  156. scores.append(t[1])
  157. return texts, scores
  158. @benchmark.timeit
  159. class CTCLabelDecode(BaseRecLabelDecode):
  160. """Convert between text-label and text-index"""
  161. def __init__(self, character_list=None, use_space_char=True):
  162. super().__init__(character_list, use_space_char=use_space_char)
  163. def __call__(self, pred):
  164. """apply"""
  165. preds = np.array(pred[0])
  166. preds_idx = preds.argmax(axis=-1)
  167. preds_prob = preds.max(axis=-1)
  168. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
  169. texts = []
  170. scores = []
  171. for t in text:
  172. texts.append(t[0])
  173. scores.append(t[1])
  174. return texts, scores
  175. def add_special_char(self, character_list):
  176. """add_special_char"""
  177. character_list = ["blank"] + character_list
  178. return character_list
  179. @benchmark.timeit
  180. class ToBatch:
  181. """A class for batching and padding images to a uniform width."""
  182. def __pad_imgs(self, imgs: List[np.ndarray]) -> List[np.ndarray]:
  183. """Pad images to the maximum width in the batch.
  184. Args:
  185. imgs (list of np.ndarrays): List of images to pad.
  186. Returns:
  187. list of np.ndarrays: List of padded images.
  188. """
  189. max_width = max(img.shape[2] for img in imgs)
  190. padded_imgs = []
  191. for img in imgs:
  192. _, height, width = img.shape
  193. pad_width = max_width - width
  194. padded_img = np.pad(
  195. img,
  196. ((0, 0), (0, 0), (0, pad_width)),
  197. mode="constant",
  198. constant_values=0,
  199. )
  200. padded_imgs.append(padded_img)
  201. return padded_imgs
  202. def __call__(self, imgs: List[np.ndarray]) -> List[np.ndarray]:
  203. """Call method to pad images and stack them into a batch.
  204. Args:
  205. imgs (list of np.ndarrays): List of images to process.
  206. Returns:
  207. list of np.ndarrays: List containing a stacked tensor of the padded images.
  208. """
  209. imgs = self.__pad_imgs(imgs)
  210. return [np.stack(imgs, axis=0).astype(dtype=np.float32, copy=False)]