processors.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import cv2
  15. import numpy as np
  16. from numpy import ndarray
  17. from ..common.vision import funcs as F
  18. from ...utils.benchmark import benchmark
  19. @benchmark.timeit
  20. class Pad:
  21. """Pad the image."""
  22. def __init__(self, target_size, val=127.5):
  23. """
  24. Initialize the instance.
  25. Args:
  26. target_size (list|tuple|int): Target width and height of the image after
  27. padding.
  28. val (float, optional): Value to fill the padded area. Default: 127.5.
  29. """
  30. super().__init__()
  31. if isinstance(target_size, int):
  32. target_size = [target_size, target_size]
  33. self.target_size = target_size
  34. self.val = val
  35. def apply(self, img):
  36. """apply"""
  37. h, w = img.shape[:2]
  38. tw, th = self.target_size
  39. ph = th - h
  40. pw = tw - w
  41. if ph < 0 or pw < 0:
  42. raise ValueError(
  43. f"Input image ({w}, {h}) smaller than the target size ({tw}, {th})."
  44. )
  45. else:
  46. img = F.pad(img, pad=(0, ph, 0, pw), val=self.val)
  47. return [img, [img.shape[1], img.shape[0]]]
  48. def __call__(self, imgs):
  49. """apply"""
  50. return [self.apply(img) for img in imgs]
  51. @benchmark.timeit
  52. class TableLabelDecode:
  53. """decode the table model outputs(probs) to character str"""
  54. ENABLE_BATCH = True
  55. INPUT_KEYS = ["pred", "img_size", "ori_img_size"]
  56. OUTPUT_KEYS = ["bbox", "structure", "structure_score"]
  57. DEAULT_INPUTS = {
  58. "pred": "pred",
  59. "img_size": "img_size",
  60. "ori_img_size": "ori_img_size",
  61. }
  62. DEAULT_OUTPUTS = {
  63. "bbox": "bbox",
  64. "structure": "structure",
  65. "structure_score": "structure_score",
  66. }
  67. def __init__(self, model_name, merge_no_span_structure=True, dict_character=[]):
  68. super().__init__()
  69. if merge_no_span_structure:
  70. if "<td></td>" not in dict_character:
  71. dict_character.append("<td></td>")
  72. if "<td>" in dict_character:
  73. dict_character.remove("<td>")
  74. self.model_name = model_name
  75. dict_character = self.add_special_char(dict_character)
  76. self.dict = {}
  77. for i, char in enumerate(dict_character):
  78. self.dict[char] = i
  79. self.character = dict_character
  80. self.td_token = ["<td>", "<td", "<td></td>"]
  81. def add_special_char(self, dict_character):
  82. """add_special_char"""
  83. self.beg_str = "sos"
  84. self.end_str = "eos"
  85. dict_character = dict_character
  86. dict_character = [self.beg_str] + dict_character + [self.end_str]
  87. return dict_character
  88. def get_ignored_tokens(self):
  89. """get_ignored_tokens"""
  90. beg_idx = self.get_beg_end_flag_idx("beg")
  91. end_idx = self.get_beg_end_flag_idx("end")
  92. return [beg_idx, end_idx]
  93. def get_beg_end_flag_idx(self, beg_or_end):
  94. """get_beg_end_flag_idx"""
  95. if beg_or_end == "beg":
  96. idx = np.array(self.dict[self.beg_str])
  97. elif beg_or_end == "end":
  98. idx = np.array(self.dict[self.end_str])
  99. else:
  100. assert False, "unsupported type %s in get_beg_end_flag_idx" % beg_or_end
  101. return idx
  102. def __call__(self, pred, img_size, ori_img_size):
  103. """apply"""
  104. bbox_preds, structure_probs = [], []
  105. for i in range(len(pred[0][0])):
  106. bbox_preds.append(pred[0][0][i])
  107. structure_probs.append(pred[1][0][i])
  108. bbox_preds = [bbox_preds]
  109. structure_probs = [structure_probs]
  110. bbox_preds = np.array(bbox_preds)
  111. structure_probs = np.array(structure_probs)
  112. bbox_list, structure_str_list, structure_score = self.decode(
  113. structure_probs, bbox_preds, img_size, ori_img_size
  114. )
  115. structure_str_list = [
  116. (
  117. ["<html>", "<body>", "<table>"]
  118. + structure
  119. + ["</table>", "</body>", "</html>"]
  120. )
  121. for structure in structure_str_list
  122. ]
  123. return [
  124. {"bbox": bbox, "structure": structure, "structure_score": structure_score}
  125. for bbox, structure in zip(bbox_list, structure_str_list)
  126. ]
  127. def decode(self, structure_probs, bbox_preds, padding_size, ori_img_size):
  128. """convert text-label into text-index."""
  129. ignored_tokens = self.get_ignored_tokens()
  130. end_idx = self.dict[self.end_str]
  131. structure_idx = structure_probs.argmax(axis=2)
  132. structure_probs = structure_probs.max(axis=2)
  133. structure_batch_list = []
  134. bbox_batch_list = []
  135. batch_size = len(structure_idx)
  136. for batch_idx in range(batch_size):
  137. structure_list = []
  138. bbox_list = []
  139. score_list = []
  140. for idx in range(len(structure_idx[batch_idx])):
  141. char_idx = int(structure_idx[batch_idx][idx])
  142. if idx > 0 and char_idx == end_idx:
  143. break
  144. if char_idx in ignored_tokens:
  145. continue
  146. text = self.character[char_idx]
  147. if text in self.td_token:
  148. bbox = bbox_preds[batch_idx, idx]
  149. bbox = self._bbox_decode(
  150. bbox, padding_size[batch_idx], ori_img_size[batch_idx]
  151. )
  152. bbox_list.append(bbox.astype(int))
  153. structure_list.append(text)
  154. score_list.append(structure_probs[batch_idx, idx])
  155. structure_batch_list.append(structure_list)
  156. structure_score = np.mean(score_list)
  157. bbox_batch_list.append(bbox_list)
  158. return bbox_batch_list, structure_batch_list, structure_score
  159. def decode_label(self, batch):
  160. """convert text-label into text-index."""
  161. structure_idx = batch[1]
  162. gt_bbox_list = batch[2]
  163. shape_list = batch[-1]
  164. ignored_tokens = self.get_ignored_tokens()
  165. end_idx = self.dict[self.end_str]
  166. structure_batch_list = []
  167. bbox_batch_list = []
  168. batch_size = len(structure_idx)
  169. for batch_idx in range(batch_size):
  170. structure_list = []
  171. bbox_list = []
  172. for idx in range(len(structure_idx[batch_idx])):
  173. char_idx = int(structure_idx[batch_idx][idx])
  174. if idx > 0 and char_idx == end_idx:
  175. break
  176. if char_idx in ignored_tokens:
  177. continue
  178. structure_list.append(self.character[char_idx])
  179. bbox = gt_bbox_list[batch_idx][idx]
  180. if bbox.sum() != 0:
  181. bbox = self._bbox_decode(bbox, shape_list[batch_idx])
  182. bbox_list.append(bbox.astype(int))
  183. structure_batch_list.append(structure_list)
  184. bbox_batch_list.append(bbox_list)
  185. return bbox_batch_list, structure_batch_list
  186. def _bbox_decode(self, bbox, padding_shape, ori_shape):
  187. if self.model_name == "SLANet":
  188. w, h = ori_shape
  189. bbox[0::2] *= w
  190. bbox[1::2] *= h
  191. else:
  192. w, h = padding_shape
  193. ori_w, ori_h = ori_shape
  194. ratio_w = w / ori_w
  195. ratio_h = h / ori_h
  196. ratio = min(ratio_w, ratio_h)
  197. bbox[0::2] *= w
  198. bbox[1::2] *= h
  199. bbox[0::2] /= ratio
  200. bbox[1::2] /= ratio
  201. return bbox