table_rec.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. from ..base import BaseComponent
  16. __all__ = ["TableLabelDecode"]
  17. class TableLabelDecode(BaseComponent):
  18. """decode the table model outputs(probs) to character str"""
  19. ENABLE_BATCH = True
  20. INPUT_KEYS = ["pred", "img_size", "ori_img_size"]
  21. OUTPUT_KEYS = ["bbox", "structure", "structure_score"]
  22. DEAULT_INPUTS = {
  23. "pred": "pred",
  24. "img_size": "img_size",
  25. "ori_img_size": "ori_img_size",
  26. }
  27. DEAULT_OUTPUTS = {
  28. "bbox": "bbox",
  29. "structure": "structure",
  30. "structure_score": "structure_score",
  31. }
  32. def __init__(self, model_name, merge_no_span_structure=True, dict_character=[]):
  33. super().__init__()
  34. if merge_no_span_structure:
  35. if "<td></td>" not in dict_character:
  36. dict_character.append("<td></td>")
  37. if "<td>" in dict_character:
  38. dict_character.remove("<td>")
  39. self.model_name = model_name
  40. dict_character = self.add_special_char(dict_character)
  41. self.dict = {}
  42. for i, char in enumerate(dict_character):
  43. self.dict[char] = i
  44. self.character = dict_character
  45. self.td_token = ["<td>", "<td", "<td></td>"]
  46. def add_special_char(self, dict_character):
  47. """add_special_char"""
  48. self.beg_str = "sos"
  49. self.end_str = "eos"
  50. dict_character = dict_character
  51. dict_character = [self.beg_str] + dict_character + [self.end_str]
  52. return dict_character
  53. def get_ignored_tokens(self):
  54. """get_ignored_tokens"""
  55. beg_idx = self.get_beg_end_flag_idx("beg")
  56. end_idx = self.get_beg_end_flag_idx("end")
  57. return [beg_idx, end_idx]
  58. def get_beg_end_flag_idx(self, beg_or_end):
  59. """get_beg_end_flag_idx"""
  60. if beg_or_end == "beg":
  61. idx = np.array(self.dict[self.beg_str])
  62. elif beg_or_end == "end":
  63. idx = np.array(self.dict[self.end_str])
  64. else:
  65. assert False, "unsupported type %s in get_beg_end_flag_idx" % beg_or_end
  66. return idx
  67. def apply(self, pred, img_size, ori_img_size):
  68. """apply"""
  69. bbox_preds, structure_probs = [], []
  70. for bbox_pred, stru_prob in pred:
  71. bbox_preds.append(bbox_pred)
  72. structure_probs.append(stru_prob)
  73. bbox_preds = np.array(bbox_preds)
  74. structure_probs = np.array(structure_probs)
  75. bbox_list, structure_str_list, structure_score = self.decode(
  76. structure_probs, bbox_preds, img_size, ori_img_size
  77. )
  78. structure_str_list = [
  79. (
  80. ["<html>", "<body>", "<table>"]
  81. + structure
  82. + ["</table>", "</body>", "</html>"]
  83. )
  84. for structure in structure_str_list
  85. ]
  86. return [
  87. {"bbox": bbox, "structure": structure, "structure_score": structure_score}
  88. for bbox, structure in zip(bbox_list, structure_str_list)
  89. ]
  90. def decode(self, structure_probs, bbox_preds, padding_size, ori_img_size):
  91. """convert text-label into text-index."""
  92. ignored_tokens = self.get_ignored_tokens()
  93. end_idx = self.dict[self.end_str]
  94. structure_idx = structure_probs.argmax(axis=2)
  95. structure_probs = structure_probs.max(axis=2)
  96. structure_batch_list = []
  97. bbox_batch_list = []
  98. batch_size = len(structure_idx)
  99. for batch_idx in range(batch_size):
  100. structure_list = []
  101. bbox_list = []
  102. score_list = []
  103. for idx in range(len(structure_idx[batch_idx])):
  104. char_idx = int(structure_idx[batch_idx][idx])
  105. if idx > 0 and char_idx == end_idx:
  106. break
  107. if char_idx in ignored_tokens:
  108. continue
  109. text = self.character[char_idx]
  110. if text in self.td_token:
  111. bbox = bbox_preds[batch_idx, idx]
  112. bbox = self._bbox_decode(
  113. bbox, padding_size[batch_idx], ori_img_size[batch_idx]
  114. )
  115. bbox_list.append(bbox.astype(int))
  116. structure_list.append(text)
  117. score_list.append(structure_probs[batch_idx, idx])
  118. structure_batch_list.append(structure_list)
  119. structure_score = np.mean(score_list)
  120. bbox_batch_list.append(bbox_list)
  121. return bbox_batch_list, structure_batch_list, structure_score
  122. def decode_label(self, batch):
  123. """convert text-label into text-index."""
  124. structure_idx = batch[1]
  125. gt_bbox_list = batch[2]
  126. shape_list = batch[-1]
  127. ignored_tokens = self.get_ignored_tokens()
  128. end_idx = self.dict[self.end_str]
  129. structure_batch_list = []
  130. bbox_batch_list = []
  131. batch_size = len(structure_idx)
  132. for batch_idx in range(batch_size):
  133. structure_list = []
  134. bbox_list = []
  135. for idx in range(len(structure_idx[batch_idx])):
  136. char_idx = int(structure_idx[batch_idx][idx])
  137. if idx > 0 and char_idx == end_idx:
  138. break
  139. if char_idx in ignored_tokens:
  140. continue
  141. structure_list.append(self.character[char_idx])
  142. bbox = gt_bbox_list[batch_idx][idx]
  143. if bbox.sum() != 0:
  144. bbox = self._bbox_decode(bbox, shape_list[batch_idx])
  145. bbox_list.append(bbox.astype(int))
  146. structure_batch_list.append(structure_list)
  147. bbox_batch_list.append(bbox_list)
  148. return bbox_batch_list, structure_batch_list
  149. def _bbox_decode(self, bbox, padding_shape, ori_shape):
  150. if self.model_name == "SLANet":
  151. w, h = ori_shape
  152. bbox[0::2] *= w
  153. bbox[1::2] *= h
  154. else:
  155. w, h = padding_shape
  156. ori_w, ori_h = ori_shape
  157. ratio_w = w / ori_w
  158. ratio_h = h / ori_h
  159. ratio = min(ratio_w, ratio_h)
  160. bbox[0::2] *= w
  161. bbox[1::2] *= h
  162. bbox[0::2] /= ratio
  163. bbox[1::2] /= ratio
  164. return bbox