# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import os.path as osp import numpy as np from PIL import Image import cv2 import paddle from ....utils import logging from ...base import BaseTransform from ...base.predictor.io.writers import ImageWriter from .keys import TableRecKeys as K __all__ = ['TableLabelDecode', 'TableMasterLabelDecode', 'SaveTableResults'] class TableLabelDecode(BaseTransform): """ decode the table model outputs(probs) to character str""" def __init__(self, character_dict_type='TableAttn_ch', merge_no_span_structure=True): dict_character = [] supported_dict = ['TableAttn_ch', 'TableAttn_en', 'TableMaster'] if character_dict_type == 'TableAttn_ch': character_dict_name = 'table_structure_dict_ch.txt' elif character_dict_type == 'TableAttn_en': character_dict_name = 'table_structure_dict.txt' elif character_dict_type == 'TableMaster': character_dict_name = 'table_master_structure_dict.txt' else: assert False, " character_dict_type must in %s " \ % supported_dict character_dict_path = osp.abspath( osp.join(osp.dirname(__file__), character_dict_name)) with open(character_dict_path, "rb") as fin: lines = fin.readlines() for line in lines: line = line.decode('utf-8').strip("\n").strip("\r\n") dict_character.append(line) if merge_no_span_structure: if "" not in dict_character: dict_character.append("") if "" in dict_character: dict_character.remove("") dict_character = self.add_special_char(dict_character) self.dict = {} for i, char in enumerate(dict_character): self.dict[char] = i self.character = dict_character self.td_token = ['', ''] def add_special_char(self, dict_character): """ add_special_char """ self.beg_str = "sos" self.end_str = "eos" dict_character = dict_character dict_character = [self.beg_str] + dict_character + [self.end_str] return dict_character def get_ignored_tokens(self): """ get_ignored_tokens """ beg_idx = self.get_beg_end_flag_idx("beg") end_idx = self.get_beg_end_flag_idx("end") return [beg_idx, end_idx] def get_beg_end_flag_idx(self, beg_or_end): """ get_beg_end_flag_idx """ if beg_or_end == "beg": idx = np.array(self.dict[self.beg_str]) elif beg_or_end == "end": idx = np.array(self.dict[self.end_str]) else: assert False, "unsupported type %s in get_beg_end_flag_idx" \ % beg_or_end return idx def apply(self, data): """ apply """ shape_list = data[K.SHAPE_LIST] structure_probs = data[K.STRUCTURE_PROB] bbox_preds = data[K.LOC_PROB] if isinstance(structure_probs, paddle.Tensor): structure_probs = structure_probs.numpy() if isinstance(bbox_preds, paddle.Tensor): bbox_preds = bbox_preds.numpy() post_result = self.decode(structure_probs, bbox_preds, shape_list) structure_str_list = post_result['structure_batch_list'][0] bbox_list = post_result['bbox_batch_list'][0] structure_str_list = structure_str_list[0] structure_str_list = [ '', '', '' ] + structure_str_list + ['
', '', ''] data[K.BBOX_RES] = bbox_list data[K.HTML_RES] = structure_str_list return data @classmethod def get_input_keys(cls): """ get input keys """ return [K.STRUCTURE_PROB, K.LOC_PROB, K.SHAPE_LIST] @classmethod def get_output_keys(cls): """ get output keys """ return [K.BBOX_RES, K.HTML_RES] def decode(self, structure_probs, bbox_preds, shape_list): """convert text-label into text-index. """ ignored_tokens = self.get_ignored_tokens() end_idx = self.dict[self.end_str] structure_idx = structure_probs.argmax(axis=2) structure_probs = structure_probs.max(axis=2) structure_batch_list = [] bbox_batch_list = [] batch_size = len(structure_idx) for batch_idx in range(batch_size): structure_list = [] bbox_list = [] score_list = [] for idx in range(len(structure_idx[batch_idx])): char_idx = int(structure_idx[batch_idx][idx]) if idx > 0 and char_idx == end_idx: break if char_idx in ignored_tokens: continue text = self.character[char_idx] if text in self.td_token: bbox = bbox_preds[batch_idx, idx] bbox = self._bbox_decode(bbox, shape_list[batch_idx]) bbox_list.append(bbox) structure_list.append(text) score_list.append(structure_probs[batch_idx, idx]) structure_batch_list.append([structure_list, np.mean(score_list)]) bbox_batch_list.append(np.array(bbox_list)) result = { 'bbox_batch_list': bbox_batch_list, 'structure_batch_list': structure_batch_list, } return result def decode_label(self, batch): """convert text-label into text-index. """ structure_idx = batch[1] gt_bbox_list = batch[2] shape_list = batch[-1] ignored_tokens = self.get_ignored_tokens() end_idx = self.dict[self.end_str] structure_batch_list = [] bbox_batch_list = [] batch_size = len(structure_idx) for batch_idx in range(batch_size): structure_list = [] bbox_list = [] for idx in range(len(structure_idx[batch_idx])): char_idx = int(structure_idx[batch_idx][idx]) if idx > 0 and char_idx == end_idx: break if char_idx in ignored_tokens: continue structure_list.append(self.character[char_idx]) bbox = gt_bbox_list[batch_idx][idx] if bbox.sum() != 0: bbox = self._bbox_decode(bbox, shape_list[batch_idx]) bbox_list.append(bbox) structure_batch_list.append(structure_list) bbox_batch_list.append(bbox_list) result = { 'bbox_batch_list': bbox_batch_list, 'structure_batch_list': structure_batch_list, } return result def _bbox_decode(self, bbox, shape): w, h = shape[:2] bbox[0::2] *= w bbox[1::2] *= h return bbox class TableMasterLabelDecode(TableLabelDecode): """ decode the table model outputs(probs) to character str""" def __init__(self, character_dict_type='TableMaster', box_shape='pad', merge_no_span_structure=True): super(TableMasterLabelDecode, self).__init__(character_dict_type, merge_no_span_structure) self.box_shape = box_shape assert box_shape in [ 'ori', 'pad' ], 'The shape used for box normalization must be ori or pad' def add_special_char(self, dict_character): """ add_special_char """ self.beg_str = '' self.end_str = '' self.unknown_str = '' self.pad_str = '' dict_character = dict_character dict_character = dict_character + [ self.unknown_str, self.beg_str, self.end_str, self.pad_str ] return dict_character def get_ignored_tokens(self): """ get_ignored_tokens """ pad_idx = self.dict[self.pad_str] start_idx = self.dict[self.beg_str] end_idx = self.dict[self.end_str] unknown_idx = self.dict[self.unknown_str] return [start_idx, end_idx, pad_idx, unknown_idx] def _bbox_decode(self, bbox, shape): """ _bbox_decode """ h, w, ratio_h, ratio_w, pad_h, pad_w = shape if self.box_shape == 'pad': h, w = pad_h, pad_w bbox[0::2] *= w bbox[1::2] *= h bbox[0::2] /= ratio_w bbox[1::2] /= ratio_h x, y, w, h = bbox x1, y1, x2, y2 = x - w // 2, y - h // 2, x + w // 2, y + h // 2 bbox = np.array([x1, y1, x2, y2]) return bbox class SaveTableResults(BaseTransform): """ SaveTableResults """ _TABLE_RES_SUFFIX = '_bbox' _FILE_EXT = '.png' # _DEFAULT_FILE_NAME = 'table_res_out.png' def __init__(self, save_dir): super().__init__() self.save_dir = save_dir # We use pillow backend to save both numpy arrays and PIL Image objects self._writer = ImageWriter(backend='pillow') def apply(self, data): """ apply """ ori_path = data[K.IM_PATH] bbox_res = data[K.BBOX_RES] file_name = os.path.basename(ori_path) file_name = self._replace_ext(file_name, self._FILE_EXT) table_res_save_path = os.path.join(self.save_dir, file_name) if len(bbox_res) > 0 and len(bbox_res[0]) == 4: vis_img = self.draw_rectangle(data[K.ORI_IM], bbox_res) else: vis_img = self.draw_bbox(data[K.ORI_IM], bbox_res) table_res_save_path = self._add_suffix(table_res_save_path, self._TABLE_RES_SUFFIX) self._write_im(table_res_save_path, vis_img) return data @classmethod def get_input_keys(cls): """ get input keys """ return [K.IM_PATH, K.ORI_IM, K.BBOX_RES] @classmethod def get_output_keys(cls): """ get output keys """ return [] def _write_im(self, path, im): """ write image """ if os.path.exists(path): logging.warning(f"{path} already exists. Overwriting it.") self._writer.write(path, im) @staticmethod def _add_suffix(path, suffix): """ _add_suffix """ stem, ext = os.path.splitext(path) return stem + suffix + ext @staticmethod def _replace_ext(path, new_ext): """ _replace_ext """ stem, _ = os.path.splitext(path) return stem + new_ext def draw_rectangle(self, img_path, boxes): """ draw_rectangle """ boxes = np.array(boxes) img = cv2.imread(img_path) img_show = img.copy() for box in boxes.astype(int): x1, y1, x2, y2 = box cv2.rectangle(img_show, (x1, y1), (x2, y2), (255, 0, 0), 2) return img_show def draw_bbox(self, image, boxes): """ draw_bbox """ for box in boxes: box = np.reshape(np.array(box), [-1, 1, 2]).astype(np.int64) image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2) return image class PrintResult(BaseTransform): """ Print Result Transform """ def apply(self, data): """ apply """ logging.info("The prediction result is:") logging.info(data[K.BOXES]) return data @classmethod def get_input_keys(cls): """ get input keys """ return [K.BOXES] @classmethod def get_output_keys(cls): """ get output keys """ return []