transforms.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import os.path as osp
  16. import numpy as np
  17. import cv2
  18. import lazy_paddle as paddle
  19. from .keys import TableRecKeys as K
  20. from ...base import BaseTransform
  21. from ...base.predictor.io.writers import ImageWriter
  22. from ....utils import logging
  23. __all__ = ["TableLabelDecode", "TableMasterLabelDecode", "SaveTableResults"]
  24. class TableLabelDecode(BaseTransform):
  25. """decode the table model outputs(probs) to character str"""
  26. def __init__(
  27. self, character_dict_type="TableAttn_ch", merge_no_span_structure=True
  28. ):
  29. dict_character = []
  30. supported_dict = ["TableAttn_ch", "TableAttn_en", "TableMaster"]
  31. if character_dict_type == "TableAttn_ch":
  32. character_dict_name = "table_structure_dict_ch.txt"
  33. elif character_dict_type == "TableAttn_en":
  34. character_dict_name = "table_structure_dict.txt"
  35. elif character_dict_type == "TableMaster":
  36. character_dict_name = "table_master_structure_dict.txt"
  37. else:
  38. assert False, " character_dict_type must in %s " % supported_dict
  39. character_dict_path = osp.abspath(
  40. osp.join(osp.dirname(__file__), character_dict_name)
  41. )
  42. with open(character_dict_path, "rb") as fin:
  43. lines = fin.readlines()
  44. for line in lines:
  45. line = line.decode("utf-8").strip("\n").strip("\r\n")
  46. dict_character.append(line)
  47. if merge_no_span_structure:
  48. if "<td></td>" not in dict_character:
  49. dict_character.append("<td></td>")
  50. if "<td>" in dict_character:
  51. dict_character.remove("<td>")
  52. dict_character = self.add_special_char(dict_character)
  53. self.dict = {}
  54. for i, char in enumerate(dict_character):
  55. self.dict[char] = i
  56. self.character = dict_character
  57. self.td_token = ["<td>", "<td", "<td></td>"]
  58. def add_special_char(self, dict_character):
  59. """add_special_char"""
  60. self.beg_str = "sos"
  61. self.end_str = "eos"
  62. dict_character = dict_character
  63. dict_character = [self.beg_str] + dict_character + [self.end_str]
  64. return dict_character
  65. def get_ignored_tokens(self):
  66. """get_ignored_tokens"""
  67. beg_idx = self.get_beg_end_flag_idx("beg")
  68. end_idx = self.get_beg_end_flag_idx("end")
  69. return [beg_idx, end_idx]
  70. def get_beg_end_flag_idx(self, beg_or_end):
  71. """get_beg_end_flag_idx"""
  72. if beg_or_end == "beg":
  73. idx = np.array(self.dict[self.beg_str])
  74. elif beg_or_end == "end":
  75. idx = np.array(self.dict[self.end_str])
  76. else:
  77. assert False, "unsupported type %s in get_beg_end_flag_idx" % beg_or_end
  78. return idx
  79. def apply(self, data):
  80. """apply"""
  81. shape_list = data[K.SHAPE_LIST]
  82. structure_probs = data[K.STRUCTURE_PROB]
  83. bbox_preds = data[K.LOC_PROB]
  84. if isinstance(structure_probs, paddle.Tensor):
  85. structure_probs = structure_probs.numpy()
  86. if isinstance(bbox_preds, paddle.Tensor):
  87. bbox_preds = bbox_preds.numpy()
  88. post_result = self.decode(structure_probs, bbox_preds, shape_list)
  89. structure_str_list = post_result["structure_batch_list"][0]
  90. bbox_list = post_result["bbox_batch_list"][0]
  91. structure_str_list = structure_str_list[0]
  92. structure_str_list = (
  93. ["<html>", "<body>", "<table>"]
  94. + structure_str_list
  95. + ["</table>", "</body>", "</html>"]
  96. )
  97. data[K.BBOX_RES] = bbox_list
  98. data[K.HTML_RES] = structure_str_list
  99. return data
  100. @classmethod
  101. def get_input_keys(cls):
  102. """get input keys"""
  103. return [K.STRUCTURE_PROB, K.LOC_PROB, K.SHAPE_LIST]
  104. @classmethod
  105. def get_output_keys(cls):
  106. """get output keys"""
  107. return [K.BBOX_RES, K.HTML_RES]
  108. def decode(self, structure_probs, bbox_preds, shape_list):
  109. """convert text-label into text-index."""
  110. ignored_tokens = self.get_ignored_tokens()
  111. end_idx = self.dict[self.end_str]
  112. structure_idx = structure_probs.argmax(axis=2)
  113. structure_probs = structure_probs.max(axis=2)
  114. structure_batch_list = []
  115. bbox_batch_list = []
  116. batch_size = len(structure_idx)
  117. for batch_idx in range(batch_size):
  118. structure_list = []
  119. bbox_list = []
  120. score_list = []
  121. for idx in range(len(structure_idx[batch_idx])):
  122. char_idx = int(structure_idx[batch_idx][idx])
  123. if idx > 0 and char_idx == end_idx:
  124. break
  125. if char_idx in ignored_tokens:
  126. continue
  127. text = self.character[char_idx]
  128. if text in self.td_token:
  129. bbox = bbox_preds[batch_idx, idx]
  130. bbox = self._bbox_decode(bbox, shape_list[batch_idx])
  131. bbox_list.append(bbox)
  132. structure_list.append(text)
  133. score_list.append(structure_probs[batch_idx, idx])
  134. structure_batch_list.append([structure_list, np.mean(score_list)])
  135. bbox_batch_list.append(np.array(bbox_list))
  136. result = {
  137. "bbox_batch_list": bbox_batch_list,
  138. "structure_batch_list": structure_batch_list,
  139. }
  140. return result
  141. def decode_label(self, batch):
  142. """convert text-label into text-index."""
  143. structure_idx = batch[1]
  144. gt_bbox_list = batch[2]
  145. shape_list = batch[-1]
  146. ignored_tokens = self.get_ignored_tokens()
  147. end_idx = self.dict[self.end_str]
  148. structure_batch_list = []
  149. bbox_batch_list = []
  150. batch_size = len(structure_idx)
  151. for batch_idx in range(batch_size):
  152. structure_list = []
  153. bbox_list = []
  154. for idx in range(len(structure_idx[batch_idx])):
  155. char_idx = int(structure_idx[batch_idx][idx])
  156. if idx > 0 and char_idx == end_idx:
  157. break
  158. if char_idx in ignored_tokens:
  159. continue
  160. structure_list.append(self.character[char_idx])
  161. bbox = gt_bbox_list[batch_idx][idx]
  162. if bbox.sum() != 0:
  163. bbox = self._bbox_decode(bbox, shape_list[batch_idx])
  164. bbox_list.append(bbox)
  165. structure_batch_list.append(structure_list)
  166. bbox_batch_list.append(bbox_list)
  167. result = {
  168. "bbox_batch_list": bbox_batch_list,
  169. "structure_batch_list": structure_batch_list,
  170. }
  171. return result
  172. def _bbox_decode(self, bbox, shape):
  173. w, h = shape[:2]
  174. bbox[0::2] *= w
  175. bbox[1::2] *= h
  176. return bbox
  177. class TableMasterLabelDecode(TableLabelDecode):
  178. """decode the table model outputs(probs) to character str"""
  179. def __init__(
  180. self,
  181. character_dict_type="TableMaster",
  182. box_shape="pad",
  183. merge_no_span_structure=True,
  184. ):
  185. super(TableMasterLabelDecode, self).__init__(
  186. character_dict_type, merge_no_span_structure
  187. )
  188. self.box_shape = box_shape
  189. assert box_shape in [
  190. "ori",
  191. "pad",
  192. ], "The shape used for box normalization must be ori or pad"
  193. def add_special_char(self, dict_character):
  194. """add_special_char"""
  195. self.beg_str = "<SOS>"
  196. self.end_str = "<EOS>"
  197. self.unknown_str = "<UKN>"
  198. self.pad_str = "<PAD>"
  199. dict_character = dict_character
  200. dict_character = dict_character + [
  201. self.unknown_str,
  202. self.beg_str,
  203. self.end_str,
  204. self.pad_str,
  205. ]
  206. return dict_character
  207. def get_ignored_tokens(self):
  208. """get_ignored_tokens"""
  209. pad_idx = self.dict[self.pad_str]
  210. start_idx = self.dict[self.beg_str]
  211. end_idx = self.dict[self.end_str]
  212. unknown_idx = self.dict[self.unknown_str]
  213. return [start_idx, end_idx, pad_idx, unknown_idx]
  214. def _bbox_decode(self, bbox, shape):
  215. """_bbox_decode"""
  216. h, w, ratio_h, ratio_w, pad_h, pad_w = shape
  217. if self.box_shape == "pad":
  218. h, w = pad_h, pad_w
  219. bbox[0::2] *= w
  220. bbox[1::2] *= h
  221. bbox[0::2] /= ratio_w
  222. bbox[1::2] /= ratio_h
  223. x, y, w, h = bbox
  224. x1, y1, x2, y2 = x - w // 2, y - h // 2, x + w // 2, y + h // 2
  225. bbox = np.array([x1, y1, x2, y2])
  226. return bbox
  227. class SaveTableResults(BaseTransform):
  228. """SaveTableResults"""
  229. _TABLE_RES_SUFFIX = "_bbox"
  230. _FILE_EXT = ".png"
  231. # _DEFAULT_FILE_NAME = 'table_res_out.png'
  232. def __init__(self, save_dir):
  233. super().__init__()
  234. self.save_dir = save_dir
  235. # We use pillow backend to save both numpy arrays and PIL Image objects
  236. self._writer = ImageWriter(backend="pillow")
  237. def apply(self, data):
  238. """apply"""
  239. ori_path = data[K.IM_PATH]
  240. bbox_res = data[K.BBOX_RES]
  241. file_name = os.path.basename(ori_path)
  242. file_name = self._replace_ext(file_name, self._FILE_EXT)
  243. table_res_save_path = os.path.join(self.save_dir, file_name)
  244. if len(bbox_res) > 0 and len(bbox_res[0]) == 4:
  245. vis_img = self.draw_rectangle(data[K.ORI_IM], bbox_res)
  246. else:
  247. vis_img = self.draw_bbox(data[K.ORI_IM], bbox_res)
  248. table_res_save_path = self._add_suffix(
  249. table_res_save_path, self._TABLE_RES_SUFFIX
  250. )
  251. self._write_im(table_res_save_path, vis_img)
  252. return data
  253. @classmethod
  254. def get_input_keys(cls):
  255. """get input keys"""
  256. return [K.IM_PATH, K.ORI_IM, K.BBOX_RES]
  257. @classmethod
  258. def get_output_keys(cls):
  259. """get output keys"""
  260. return []
  261. def _write_im(self, path, im):
  262. """write image"""
  263. if os.path.exists(path):
  264. logging.warning(f"{path} already exists. Overwriting it.")
  265. self._writer.write(path, im)
  266. @staticmethod
  267. def _add_suffix(path, suffix):
  268. """_add_suffix"""
  269. stem, ext = os.path.splitext(path)
  270. return stem + suffix + ext
  271. @staticmethod
  272. def _replace_ext(path, new_ext):
  273. """_replace_ext"""
  274. stem, _ = os.path.splitext(path)
  275. return stem + new_ext
  276. def draw_rectangle(self, img_path, boxes):
  277. """draw_rectangle"""
  278. boxes = np.array(boxes)
  279. img = cv2.imread(img_path)
  280. img_show = img.copy()
  281. for box in boxes.astype(int):
  282. x1, y1, x2, y2 = box
  283. cv2.rectangle(img_show, (x1, y1), (x2, y2), (255, 0, 0), 2)
  284. return img_show
  285. def draw_bbox(self, image, boxes):
  286. """draw_bbox"""
  287. for box in boxes:
  288. box = np.reshape(np.array(box), [-1, 1, 2]).astype(np.int64)
  289. image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2)
  290. return image
  291. class PrintResult(BaseTransform):
  292. """Print Result Transform"""
  293. def apply(self, data):
  294. """apply"""
  295. logging.info("The prediction result is:")
  296. logging.info(data[K.BOXES])
  297. return data
  298. @classmethod
  299. def get_input_keys(cls):
  300. """get input keys"""
  301. return [K.BOXES]
  302. @classmethod
  303. def get_output_keys(cls):
  304. """get output keys"""
  305. return []