transforms.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import os.path as osp
  16. import numpy as np
  17. import cv2
  18. import paddle
  19. from .keys import TableRecKeys as K
  20. from ...base import BaseTransform
  21. from ...base.predictor.io.writers import ImageWriter
  22. from ....utils import logging
  23. __all__ = ['TableLabelDecode', 'TableMasterLabelDecode', 'SaveTableResults']
  24. class TableLabelDecode(BaseTransform):
  25. """ decode the table model outputs(probs) to character str"""
  26. def __init__(self,
  27. character_dict_type='TableAttn_ch',
  28. merge_no_span_structure=True):
  29. dict_character = []
  30. supported_dict = ['TableAttn_ch', 'TableAttn_en', 'TableMaster']
  31. if character_dict_type == 'TableAttn_ch':
  32. character_dict_name = 'table_structure_dict_ch.txt'
  33. elif character_dict_type == 'TableAttn_en':
  34. character_dict_name = 'table_structure_dict.txt'
  35. elif character_dict_type == 'TableMaster':
  36. character_dict_name = 'table_master_structure_dict.txt'
  37. else:
  38. assert False, " character_dict_type must in %s " \
  39. % supported_dict
  40. character_dict_path = osp.abspath(
  41. osp.join(osp.dirname(__file__), character_dict_name))
  42. with open(character_dict_path, "rb") as fin:
  43. lines = fin.readlines()
  44. for line in lines:
  45. line = line.decode('utf-8').strip("\n").strip("\r\n")
  46. dict_character.append(line)
  47. if merge_no_span_structure:
  48. if "<td></td>" not in dict_character:
  49. dict_character.append("<td></td>")
  50. if "<td>" in dict_character:
  51. dict_character.remove("<td>")
  52. dict_character = self.add_special_char(dict_character)
  53. self.dict = {}
  54. for i, char in enumerate(dict_character):
  55. self.dict[char] = i
  56. self.character = dict_character
  57. self.td_token = ['<td>', '<td', '<td></td>']
  58. def add_special_char(self, dict_character):
  59. """ add_special_char """
  60. self.beg_str = "sos"
  61. self.end_str = "eos"
  62. dict_character = dict_character
  63. dict_character = [self.beg_str] + dict_character + [self.end_str]
  64. return dict_character
  65. def get_ignored_tokens(self):
  66. """ get_ignored_tokens """
  67. beg_idx = self.get_beg_end_flag_idx("beg")
  68. end_idx = self.get_beg_end_flag_idx("end")
  69. return [beg_idx, end_idx]
  70. def get_beg_end_flag_idx(self, beg_or_end):
  71. """ get_beg_end_flag_idx """
  72. if beg_or_end == "beg":
  73. idx = np.array(self.dict[self.beg_str])
  74. elif beg_or_end == "end":
  75. idx = np.array(self.dict[self.end_str])
  76. else:
  77. assert False, "unsupported type %s in get_beg_end_flag_idx" \
  78. % beg_or_end
  79. return idx
  80. def apply(self, data):
  81. """ apply """
  82. shape_list = data[K.SHAPE_LIST]
  83. structure_probs = data[K.STRUCTURE_PROB]
  84. bbox_preds = data[K.LOC_PROB]
  85. if isinstance(structure_probs, paddle.Tensor):
  86. structure_probs = structure_probs.numpy()
  87. if isinstance(bbox_preds, paddle.Tensor):
  88. bbox_preds = bbox_preds.numpy()
  89. post_result = self.decode(structure_probs, bbox_preds, shape_list)
  90. structure_str_list = post_result['structure_batch_list'][0]
  91. bbox_list = post_result['bbox_batch_list'][0]
  92. structure_str_list = structure_str_list[0]
  93. structure_str_list = [
  94. '<html>', '<body>', '<table>'
  95. ] + structure_str_list + ['</table>', '</body>', '</html>']
  96. data[K.BBOX_RES] = bbox_list
  97. data[K.HTML_RES] = structure_str_list
  98. return data
  99. @classmethod
  100. def get_input_keys(cls):
  101. """ get input keys """
  102. return [K.STRUCTURE_PROB, K.LOC_PROB, K.SHAPE_LIST]
  103. @classmethod
  104. def get_output_keys(cls):
  105. """ get output keys """
  106. return [K.BBOX_RES, K.HTML_RES]
  107. def decode(self, structure_probs, bbox_preds, shape_list):
  108. """convert text-label into text-index.
  109. """
  110. ignored_tokens = self.get_ignored_tokens()
  111. end_idx = self.dict[self.end_str]
  112. structure_idx = structure_probs.argmax(axis=2)
  113. structure_probs = structure_probs.max(axis=2)
  114. structure_batch_list = []
  115. bbox_batch_list = []
  116. batch_size = len(structure_idx)
  117. for batch_idx in range(batch_size):
  118. structure_list = []
  119. bbox_list = []
  120. score_list = []
  121. for idx in range(len(structure_idx[batch_idx])):
  122. char_idx = int(structure_idx[batch_idx][idx])
  123. if idx > 0 and char_idx == end_idx:
  124. break
  125. if char_idx in ignored_tokens:
  126. continue
  127. text = self.character[char_idx]
  128. if text in self.td_token:
  129. bbox = bbox_preds[batch_idx, idx]
  130. bbox = self._bbox_decode(bbox, shape_list[batch_idx])
  131. bbox_list.append(bbox)
  132. structure_list.append(text)
  133. score_list.append(structure_probs[batch_idx, idx])
  134. structure_batch_list.append([structure_list, np.mean(score_list)])
  135. bbox_batch_list.append(np.array(bbox_list))
  136. result = {
  137. 'bbox_batch_list': bbox_batch_list,
  138. 'structure_batch_list': structure_batch_list,
  139. }
  140. return result
  141. def decode_label(self, batch):
  142. """convert text-label into text-index.
  143. """
  144. structure_idx = batch[1]
  145. gt_bbox_list = batch[2]
  146. shape_list = batch[-1]
  147. ignored_tokens = self.get_ignored_tokens()
  148. end_idx = self.dict[self.end_str]
  149. structure_batch_list = []
  150. bbox_batch_list = []
  151. batch_size = len(structure_idx)
  152. for batch_idx in range(batch_size):
  153. structure_list = []
  154. bbox_list = []
  155. for idx in range(len(structure_idx[batch_idx])):
  156. char_idx = int(structure_idx[batch_idx][idx])
  157. if idx > 0 and char_idx == end_idx:
  158. break
  159. if char_idx in ignored_tokens:
  160. continue
  161. structure_list.append(self.character[char_idx])
  162. bbox = gt_bbox_list[batch_idx][idx]
  163. if bbox.sum() != 0:
  164. bbox = self._bbox_decode(bbox, shape_list[batch_idx])
  165. bbox_list.append(bbox)
  166. structure_batch_list.append(structure_list)
  167. bbox_batch_list.append(bbox_list)
  168. result = {
  169. 'bbox_batch_list': bbox_batch_list,
  170. 'structure_batch_list': structure_batch_list,
  171. }
  172. return result
  173. def _bbox_decode(self, bbox, shape):
  174. w, h = shape[:2]
  175. bbox[0::2] *= w
  176. bbox[1::2] *= h
  177. return bbox
  178. class TableMasterLabelDecode(TableLabelDecode):
  179. """ decode the table model outputs(probs) to character str"""
  180. def __init__(self,
  181. character_dict_type='TableMaster',
  182. box_shape='pad',
  183. merge_no_span_structure=True):
  184. super(TableMasterLabelDecode, self).__init__(character_dict_type,
  185. merge_no_span_structure)
  186. self.box_shape = box_shape
  187. assert box_shape in [
  188. 'ori', 'pad'
  189. ], 'The shape used for box normalization must be ori or pad'
  190. def add_special_char(self, dict_character):
  191. """ add_special_char """
  192. self.beg_str = '<SOS>'
  193. self.end_str = '<EOS>'
  194. self.unknown_str = '<UKN>'
  195. self.pad_str = '<PAD>'
  196. dict_character = dict_character
  197. dict_character = dict_character + [
  198. self.unknown_str, self.beg_str, self.end_str, self.pad_str
  199. ]
  200. return dict_character
  201. def get_ignored_tokens(self):
  202. """ get_ignored_tokens """
  203. pad_idx = self.dict[self.pad_str]
  204. start_idx = self.dict[self.beg_str]
  205. end_idx = self.dict[self.end_str]
  206. unknown_idx = self.dict[self.unknown_str]
  207. return [start_idx, end_idx, pad_idx, unknown_idx]
  208. def _bbox_decode(self, bbox, shape):
  209. """ _bbox_decode """
  210. h, w, ratio_h, ratio_w, pad_h, pad_w = shape
  211. if self.box_shape == 'pad':
  212. h, w = pad_h, pad_w
  213. bbox[0::2] *= w
  214. bbox[1::2] *= h
  215. bbox[0::2] /= ratio_w
  216. bbox[1::2] /= ratio_h
  217. x, y, w, h = bbox
  218. x1, y1, x2, y2 = x - w // 2, y - h // 2, x + w // 2, y + h // 2
  219. bbox = np.array([x1, y1, x2, y2])
  220. return bbox
  221. class SaveTableResults(BaseTransform):
  222. """ SaveTableResults """
  223. _TABLE_RES_SUFFIX = '_bbox'
  224. _FILE_EXT = '.png'
  225. # _DEFAULT_FILE_NAME = 'table_res_out.png'
  226. def __init__(self, save_dir):
  227. super().__init__()
  228. self.save_dir = save_dir
  229. # We use pillow backend to save both numpy arrays and PIL Image objects
  230. self._writer = ImageWriter(backend='pillow')
  231. def apply(self, data):
  232. """ apply """
  233. ori_path = data[K.IM_PATH]
  234. bbox_res = data[K.BBOX_RES]
  235. file_name = os.path.basename(ori_path)
  236. file_name = self._replace_ext(file_name, self._FILE_EXT)
  237. table_res_save_path = os.path.join(self.save_dir, file_name)
  238. if len(bbox_res) > 0 and len(bbox_res[0]) == 4:
  239. vis_img = self.draw_rectangle(data[K.ORI_IM], bbox_res)
  240. else:
  241. vis_img = self.draw_bbox(data[K.ORI_IM], bbox_res)
  242. table_res_save_path = self._add_suffix(table_res_save_path,
  243. self._TABLE_RES_SUFFIX)
  244. self._write_im(table_res_save_path, vis_img)
  245. return data
  246. @classmethod
  247. def get_input_keys(cls):
  248. """ get input keys """
  249. return [K.IM_PATH, K.ORI_IM, K.BBOX_RES]
  250. @classmethod
  251. def get_output_keys(cls):
  252. """ get output keys """
  253. return []
  254. def _write_im(self, path, im):
  255. """ write image """
  256. if os.path.exists(path):
  257. logging.warning(f"{path} already exists. Overwriting it.")
  258. self._writer.write(path, im)
  259. @staticmethod
  260. def _add_suffix(path, suffix):
  261. """ _add_suffix """
  262. stem, ext = os.path.splitext(path)
  263. return stem + suffix + ext
  264. @staticmethod
  265. def _replace_ext(path, new_ext):
  266. """ _replace_ext """
  267. stem, _ = os.path.splitext(path)
  268. return stem + new_ext
  269. def draw_rectangle(self, img_path, boxes):
  270. """ draw_rectangle """
  271. boxes = np.array(boxes)
  272. img = cv2.imread(img_path)
  273. img_show = img.copy()
  274. for box in boxes.astype(int):
  275. x1, y1, x2, y2 = box
  276. cv2.rectangle(img_show, (x1, y1), (x2, y2), (255, 0, 0), 2)
  277. return img_show
  278. def draw_bbox(self, image, boxes):
  279. """ draw_bbox """
  280. for box in boxes:
  281. box = np.reshape(np.array(box), [-1, 1, 2]).astype(np.int64)
  282. image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2)
  283. return image
  284. class PrintResult(BaseTransform):
  285. """ Print Result Transform """
  286. def apply(self, data):
  287. """ apply """
  288. logging.info("The prediction result is:")
  289. logging.info(data[K.BOXES])
  290. return data
  291. @classmethod
  292. def get_input_keys(cls):
  293. """ get input keys """
  294. return [K.BOXES]
  295. @classmethod
  296. def get_output_keys(cls):
  297. """ get output keys """
  298. return []