processors.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import re
  16. from typing import List
  17. import numpy as np
  18. from ....utils.deps import class_requires_deps, is_dep_available
  19. from ...utils.benchmark import benchmark
  20. if is_dep_available("opencv-contrib-python"):
  21. import cv2
  22. @benchmark.timeit
  23. @class_requires_deps("opencv-contrib-python")
  24. class OCRReisizeNormImg:
  25. """for ocr image resize and normalization"""
  26. def __init__(self, rec_image_shape=[3, 48, 320], input_shape=None):
  27. super().__init__()
  28. self.rec_image_shape = rec_image_shape
  29. self.input_shape = input_shape
  30. self.max_imgW = 3200
  31. def resize_norm_img(self, img, max_wh_ratio):
  32. """resize and normalize the img"""
  33. imgC, imgH, imgW = self.rec_image_shape
  34. assert imgC == img.shape[2]
  35. imgW = int((imgH * max_wh_ratio))
  36. if imgW > self.max_imgW:
  37. resized_image = cv2.resize(img, (self.max_imgW, imgH))
  38. resized_w = self.max_imgW
  39. imgW = self.max_imgW
  40. else:
  41. h, w = img.shape[:2]
  42. ratio = w / float(h)
  43. if math.ceil(imgH * ratio) > imgW:
  44. resized_w = imgW
  45. else:
  46. resized_w = int(math.ceil(imgH * ratio))
  47. resized_image = cv2.resize(img, (resized_w, imgH))
  48. resized_image = resized_image.astype("float32")
  49. resized_image = resized_image.transpose((2, 0, 1)) / 255
  50. resized_image -= 0.5
  51. resized_image /= 0.5
  52. padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
  53. padding_im[:, :, 0:resized_w] = resized_image
  54. return padding_im
  55. def __call__(self, imgs):
  56. """apply"""
  57. if self.input_shape is None:
  58. return [self.resize(img) for img in imgs]
  59. else:
  60. return [self.staticResize(img) for img in imgs]
  61. def resize(self, img):
  62. imgC, imgH, imgW = self.rec_image_shape
  63. max_wh_ratio = imgW / imgH
  64. h, w = img.shape[:2]
  65. wh_ratio = w * 1.0 / h
  66. max_wh_ratio = max(max_wh_ratio, wh_ratio)
  67. img = self.resize_norm_img(img, max_wh_ratio)
  68. return img
  69. def staticResize(self, img):
  70. imgC, imgH, imgW = self.input_shape
  71. resized_image = cv2.resize(img, (int(imgW), int(imgH)))
  72. resized_image = resized_image.transpose((2, 0, 1)) / 255
  73. resized_image -= 0.5
  74. resized_image /= 0.5
  75. return resized_image
  76. @benchmark.timeit
  77. class BaseRecLabelDecode:
  78. """Convert between text-label and text-index"""
  79. def __init__(self, character_str=None, use_space_char=True):
  80. super().__init__()
  81. self.reverse = False
  82. character_list = (
  83. list(character_str)
  84. if character_str is not None
  85. else list("0123456789abcdefghijklmnopqrstuvwxyz")
  86. )
  87. if use_space_char:
  88. character_list.append(" ")
  89. character_list = self.add_special_char(character_list)
  90. self.dict = {}
  91. for i, char in enumerate(character_list):
  92. self.dict[char] = i
  93. self.character = character_list
  94. def pred_reverse(self, pred):
  95. """pred_reverse"""
  96. pred_re = []
  97. c_current = ""
  98. for c in pred:
  99. if not bool(re.search("[a-zA-Z0-9 :*./%+-]", c)):
  100. if c_current != "":
  101. pred_re.append(c_current)
  102. pred_re.append(c)
  103. c_current = ""
  104. else:
  105. c_current += c
  106. if c_current != "":
  107. pred_re.append(c_current)
  108. return "".join(pred_re[::-1])
  109. def add_special_char(self, character_list):
  110. """add_special_char"""
  111. return character_list
  112. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  113. """convert text-index into text-label."""
  114. result_list = []
  115. ignored_tokens = self.get_ignored_tokens()
  116. batch_size = len(text_index)
  117. for batch_idx in range(batch_size):
  118. selection = np.ones(len(text_index[batch_idx]), dtype=bool)
  119. if is_remove_duplicate:
  120. selection[1:] = text_index[batch_idx][1:] != text_index[batch_idx][:-1]
  121. for ignored_token in ignored_tokens:
  122. selection &= text_index[batch_idx] != ignored_token
  123. char_list = [
  124. self.character[text_id] for text_id in text_index[batch_idx][selection]
  125. ]
  126. if text_prob is not None:
  127. conf_list = text_prob[batch_idx][selection]
  128. else:
  129. conf_list = [1] * len(selection)
  130. if len(conf_list) == 0:
  131. conf_list = [0]
  132. text = "".join(char_list)
  133. if self.reverse: # for arabic rec
  134. text = self.pred_reverse(text)
  135. result_list.append((text, np.mean(conf_list).tolist()))
  136. return result_list
  137. def get_ignored_tokens(self):
  138. """get_ignored_tokens"""
  139. return [0] # for ctc blank
  140. def __call__(self, pred):
  141. """apply"""
  142. preds = np.array(pred)
  143. if isinstance(preds, tuple) or isinstance(preds, list):
  144. preds = preds[-1]
  145. preds_idx = preds.argmax(axis=-1)
  146. preds_prob = preds.max(axis=-1)
  147. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
  148. texts = []
  149. scores = []
  150. for t in text:
  151. texts.append(t[0])
  152. scores.append(t[1])
  153. return texts, scores
  154. @benchmark.timeit
  155. class CTCLabelDecode(BaseRecLabelDecode):
  156. """Convert between text-label and text-index"""
  157. def __init__(self, character_list=None, use_space_char=True):
  158. super().__init__(character_list, use_space_char=use_space_char)
  159. def __call__(self, pred):
  160. """apply"""
  161. preds = np.array(pred[0])
  162. preds_idx = preds.argmax(axis=-1)
  163. preds_prob = preds.max(axis=-1)
  164. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
  165. texts = []
  166. scores = []
  167. for t in text:
  168. texts.append(t[0])
  169. scores.append(t[1])
  170. return texts, scores
  171. def add_special_char(self, character_list):
  172. """add_special_char"""
  173. character_list = ["blank"] + character_list
  174. return character_list
  175. @benchmark.timeit
  176. class ToBatch:
  177. """A class for batching and padding images to a uniform width."""
  178. def __pad_imgs(self, imgs: List[np.ndarray]) -> List[np.ndarray]:
  179. """Pad images to the maximum width in the batch.
  180. Args:
  181. imgs (list of np.ndarrays): List of images to pad.
  182. Returns:
  183. list of np.ndarrays: List of padded images.
  184. """
  185. max_width = max(img.shape[2] for img in imgs)
  186. padded_imgs = []
  187. for img in imgs:
  188. _, height, width = img.shape
  189. pad_width = max_width - width
  190. padded_img = np.pad(
  191. img,
  192. ((0, 0), (0, 0), (0, pad_width)),
  193. mode="constant",
  194. constant_values=0,
  195. )
  196. padded_imgs.append(padded_img)
  197. return padded_imgs
  198. def __call__(self, imgs: List[np.ndarray]) -> List[np.ndarray]:
  199. """Call method to pad images and stack them into a batch.
  200. Args:
  201. imgs (list of np.ndarrays): List of images to process.
  202. Returns:
  203. list of np.ndarrays: List containing a stacked tensor of the padded images.
  204. """
  205. imgs = self.__pad_imgs(imgs)
  206. return [np.stack(imgs, axis=0).astype(dtype=np.float32, copy=False)]