table_structure_unitable.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. import re
  2. import time
  3. import cv2
  4. import numpy as np
  5. import torch
  6. from PIL import Image
  7. from tokenizers import Tokenizer
  8. from torchvision import transforms
  9. from .unitable_modules import Encoder, GPTFastDecoder
  10. IMG_SIZE = 448
  11. EOS_TOKEN = "<eos>"
  12. BBOX_TOKENS = [f"bbox-{i}" for i in range(IMG_SIZE + 1)]
  13. HTML_BBOX_HTML_TOKENS = [
  14. "<td></td>",
  15. "<td>[",
  16. "]</td>",
  17. "<td",
  18. ">[",
  19. "></td>",
  20. "<tr>",
  21. "</tr>",
  22. "<tbody>",
  23. "</tbody>",
  24. "<thead>",
  25. "</thead>",
  26. ' rowspan="2"',
  27. ' rowspan="3"',
  28. ' rowspan="4"',
  29. ' rowspan="5"',
  30. ' rowspan="6"',
  31. ' rowspan="7"',
  32. ' rowspan="8"',
  33. ' rowspan="9"',
  34. ' rowspan="10"',
  35. ' rowspan="11"',
  36. ' rowspan="12"',
  37. ' rowspan="13"',
  38. ' rowspan="14"',
  39. ' rowspan="15"',
  40. ' rowspan="16"',
  41. ' rowspan="17"',
  42. ' rowspan="18"',
  43. ' rowspan="19"',
  44. ' colspan="2"',
  45. ' colspan="3"',
  46. ' colspan="4"',
  47. ' colspan="5"',
  48. ' colspan="6"',
  49. ' colspan="7"',
  50. ' colspan="8"',
  51. ' colspan="9"',
  52. ' colspan="10"',
  53. ' colspan="11"',
  54. ' colspan="12"',
  55. ' colspan="13"',
  56. ' colspan="14"',
  57. ' colspan="15"',
  58. ' colspan="16"',
  59. ' colspan="17"',
  60. ' colspan="18"',
  61. ' colspan="19"',
  62. ' colspan="25"',
  63. ]
  64. VALID_HTML_BBOX_TOKENS = [EOS_TOKEN] + HTML_BBOX_HTML_TOKENS + BBOX_TOKENS
  65. TASK_TOKENS = [
  66. "[table]",
  67. "[html]",
  68. "[cell]",
  69. "[bbox]",
  70. "[cell+bbox]",
  71. "[html+bbox]",
  72. ]
  73. class TableStructureUnitable:
  74. def __init__(self, config):
  75. # encoder_path: str, decoder_path: str, vocab_path: str, device: str
  76. vocab_path = config["model_path"]["vocab"]
  77. encoder_path = config["model_path"]["encoder"]
  78. decoder_path = config["model_path"]["decoder"]
  79. device = config.get("device", "cuda:0") if config["use_cuda"] else "cpu"
  80. self.vocab = Tokenizer.from_file(vocab_path)
  81. self.token_white_list = [
  82. self.vocab.token_to_id(i) for i in VALID_HTML_BBOX_TOKENS
  83. ]
  84. self.bbox_token_ids = set(self.vocab.token_to_id(i) for i in BBOX_TOKENS)
  85. self.bbox_close_html_token = self.vocab.token_to_id("]</td>")
  86. self.prefix_token_id = self.vocab.token_to_id("[html+bbox]")
  87. self.eos_id = self.vocab.token_to_id(EOS_TOKEN)
  88. self.max_seq_len = 1024
  89. self.device = device
  90. self.img_size = IMG_SIZE
  91. # init encoder
  92. encoder_state_dict = torch.load(encoder_path, map_location=device)
  93. self.encoder = Encoder()
  94. self.encoder.load_state_dict(encoder_state_dict)
  95. self.encoder.eval().to(device)
  96. # init decoder
  97. decoder_state_dict = torch.load(decoder_path, map_location=device)
  98. self.decoder = GPTFastDecoder()
  99. self.decoder.load_state_dict(decoder_state_dict)
  100. self.decoder.eval().to(device)
  101. # define img transform
  102. self.transform = transforms.Compose(
  103. [
  104. transforms.Resize((448, 448)),
  105. transforms.ToTensor(),
  106. transforms.Normalize(
  107. mean=[0.86597056, 0.88463002, 0.87491087],
  108. std=[0.20686628, 0.18201602, 0.18485524],
  109. ),
  110. ]
  111. )
  112. @torch.inference_mode()
  113. def __call__(self, image: np.ndarray):
  114. start_time = time.time()
  115. ori_h, ori_w = image.shape[:2]
  116. image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  117. image = Image.fromarray(image)
  118. image = self.transform(image).unsqueeze(0).to(self.device)
  119. self.decoder.setup_caches(
  120. max_batch_size=1,
  121. max_seq_length=self.max_seq_len,
  122. dtype=image.dtype,
  123. device=self.device,
  124. )
  125. context = (
  126. torch.tensor([self.prefix_token_id], dtype=torch.int32)
  127. .repeat(1, 1)
  128. .to(self.device)
  129. )
  130. eos_id_tensor = torch.tensor(self.eos_id, dtype=torch.int32).to(self.device)
  131. memory = self.encoder(image)
  132. context = self.loop_decode(context, eos_id_tensor, memory)
  133. bboxes, html_tokens = self.decode_tokens(context)
  134. bboxes = bboxes.astype(np.float32)
  135. # rescale boxes
  136. scale_h = ori_h / self.img_size
  137. scale_w = ori_w / self.img_size
  138. bboxes[:, 0::2] *= scale_w # 缩放 x 坐标
  139. bboxes[:, 1::2] *= scale_h # 缩放 y 坐标
  140. bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, ori_w - 1)
  141. bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, ori_h - 1)
  142. structure_str_list = (
  143. ["<html>", "<body>", "<table>"]
  144. + html_tokens
  145. + ["</table>", "</body>", "</html>"]
  146. )
  147. return structure_str_list, bboxes, time.time() - start_time
  148. def decode_tokens(self, context):
  149. pred_html = context[0]
  150. pred_html = pred_html.detach().cpu().numpy()
  151. pred_html = self.vocab.decode(pred_html, skip_special_tokens=False)
  152. seq = pred_html.split("<eos>")[0]
  153. token_black_list = ["<eos>", "<pad>", *TASK_TOKENS]
  154. for i in token_black_list:
  155. seq = seq.replace(i, "")
  156. tr_pattern = re.compile(r"<tr>(.*?)</tr>", re.DOTALL)
  157. td_pattern = re.compile(r"<td(.*?)>(.*?)</td>", re.DOTALL)
  158. bbox_pattern = re.compile(r"\[ bbox-(\d+) bbox-(\d+) bbox-(\d+) bbox-(\d+) \]")
  159. decoded_list = []
  160. bbox_coords = []
  161. # 查找所有的 <tr> 标签
  162. for tr_match in tr_pattern.finditer(pred_html):
  163. tr_content = tr_match.group(1)
  164. decoded_list.append("<tr>")
  165. # 查找所有的 <td> 标签
  166. for td_match in td_pattern.finditer(tr_content):
  167. td_attrs = td_match.group(1).strip()
  168. td_content = td_match.group(2).strip()
  169. if td_attrs:
  170. decoded_list.append("<td")
  171. # 可能同时存在行列合并,需要都添加
  172. attrs_list = td_attrs.split()
  173. for attr in attrs_list:
  174. decoded_list.append(" " + attr)
  175. decoded_list.append(">")
  176. decoded_list.append("</td>")
  177. else:
  178. decoded_list.append("<td></td>")
  179. # 查找 bbox 坐标
  180. bbox_match = bbox_pattern.search(td_content)
  181. if bbox_match:
  182. xmin, ymin, xmax, ymax = map(int, bbox_match.groups())
  183. # 将坐标转换为从左上角开始顺时针到左下角的点的坐标
  184. coords = np.array([xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax])
  185. bbox_coords.append(coords)
  186. else:
  187. # 填充占位的bbox,保证后续流程统一
  188. bbox_coords.append(np.array([0, 0, 0, 0, 0, 0, 0, 0]))
  189. decoded_list.append("</tr>")
  190. bbox_coords_array = np.array(bbox_coords)
  191. return bbox_coords_array, decoded_list
  192. def loop_decode(self, context, eos_id_tensor, memory):
  193. box_token_count = 0
  194. for _ in range(self.max_seq_len):
  195. eos_flag = (context == eos_id_tensor).any(dim=1)
  196. if torch.all(eos_flag):
  197. break
  198. next_tokens = self.decoder(memory, context)
  199. if next_tokens[0] in self.bbox_token_ids:
  200. box_token_count += 1
  201. if box_token_count > 4:
  202. next_tokens = torch.tensor(
  203. [self.bbox_close_html_token], dtype=torch.int32
  204. )
  205. box_token_count = 0
  206. context = torch.cat([context, next_tokens], dim=1)
  207. return context