transforms.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import os.path as osp
  16. import re
  17. import numpy as np
  18. from PIL import Image
  19. import cv2
  20. import math
  21. import paddle
  22. from ....utils import logging
  23. from ...base.predictor import BaseTransform
  24. from ...base.predictor.io.writers import TextWriter
  25. from .keys import TextRecKeys as K
  26. __all__ = ['OCRReisizeNormImg', 'CTCLabelDecode', 'SaveTextRecResults']
  27. class OCRReisizeNormImg(BaseTransform):
  28. """ for ocr image resize and normalization """
  29. def __init__(self, rec_image_shape=[3, 48, 320]):
  30. super().__init__()
  31. self.rec_image_shape = rec_image_shape
  32. def resize_norm_img(self, img, max_wh_ratio):
  33. """ resize and normalize the img """
  34. imgC, imgH, imgW = self.rec_image_shape
  35. assert imgC == img.shape[2]
  36. imgW = int((imgH * max_wh_ratio))
  37. h, w = img.shape[:2]
  38. ratio = w / float(h)
  39. if math.ceil(imgH * ratio) > imgW:
  40. resized_w = imgW
  41. else:
  42. resized_w = int(math.ceil(imgH * ratio))
  43. resized_image = cv2.resize(img, (resized_w, imgH))
  44. resized_image = resized_image.astype('float32')
  45. resized_image = resized_image.transpose((2, 0, 1)) / 255
  46. resized_image -= 0.5
  47. resized_image /= 0.5
  48. padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
  49. padding_im[:, :, 0:resized_w] = resized_image
  50. return padding_im
  51. def apply(self, data):
  52. """ apply """
  53. imgC, imgH, imgW = self.rec_image_shape
  54. max_wh_ratio = imgW / imgH
  55. w, h = data[K.ORI_IM_SIZE]
  56. wh_ratio = w * 1.0 / h
  57. max_wh_ratio = max(max_wh_ratio, wh_ratio)
  58. data[K.IMAGE] = self.resize_norm_img(data[K.IMAGE], max_wh_ratio)
  59. return data
  60. @classmethod
  61. def get_input_keys(cls):
  62. """ get input keys """
  63. return [K.IMAGE, K.ORI_IM_SIZE]
  64. @classmethod
  65. def get_output_keys(cls):
  66. """ get output keys """
  67. return [K.IMAGE]
  68. class BaseRecLabelDecode(BaseTransform):
  69. """ Convert between text-label and text-index """
  70. def __init__(self, character_str=None):
  71. self.reverse = False
  72. dict_character = character_str if character_str is not None else "0123456789abcdefghijklmnopqrstuvwxyz"
  73. dict_character = self.add_special_char(dict_character)
  74. self.dict = {}
  75. for i, char in enumerate(dict_character):
  76. self.dict[char] = i
  77. self.character = dict_character
  78. def pred_reverse(self, pred):
  79. """ pred_reverse """
  80. pred_re = []
  81. c_current = ''
  82. for c in pred:
  83. if not bool(re.search('[a-zA-Z0-9 :*./%+-]', c)):
  84. if c_current != '':
  85. pred_re.append(c_current)
  86. pred_re.append(c)
  87. c_current = ''
  88. else:
  89. c_current += c
  90. if c_current != '':
  91. pred_re.append(c_current)
  92. return ''.join(pred_re[::-1])
  93. def add_special_char(self, dict_character):
  94. """ add_special_char """
  95. return dict_character
  96. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  97. """ convert text-index into text-label. """
  98. result_list = []
  99. ignored_tokens = self.get_ignored_tokens()
  100. batch_size = len(text_index)
  101. for batch_idx in range(batch_size):
  102. selection = np.ones(len(text_index[batch_idx]), dtype=bool)
  103. if is_remove_duplicate:
  104. selection[1:] = text_index[batch_idx][1:] != text_index[
  105. batch_idx][:-1]
  106. for ignored_token in ignored_tokens:
  107. selection &= text_index[batch_idx] != ignored_token
  108. char_list = [
  109. self.character[text_id]
  110. for text_id in text_index[batch_idx][selection]
  111. ]
  112. if text_prob is not None:
  113. conf_list = text_prob[batch_idx][selection]
  114. else:
  115. conf_list = [1] * len(selection)
  116. if len(conf_list) == 0:
  117. conf_list = [0]
  118. text = ''.join(char_list)
  119. if self.reverse: # for arabic rec
  120. text = self.pred_reverse(text)
  121. result_list.append((text, np.mean(conf_list).tolist()))
  122. return result_list
  123. def get_ignored_tokens(self):
  124. """ get_ignored_tokens """
  125. return [0] # for ctc blank
  126. def apply(self, data):
  127. """ apply """
  128. preds = data[K.REC_PROBS]
  129. if isinstance(preds, tuple) or isinstance(preds, list):
  130. preds = preds[-1]
  131. preds_idx = preds.argmax(axis=2)
  132. preds_prob = preds.max(axis=2)
  133. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
  134. data[K.REC_TEXT] = []
  135. data[K.REC_SCORE] = []
  136. for t in text:
  137. data[K.REC_TEXT].append(t[0])
  138. data[K.REC_SCORE].append(t[1])
  139. return data
  140. @classmethod
  141. def get_input_keys(cls):
  142. """ get_input_keys """
  143. return [K.REC_PROBS]
  144. @classmethod
  145. def get_output_keys(cls):
  146. """ get_output_keys """
  147. return [K.REC_TEXT, K.REC_SCORE]
  148. class CTCLabelDecode(BaseRecLabelDecode):
  149. """ Convert between text-label and text-index """
  150. def __init__(self, post_process_cfg=None):
  151. assert post_process_cfg['name'] == 'CTCLabelDecode'
  152. character_str = post_process_cfg['character_dict']
  153. super().__init__(character_str)
  154. def apply(self, data):
  155. """ apply """
  156. preds = data[K.REC_PROBS]
  157. if isinstance(preds, tuple) or isinstance(preds, list):
  158. preds = preds[-1]
  159. preds_idx = preds.argmax(axis=2)
  160. preds_prob = preds.max(axis=2)
  161. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
  162. data[K.REC_TEXT] = []
  163. data[K.REC_SCORE] = []
  164. for t in text:
  165. data[K.REC_TEXT].append(t[0])
  166. data[K.REC_SCORE].append(t[1])
  167. return data
  168. def add_special_char(self, dict_character):
  169. """ add_special_char """
  170. dict_character = ['blank'] + dict_character
  171. return dict_character
  172. @classmethod
  173. def get_input_keys(cls):
  174. """ get_input_keys """
  175. return [K.REC_PROBS]
  176. @classmethod
  177. def get_output_keys(cls):
  178. """ get_output_keys """
  179. return [K.REC_TEXT, K.REC_SCORE]
  180. class SaveTextRecResults(BaseTransform):
  181. """ SaveTextRecResults """
  182. _TEXT_REC_RES_SUFFIX = '_text_rec'
  183. _FILE_EXT = '.txt'
  184. def __init__(self, save_dir):
  185. super().__init__()
  186. self.save_dir = save_dir
  187. # We use python backend to save text object
  188. self._writer = TextWriter(backend='python')
  189. def apply(self, data):
  190. """ apply """
  191. ori_path = data[K.IM_PATH]
  192. file_name = os.path.basename(ori_path)
  193. file_name = self._replace_ext(file_name, self._FILE_EXT)
  194. text_rec_res_save_path = os.path.join(self.save_dir, file_name)
  195. rec_res = ''
  196. for text, score in zip(data[K.REC_TEXT], data[K.REC_SCORE]):
  197. line = text + '\t' + str(score) + '\n'
  198. rec_res += line
  199. text_rec_res_save_path = self._add_suffix(text_rec_res_save_path,
  200. self._TEXT_REC_RES_SUFFIX)
  201. self._write_txt(text_rec_res_save_path, rec_res)
  202. return data
  203. @classmethod
  204. def get_input_keys(cls):
  205. """ get_input_keys """
  206. return [K.IM_PATH, K.REC_TEXT, K.REC_SCORE]
  207. @classmethod
  208. def get_output_keys(cls):
  209. """ get_output_keys """
  210. return []
  211. def _write_txt(self, path, txt_str):
  212. """ _write_txt """
  213. if os.path.exists(path):
  214. logging.warning(f"{path} already exists. Overwriting it.")
  215. self._writer.write(path, txt_str)
  216. @staticmethod
  217. def _add_suffix(path, suffix):
  218. """ _add_suffix """
  219. stem, ext = os.path.splitext(path)
  220. return stem + suffix + ext
  221. @staticmethod
  222. def _replace_ext(path, new_ext):
  223. """ _replace_ext """
  224. stem, _ = os.path.splitext(path)
  225. return stem + new_ext
  226. class PrintResult(BaseTransform):
  227. """ Print Result Transform """
  228. def apply(self, data):
  229. """ apply """
  230. logging.info("The prediction result is:")
  231. logging.info(data[K.REC_TEXT])
  232. return data
  233. @classmethod
  234. def get_input_keys(cls):
  235. """ get input keys """
  236. return [K.REC_TEXT]
  237. @classmethod
  238. def get_output_keys(cls):
  239. """ get output keys """
  240. return []