table_rec.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. from ..base import BaseComponent
  16. __all__ = ["TableLabelDecode", "TableMasterLabelDecode"]
  17. class TableLabelDecode(BaseComponent):
  18. """decode the table model outputs(probs) to character str"""
  19. ENABLE_BATCH = True
  20. INPUT_KEYS = ["pred", "ori_img_size"]
  21. OUTPUT_KEYS = ["bbox", "structure"]
  22. DEAULT_INPUTS = {"pred": "pred", "ori_img_size": "ori_img_size"}
  23. DEAULT_OUTPUTS = {"bbox": "bbox", "structure": "structure"}
  24. def __init__(self, merge_no_span_structure=True, dict_character=[]):
  25. super().__init__()
  26. if merge_no_span_structure:
  27. if "<td></td>" not in dict_character:
  28. dict_character.append("<td></td>")
  29. if "<td>" in dict_character:
  30. dict_character.remove("<td>")
  31. dict_character = self.add_special_char(dict_character)
  32. self.dict = {}
  33. for i, char in enumerate(dict_character):
  34. self.dict[char] = i
  35. self.character = dict_character
  36. self.td_token = ["<td>", "<td", "<td></td>"]
  37. def add_special_char(self, dict_character):
  38. """add_special_char"""
  39. self.beg_str = "sos"
  40. self.end_str = "eos"
  41. dict_character = dict_character
  42. dict_character = [self.beg_str] + dict_character + [self.end_str]
  43. return dict_character
  44. def get_ignored_tokens(self):
  45. """get_ignored_tokens"""
  46. beg_idx = self.get_beg_end_flag_idx("beg")
  47. end_idx = self.get_beg_end_flag_idx("end")
  48. return [beg_idx, end_idx]
  49. def get_beg_end_flag_idx(self, beg_or_end):
  50. """get_beg_end_flag_idx"""
  51. if beg_or_end == "beg":
  52. idx = np.array(self.dict[self.beg_str])
  53. elif beg_or_end == "end":
  54. idx = np.array(self.dict[self.end_str])
  55. else:
  56. assert False, "unsupported type %s in get_beg_end_flag_idx" % beg_or_end
  57. return idx
  58. def apply(self, pred, ori_img_size):
  59. """apply"""
  60. bbox_preds, structure_probs = [], []
  61. for bbox_pred, stru_prob in pred:
  62. bbox_preds.append(bbox_pred)
  63. structure_probs.append(stru_prob)
  64. bbox_preds = np.array(bbox_preds)
  65. structure_probs = np.array(structure_probs)
  66. bbox_list, structure_str_list = self.decode(
  67. structure_probs, bbox_preds, ori_img_size
  68. )
  69. structure_str_list = [
  70. (
  71. ["<html>", "<body>", "<table>"]
  72. + structure
  73. + ["</table>", "</body>", "</html>"]
  74. )
  75. for structure in structure_str_list
  76. ]
  77. return [
  78. {"bbox": bbox, "structure": structure}
  79. for bbox, structure in zip(bbox_list, structure_str_list)
  80. ]
  81. def decode(self, structure_probs, bbox_preds, shape_list):
  82. """convert text-label into text-index."""
  83. ignored_tokens = self.get_ignored_tokens()
  84. end_idx = self.dict[self.end_str]
  85. structure_idx = structure_probs.argmax(axis=2)
  86. structure_probs = structure_probs.max(axis=2)
  87. structure_batch_list = []
  88. bbox_batch_list = []
  89. batch_size = len(structure_idx)
  90. for batch_idx in range(batch_size):
  91. structure_list = []
  92. bbox_list = []
  93. score_list = []
  94. for idx in range(len(structure_idx[batch_idx])):
  95. char_idx = int(structure_idx[batch_idx][idx])
  96. if idx > 0 and char_idx == end_idx:
  97. break
  98. if char_idx in ignored_tokens:
  99. continue
  100. text = self.character[char_idx]
  101. if text in self.td_token:
  102. bbox = bbox_preds[batch_idx, idx]
  103. bbox = self._bbox_decode(bbox, shape_list[batch_idx])
  104. bbox_list.append(bbox.tolist())
  105. structure_list.append(text)
  106. score_list.append(structure_probs[batch_idx, idx])
  107. structure_batch_list.append([structure_list, float(np.mean(score_list))])
  108. bbox_batch_list.append(bbox_list)
  109. return bbox_batch_list, structure_batch_list
  110. def decode_label(self, batch):
  111. """convert text-label into text-index."""
  112. structure_idx = batch[1]
  113. gt_bbox_list = batch[2]
  114. shape_list = batch[-1]
  115. ignored_tokens = self.get_ignored_tokens()
  116. end_idx = self.dict[self.end_str]
  117. structure_batch_list = []
  118. bbox_batch_list = []
  119. batch_size = len(structure_idx)
  120. for batch_idx in range(batch_size):
  121. structure_list = []
  122. bbox_list = []
  123. for idx in range(len(structure_idx[batch_idx])):
  124. char_idx = int(structure_idx[batch_idx][idx])
  125. if idx > 0 and char_idx == end_idx:
  126. break
  127. if char_idx in ignored_tokens:
  128. continue
  129. structure_list.append(self.character[char_idx])
  130. bbox = gt_bbox_list[batch_idx][idx]
  131. if bbox.sum() != 0:
  132. bbox = self._bbox_decode(bbox, shape_list[batch_idx])
  133. bbox_list.append(bbox.tolist())
  134. structure_batch_list.append(structure_list)
  135. bbox_batch_list.append(bbox_list)
  136. return bbox_batch_list, structure_batch_list
  137. def _bbox_decode(self, bbox, shape):
  138. w, h = shape[:2]
  139. bbox[0::2] *= w
  140. bbox[1::2] *= h
  141. return bbox
  142. class TableMasterLabelDecode(TableLabelDecode):
  143. """decode the table model outputs(probs) to character str"""
  144. def __init__(
  145. self,
  146. character_dict_type="TableMaster",
  147. box_shape="pad",
  148. merge_no_span_structure=True,
  149. ):
  150. super(TableMasterLabelDecode, self).__init__(
  151. character_dict_type, merge_no_span_structure
  152. )
  153. self.box_shape = box_shape
  154. assert box_shape in [
  155. "ori",
  156. "pad",
  157. ], "The shape used for box normalization must be ori or pad"
  158. def add_special_char(self, dict_character):
  159. """add_special_char"""
  160. self.beg_str = "<SOS>"
  161. self.end_str = "<EOS>"
  162. self.unknown_str = "<UKN>"
  163. self.pad_str = "<PAD>"
  164. dict_character = dict_character
  165. dict_character = dict_character + [
  166. self.unknown_str,
  167. self.beg_str,
  168. self.end_str,
  169. self.pad_str,
  170. ]
  171. return dict_character
  172. def get_ignored_tokens(self):
  173. """get_ignored_tokens"""
  174. pad_idx = self.dict[self.pad_str]
  175. start_idx = self.dict[self.beg_str]
  176. end_idx = self.dict[self.end_str]
  177. unknown_idx = self.dict[self.unknown_str]
  178. return [start_idx, end_idx, pad_idx, unknown_idx]
  179. def _bbox_decode(self, bbox, shape):
  180. """_bbox_decode"""
  181. h, w, ratio_h, ratio_w, pad_h, pad_w = shape
  182. if self.box_shape == "pad":
  183. h, w = pad_h, pad_w
  184. bbox[0::2] *= w
  185. bbox[1::2] *= h
  186. bbox[0::2] /= ratio_w
  187. bbox[1::2] /= ratio_h
  188. x, y, w, h = bbox
  189. x1, y1, x2, y2 = x - w // 2, y - h // 2, x + w // 2, y + h // 2
  190. bbox = np.array([x1, y1, x2, y2])
  191. return bbox