transforms.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import os.path as osp
  16. import re
  17. import numpy as np
  18. from PIL import Image
  19. import cv2
  20. import math
  21. import paddle
  22. from ....utils import logging
  23. from ...base.predictor import BaseTransform
  24. from ...base.predictor.io.writers import TextWriter
  25. from .keys import TextRecKeys as K
  26. __all__ = ['OCRReisizeNormImg', 'CTCLabelDecode', 'SaveTextRecResults']
  27. class OCRReisizeNormImg(BaseTransform):
  28. """ for ocr image resize and normalization """
  29. def __init__(self, rec_image_shape=[3, 48, 320]):
  30. super().__init__()
  31. self.rec_image_shape = rec_image_shape
  32. def resize_norm_img(self, img, max_wh_ratio):
  33. """ resize and normalize the img """
  34. imgC, imgH, imgW = self.rec_image_shape
  35. assert imgC == img.shape[2]
  36. imgW = int((imgH * max_wh_ratio))
  37. h, w = img.shape[:2]
  38. ratio = w / float(h)
  39. if math.ceil(imgH * ratio) > imgW:
  40. resized_w = imgW
  41. else:
  42. resized_w = int(math.ceil(imgH * ratio))
  43. resized_image = cv2.resize(img, (resized_w, imgH))
  44. resized_image = resized_image.astype('float32')
  45. resized_image = resized_image.transpose((2, 0, 1)) / 255
  46. resized_image -= 0.5
  47. resized_image /= 0.5
  48. padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
  49. padding_im[:, :, 0:resized_w] = resized_image
  50. return padding_im
  51. def apply(self, data):
  52. """ apply """
  53. imgC, imgH, imgW = self.rec_image_shape
  54. max_wh_ratio = imgW / imgH
  55. w, h = data[K.ORI_IM_SIZE]
  56. wh_ratio = w * 1.0 / h
  57. max_wh_ratio = max(max_wh_ratio, wh_ratio)
  58. data[K.IMAGE] = self.resize_norm_img(data[K.IMAGE], max_wh_ratio)
  59. return data
  60. @classmethod
  61. def get_input_keys(cls):
  62. """ get input keys """
  63. return [K.IMAGE, K.ORI_IM_SIZE]
  64. @classmethod
  65. def get_output_keys(cls):
  66. """ get output keys """
  67. return [K.IMAGE]
  68. class BaseRecLabelDecode(BaseTransform):
  69. """ Convert between text-label and text-index """
  70. def __init__(self, character_str=None, use_space_char=True):
  71. self.reverse = False
  72. character_list = list(
  73. character_str) if character_str is not None else list(
  74. "0123456789abcdefghijklmnopqrstuvwxyz")
  75. if use_space_char:
  76. character_list.append(" ")
  77. character_list = self.add_special_char(character_list)
  78. self.dict = {}
  79. for i, char in enumerate(character_list):
  80. self.dict[char] = i
  81. self.character = character_list
  82. def pred_reverse(self, pred):
  83. """ pred_reverse """
  84. pred_re = []
  85. c_current = ''
  86. for c in pred:
  87. if not bool(re.search('[a-zA-Z0-9 :*./%+-]', c)):
  88. if c_current != '':
  89. pred_re.append(c_current)
  90. pred_re.append(c)
  91. c_current = ''
  92. else:
  93. c_current += c
  94. if c_current != '':
  95. pred_re.append(c_current)
  96. return ''.join(pred_re[::-1])
  97. def add_special_char(self, character_list):
  98. """ add_special_char """
  99. return character_list
  100. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  101. """ convert text-index into text-label. """
  102. result_list = []
  103. ignored_tokens = self.get_ignored_tokens()
  104. batch_size = len(text_index)
  105. for batch_idx in range(batch_size):
  106. selection = np.ones(len(text_index[batch_idx]), dtype=bool)
  107. if is_remove_duplicate:
  108. selection[1:] = text_index[batch_idx][1:] != text_index[
  109. batch_idx][:-1]
  110. for ignored_token in ignored_tokens:
  111. selection &= text_index[batch_idx] != ignored_token
  112. char_list = [
  113. self.character[text_id]
  114. for text_id in text_index[batch_idx][selection]
  115. ]
  116. if text_prob is not None:
  117. conf_list = text_prob[batch_idx][selection]
  118. else:
  119. conf_list = [1] * len(selection)
  120. if len(conf_list) == 0:
  121. conf_list = [0]
  122. text = ''.join(char_list)
  123. if self.reverse: # for arabic rec
  124. text = self.pred_reverse(text)
  125. result_list.append((text, np.mean(conf_list).tolist()))
  126. return result_list
  127. def get_ignored_tokens(self):
  128. """ get_ignored_tokens """
  129. return [0] # for ctc blank
  130. def apply(self, data):
  131. """ apply """
  132. preds = data[K.REC_PROBS]
  133. if isinstance(preds, tuple) or isinstance(preds, list):
  134. preds = preds[-1]
  135. preds_idx = preds.argmax(axis=2)
  136. preds_prob = preds.max(axis=2)
  137. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
  138. data[K.REC_TEXT] = []
  139. data[K.REC_SCORE] = []
  140. for t in text:
  141. data[K.REC_TEXT].append(t[0])
  142. data[K.REC_SCORE].append(t[1])
  143. return data
  144. @classmethod
  145. def get_input_keys(cls):
  146. """ get_input_keys """
  147. return [K.REC_PROBS]
  148. @classmethod
  149. def get_output_keys(cls):
  150. """ get_output_keys """
  151. return [K.REC_TEXT, K.REC_SCORE]
  152. class CTCLabelDecode(BaseRecLabelDecode):
  153. """ Convert between text-label and text-index """
  154. def __init__(self, post_process_cfg=None, use_space_char=True):
  155. assert post_process_cfg['name'] == 'CTCLabelDecode'
  156. character_list = post_process_cfg['character_dict']
  157. super().__init__(character_list, use_space_char=use_space_char)
  158. def apply(self, data):
  159. """ apply """
  160. preds = data[K.REC_PROBS]
  161. if isinstance(preds, tuple) or isinstance(preds, list):
  162. preds = preds[-1]
  163. preds_idx = preds.argmax(axis=2)
  164. preds_prob = preds.max(axis=2)
  165. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
  166. data[K.REC_TEXT] = []
  167. data[K.REC_SCORE] = []
  168. for t in text:
  169. data[K.REC_TEXT].append(t[0])
  170. data[K.REC_SCORE].append(t[1])
  171. return data
  172. def add_special_char(self, character_list):
  173. """ add_special_char """
  174. character_list = ['blank'] + character_list
  175. return character_list
  176. @classmethod
  177. def get_input_keys(cls):
  178. """ get_input_keys """
  179. return [K.REC_PROBS]
  180. @classmethod
  181. def get_output_keys(cls):
  182. """ get_output_keys """
  183. return [K.REC_TEXT, K.REC_SCORE]
  184. class SaveTextRecResults(BaseTransform):
  185. """ SaveTextRecResults """
  186. _TEXT_REC_RES_SUFFIX = '_text_rec'
  187. _FILE_EXT = '.txt'
  188. def __init__(self, save_dir):
  189. super().__init__()
  190. self.save_dir = save_dir
  191. # We use python backend to save text object
  192. self._writer = TextWriter(backend='python')
  193. def apply(self, data):
  194. """ apply """
  195. ori_path = data[K.IM_PATH]
  196. file_name = os.path.basename(ori_path)
  197. file_name = self._replace_ext(file_name, self._FILE_EXT)
  198. text_rec_res_save_path = os.path.join(self.save_dir, file_name)
  199. rec_res = ''
  200. for text, score in zip(data[K.REC_TEXT], data[K.REC_SCORE]):
  201. line = text + '\t' + str(score) + '\n'
  202. rec_res += line
  203. text_rec_res_save_path = self._add_suffix(text_rec_res_save_path,
  204. self._TEXT_REC_RES_SUFFIX)
  205. self._write_txt(text_rec_res_save_path, rec_res)
  206. return data
  207. @classmethod
  208. def get_input_keys(cls):
  209. """ get_input_keys """
  210. return [K.IM_PATH, K.REC_TEXT, K.REC_SCORE]
  211. @classmethod
  212. def get_output_keys(cls):
  213. """ get_output_keys """
  214. return []
  215. def _write_txt(self, path, txt_str):
  216. """ _write_txt """
  217. if os.path.exists(path):
  218. logging.warning(f"{path} already exists. Overwriting it.")
  219. self._writer.write(path, txt_str)
  220. @staticmethod
  221. def _add_suffix(path, suffix):
  222. """ _add_suffix """
  223. stem, ext = os.path.splitext(path)
  224. return stem + suffix + ext
  225. @staticmethod
  226. def _replace_ext(path, new_ext):
  227. """ _replace_ext """
  228. stem, _ = os.path.splitext(path)
  229. return stem + new_ext
  230. class PrintResult(BaseTransform):
  231. """ Print Result Transform """
  232. def apply(self, data):
  233. """ apply """
  234. logging.info("The prediction result is:")
  235. logging.info(data[K.REC_TEXT])
  236. return data
  237. @classmethod
  238. def get_input_keys(cls):
  239. """ get input keys """
  240. return [K.REC_TEXT]
  241. @classmethod
  242. def get_output_keys(cls):
  243. """ get output keys """
  244. return []