transforms.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import os.path as osp
  16. import numpy as np
  17. from PIL import Image
  18. import cv2
  19. import paddle
  20. from ....utils import logging
  21. from ...base import BaseTransform
  22. from ...base.predictor.io.writers import ImageWriter
  23. from .keys import TableRecKeys as K
  24. __all__ = ['TableLabelDecode', 'TableMasterLabelDecode', 'SaveTableResults']
  25. class TableLabelDecode(BaseTransform):
  26. """ decode the table model outputs(probs) to character str"""
  27. def __init__(self,
  28. character_dict_type='TableAttn_ch',
  29. merge_no_span_structure=True):
  30. dict_character = []
  31. supported_dict = ['TableAttn_ch', 'TableAttn_en', 'TableMaster']
  32. if character_dict_type == 'TableAttn_ch':
  33. character_dict_name = 'table_structure_dict_ch.txt'
  34. elif character_dict_type == 'TableAttn_en':
  35. character_dict_name = 'table_structure_dict.txt'
  36. elif character_dict_type == 'TableMaster':
  37. character_dict_name = 'table_master_structure_dict.txt'
  38. else:
  39. assert False, " character_dict_type must in %s " \
  40. % supported_dict
  41. character_dict_path = osp.abspath(
  42. osp.join(osp.dirname(__file__), character_dict_name))
  43. with open(character_dict_path, "rb") as fin:
  44. lines = fin.readlines()
  45. for line in lines:
  46. line = line.decode('utf-8').strip("\n").strip("\r\n")
  47. dict_character.append(line)
  48. if merge_no_span_structure:
  49. if "<td></td>" not in dict_character:
  50. dict_character.append("<td></td>")
  51. if "<td>" in dict_character:
  52. dict_character.remove("<td>")
  53. dict_character = self.add_special_char(dict_character)
  54. self.dict = {}
  55. for i, char in enumerate(dict_character):
  56. self.dict[char] = i
  57. self.character = dict_character
  58. self.td_token = ['<td>', '<td', '<td></td>']
  59. def add_special_char(self, dict_character):
  60. """ add_special_char """
  61. self.beg_str = "sos"
  62. self.end_str = "eos"
  63. dict_character = dict_character
  64. dict_character = [self.beg_str] + dict_character + [self.end_str]
  65. return dict_character
  66. def get_ignored_tokens(self):
  67. """ get_ignored_tokens """
  68. beg_idx = self.get_beg_end_flag_idx("beg")
  69. end_idx = self.get_beg_end_flag_idx("end")
  70. return [beg_idx, end_idx]
  71. def get_beg_end_flag_idx(self, beg_or_end):
  72. """ get_beg_end_flag_idx """
  73. if beg_or_end == "beg":
  74. idx = np.array(self.dict[self.beg_str])
  75. elif beg_or_end == "end":
  76. idx = np.array(self.dict[self.end_str])
  77. else:
  78. assert False, "unsupported type %s in get_beg_end_flag_idx" \
  79. % beg_or_end
  80. return idx
  81. def apply(self, data):
  82. """ apply """
  83. shape_list = data[K.SHAPE_LIST]
  84. structure_probs = data[K.STRUCTURE_PROB]
  85. bbox_preds = data[K.LOC_PROB]
  86. if isinstance(structure_probs, paddle.Tensor):
  87. structure_probs = structure_probs.numpy()
  88. if isinstance(bbox_preds, paddle.Tensor):
  89. bbox_preds = bbox_preds.numpy()
  90. post_result = self.decode(structure_probs, bbox_preds, shape_list)
  91. structure_str_list = post_result['structure_batch_list'][0]
  92. bbox_list = post_result['bbox_batch_list'][0]
  93. structure_str_list = structure_str_list[0]
  94. structure_str_list = [
  95. '<html>', '<body>', '<table>'
  96. ] + structure_str_list + ['</table>', '</body>', '</html>']
  97. data[K.BBOX_RES] = bbox_list
  98. data[K.HTML_RES] = structure_str_list
  99. return data
  100. @classmethod
  101. def get_input_keys(cls):
  102. """ get input keys """
  103. return [K.STRUCTURE_PROB, K.LOC_PROB, K.SHAPE_LIST]
  104. @classmethod
  105. def get_output_keys(cls):
  106. """ get output keys """
  107. return [K.BBOX_RES, K.HTML_RES]
  108. def decode(self, structure_probs, bbox_preds, shape_list):
  109. """convert text-label into text-index.
  110. """
  111. ignored_tokens = self.get_ignored_tokens()
  112. end_idx = self.dict[self.end_str]
  113. structure_idx = structure_probs.argmax(axis=2)
  114. structure_probs = structure_probs.max(axis=2)
  115. structure_batch_list = []
  116. bbox_batch_list = []
  117. batch_size = len(structure_idx)
  118. for batch_idx in range(batch_size):
  119. structure_list = []
  120. bbox_list = []
  121. score_list = []
  122. for idx in range(len(structure_idx[batch_idx])):
  123. char_idx = int(structure_idx[batch_idx][idx])
  124. if idx > 0 and char_idx == end_idx:
  125. break
  126. if char_idx in ignored_tokens:
  127. continue
  128. text = self.character[char_idx]
  129. if text in self.td_token:
  130. bbox = bbox_preds[batch_idx, idx]
  131. bbox = self._bbox_decode(bbox, shape_list[batch_idx])
  132. bbox_list.append(bbox)
  133. structure_list.append(text)
  134. score_list.append(structure_probs[batch_idx, idx])
  135. structure_batch_list.append([structure_list, np.mean(score_list)])
  136. bbox_batch_list.append(np.array(bbox_list))
  137. result = {
  138. 'bbox_batch_list': bbox_batch_list,
  139. 'structure_batch_list': structure_batch_list,
  140. }
  141. return result
  142. def decode_label(self, batch):
  143. """convert text-label into text-index.
  144. """
  145. structure_idx = batch[1]
  146. gt_bbox_list = batch[2]
  147. shape_list = batch[-1]
  148. ignored_tokens = self.get_ignored_tokens()
  149. end_idx = self.dict[self.end_str]
  150. structure_batch_list = []
  151. bbox_batch_list = []
  152. batch_size = len(structure_idx)
  153. for batch_idx in range(batch_size):
  154. structure_list = []
  155. bbox_list = []
  156. for idx in range(len(structure_idx[batch_idx])):
  157. char_idx = int(structure_idx[batch_idx][idx])
  158. if idx > 0 and char_idx == end_idx:
  159. break
  160. if char_idx in ignored_tokens:
  161. continue
  162. structure_list.append(self.character[char_idx])
  163. bbox = gt_bbox_list[batch_idx][idx]
  164. if bbox.sum() != 0:
  165. bbox = self._bbox_decode(bbox, shape_list[batch_idx])
  166. bbox_list.append(bbox)
  167. structure_batch_list.append(structure_list)
  168. bbox_batch_list.append(bbox_list)
  169. result = {
  170. 'bbox_batch_list': bbox_batch_list,
  171. 'structure_batch_list': structure_batch_list,
  172. }
  173. return result
  174. def _bbox_decode(self, bbox, shape):
  175. w, h = shape[:2]
  176. bbox[0::2] *= w
  177. bbox[1::2] *= h
  178. return bbox
  179. class TableMasterLabelDecode(TableLabelDecode):
  180. """ decode the table model outputs(probs) to character str"""
  181. def __init__(self,
  182. character_dict_type='TableMaster',
  183. box_shape='pad',
  184. merge_no_span_structure=True):
  185. super(TableMasterLabelDecode, self).__init__(character_dict_type,
  186. merge_no_span_structure)
  187. self.box_shape = box_shape
  188. assert box_shape in [
  189. 'ori', 'pad'
  190. ], 'The shape used for box normalization must be ori or pad'
  191. def add_special_char(self, dict_character):
  192. """ add_special_char """
  193. self.beg_str = '<SOS>'
  194. self.end_str = '<EOS>'
  195. self.unknown_str = '<UKN>'
  196. self.pad_str = '<PAD>'
  197. dict_character = dict_character
  198. dict_character = dict_character + [
  199. self.unknown_str, self.beg_str, self.end_str, self.pad_str
  200. ]
  201. return dict_character
  202. def get_ignored_tokens(self):
  203. """ get_ignored_tokens """
  204. pad_idx = self.dict[self.pad_str]
  205. start_idx = self.dict[self.beg_str]
  206. end_idx = self.dict[self.end_str]
  207. unknown_idx = self.dict[self.unknown_str]
  208. return [start_idx, end_idx, pad_idx, unknown_idx]
  209. def _bbox_decode(self, bbox, shape):
  210. """ _bbox_decode """
  211. h, w, ratio_h, ratio_w, pad_h, pad_w = shape
  212. if self.box_shape == 'pad':
  213. h, w = pad_h, pad_w
  214. bbox[0::2] *= w
  215. bbox[1::2] *= h
  216. bbox[0::2] /= ratio_w
  217. bbox[1::2] /= ratio_h
  218. x, y, w, h = bbox
  219. x1, y1, x2, y2 = x - w // 2, y - h // 2, x + w // 2, y + h // 2
  220. bbox = np.array([x1, y1, x2, y2])
  221. return bbox
  222. class SaveTableResults(BaseTransform):
  223. """ SaveTableResults """
  224. _TABLE_RES_SUFFIX = '_bbox'
  225. _FILE_EXT = '.png'
  226. # _DEFAULT_FILE_NAME = 'table_res_out.png'
  227. def __init__(self, save_dir):
  228. super().__init__()
  229. self.save_dir = save_dir
  230. # We use pillow backend to save both numpy arrays and PIL Image objects
  231. self._writer = ImageWriter(backend='pillow')
  232. def apply(self, data):
  233. """ apply """
  234. ori_path = data[K.IM_PATH]
  235. bbox_res = data[K.BBOX_RES]
  236. file_name = os.path.basename(ori_path)
  237. file_name = self._replace_ext(file_name, self._FILE_EXT)
  238. table_res_save_path = os.path.join(self.save_dir, file_name)
  239. if len(bbox_res) > 0 and len(bbox_res[0]) == 4:
  240. vis_img = self.draw_rectangle(data[K.ORI_IM], bbox_res)
  241. else:
  242. vis_img = self.draw_bbox(data[K.ORI_IM], bbox_res)
  243. table_res_save_path = self._add_suffix(table_res_save_path,
  244. self._TABLE_RES_SUFFIX)
  245. self._write_im(table_res_save_path, vis_img)
  246. return data
  247. @classmethod
  248. def get_input_keys(cls):
  249. """ get input keys """
  250. return [K.IM_PATH, K.ORI_IM, K.BBOX_RES]
  251. @classmethod
  252. def get_output_keys(cls):
  253. """ get output keys """
  254. return []
  255. def _write_im(self, path, im):
  256. """ write image """
  257. if os.path.exists(path):
  258. logging.warning(f"{path} already exists. Overwriting it.")
  259. self._writer.write(path, im)
  260. @staticmethod
  261. def _add_suffix(path, suffix):
  262. """ _add_suffix """
  263. stem, ext = os.path.splitext(path)
  264. return stem + suffix + ext
  265. @staticmethod
  266. def _replace_ext(path, new_ext):
  267. """ _replace_ext """
  268. stem, _ = os.path.splitext(path)
  269. return stem + new_ext
  270. def draw_rectangle(self, img_path, boxes):
  271. """ draw_rectangle """
  272. boxes = np.array(boxes)
  273. img = cv2.imread(img_path)
  274. img_show = img.copy()
  275. for box in boxes.astype(int):
  276. x1, y1, x2, y2 = box
  277. cv2.rectangle(img_show, (x1, y1), (x2, y2), (255, 0, 0), 2)
  278. return img_show
  279. def draw_bbox(self, image, boxes):
  280. """ draw_bbox """
  281. for box in boxes:
  282. box = np.reshape(np.array(box), [-1, 1, 2]).astype(np.int64)
  283. image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2)
  284. return image
  285. class PrintResult(BaseTransform):
  286. """ Print Result Transform """
  287. def apply(self, data):
  288. """ apply """
  289. logging.info("The prediction result is:")
  290. logging.info(data[K.BOXES])
  291. return data
  292. @classmethod
  293. def get_input_keys(cls):
  294. """ get input keys """
  295. return [K.BOXES]
  296. @classmethod
  297. def get_output_keys(cls):
  298. """ get output keys """
  299. return []