# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import os.path as osp import numpy as np import cv2 import paddle from .keys import TableRecKeys as K from ...base import BaseTransform from ...base.predictor.io.writers import ImageWriter from ....utils import logging __all__ = ['TableLabelDecode', 'TableMasterLabelDecode', 'SaveTableResults'] class TableLabelDecode(BaseTransform): """ decode the table model outputs(probs) to character str""" def __init__(self, character_dict_type='TableAttn_ch', merge_no_span_structure=True): dict_character = [] supported_dict = ['TableAttn_ch', 'TableAttn_en', 'TableMaster'] if character_dict_type == 'TableAttn_ch': character_dict_name = 'table_structure_dict_ch.txt' elif character_dict_type == 'TableAttn_en': character_dict_name = 'table_structure_dict.txt' elif character_dict_type == 'TableMaster': character_dict_name = 'table_master_structure_dict.txt' else: assert False, " character_dict_type must in %s " \ % supported_dict character_dict_path = osp.abspath( osp.join(osp.dirname(__file__), character_dict_name)) with open(character_dict_path, "rb") as fin: lines = fin.readlines() for line in lines: line = line.decode('utf-8').strip("\n").strip("\r\n") dict_character.append(line) if merge_no_span_structure: if "" not in dict_character: dict_character.append("") if "" in dict_character: dict_character.remove("") dict_character = self.add_special_char(dict_character) self.dict = {} for i, char in enumerate(dict_character): self.dict[char] = i self.character = dict_character self.td_token = ['', ''] def add_special_char(self, dict_character): """ add_special_char """ self.beg_str = "sos" self.end_str = "eos" dict_character = dict_character dict_character = [self.beg_str] + dict_character + [self.end_str] return dict_character def get_ignored_tokens(self): """ get_ignored_tokens """ beg_idx = self.get_beg_end_flag_idx("beg") end_idx = self.get_beg_end_flag_idx("end") return [beg_idx, end_idx] def get_beg_end_flag_idx(self, beg_or_end): """ get_beg_end_flag_idx """ if beg_or_end == "beg": idx = np.array(self.dict[self.beg_str]) elif beg_or_end == "end": idx = np.array(self.dict[self.end_str]) else: assert False, "unsupported type %s in get_beg_end_flag_idx" \ % beg_or_end return idx def apply(self, data): """ apply """ shape_list = data[K.SHAPE_LIST] structure_probs = data[K.STRUCTURE_PROB] bbox_preds = data[K.LOC_PROB] if isinstance(structure_probs, paddle.Tensor): structure_probs = structure_probs.numpy() if isinstance(bbox_preds, paddle.Tensor): bbox_preds = bbox_preds.numpy() post_result = self.decode(structure_probs, bbox_preds, shape_list) structure_str_list = post_result['structure_batch_list'][0] bbox_list = post_result['bbox_batch_list'][0] structure_str_list = structure_str_list[0] structure_str_list = [ '', '', '' ] + structure_str_list + ['
', '', ''] data[K.BBOX_RES] = bbox_list data[K.HTML_RES] = structure_str_list return data @classmethod def get_input_keys(cls): """ get input keys """ return [K.STRUCTURE_PROB, K.LOC_PROB, K.SHAPE_LIST] @classmethod def get_output_keys(cls): """ get output keys """ return [K.BBOX_RES, K.HTML_RES] def decode(self, structure_probs, bbox_preds, shape_list): """convert text-label into text-index. """ ignored_tokens = self.get_ignored_tokens() end_idx = self.dict[self.end_str] structure_idx = structure_probs.argmax(axis=2) structure_probs = structure_probs.max(axis=2) structure_batch_list = [] bbox_batch_list = [] batch_size = len(structure_idx) for batch_idx in range(batch_size): structure_list = [] bbox_list = [] score_list = [] for idx in range(len(structure_idx[batch_idx])): char_idx = int(structure_idx[batch_idx][idx]) if idx > 0 and char_idx == end_idx: break if char_idx in ignored_tokens: continue text = self.character[char_idx] if text in self.td_token: bbox = bbox_preds[batch_idx, idx] bbox = self._bbox_decode(bbox, shape_list[batch_idx]) bbox_list.append(bbox) structure_list.append(text) score_list.append(structure_probs[batch_idx, idx]) structure_batch_list.append([structure_list, np.mean(score_list)]) bbox_batch_list.append(np.array(bbox_list)) result = { 'bbox_batch_list': bbox_batch_list, 'structure_batch_list': structure_batch_list, } return result def decode_label(self, batch): """convert text-label into text-index. """ structure_idx = batch[1] gt_bbox_list = batch[2] shape_list = batch[-1] ignored_tokens = self.get_ignored_tokens() end_idx = self.dict[self.end_str] structure_batch_list = [] bbox_batch_list = [] batch_size = len(structure_idx) for batch_idx in range(batch_size): structure_list = [] bbox_list = [] for idx in range(len(structure_idx[batch_idx])): char_idx = int(structure_idx[batch_idx][idx]) if idx > 0 and char_idx == end_idx: break if char_idx in ignored_tokens: continue structure_list.append(self.character[char_idx]) bbox = gt_bbox_list[batch_idx][idx] if bbox.sum() != 0: bbox = self._bbox_decode(bbox, shape_list[batch_idx]) bbox_list.append(bbox) structure_batch_list.append(structure_list) bbox_batch_list.append(bbox_list) result = { 'bbox_batch_list': bbox_batch_list, 'structure_batch_list': structure_batch_list, } return result def _bbox_decode(self, bbox, shape): w, h = shape[:2] bbox[0::2] *= w bbox[1::2] *= h return bbox class TableMasterLabelDecode(TableLabelDecode): """ decode the table model outputs(probs) to character str""" def __init__(self, character_dict_type='TableMaster', box_shape='pad', merge_no_span_structure=True): super(TableMasterLabelDecode, self).__init__(character_dict_type, merge_no_span_structure) self.box_shape = box_shape assert box_shape in [ 'ori', 'pad' ], 'The shape used for box normalization must be ori or pad' def add_special_char(self, dict_character): """ add_special_char """ self.beg_str = '' self.end_str = '' self.unknown_str = '' self.pad_str = '' dict_character = dict_character dict_character = dict_character + [ self.unknown_str, self.beg_str, self.end_str, self.pad_str ] return dict_character def get_ignored_tokens(self): """ get_ignored_tokens """ pad_idx = self.dict[self.pad_str] start_idx = self.dict[self.beg_str] end_idx = self.dict[self.end_str] unknown_idx = self.dict[self.unknown_str] return [start_idx, end_idx, pad_idx, unknown_idx] def _bbox_decode(self, bbox, shape): """ _bbox_decode """ h, w, ratio_h, ratio_w, pad_h, pad_w = shape if self.box_shape == 'pad': h, w = pad_h, pad_w bbox[0::2] *= w bbox[1::2] *= h bbox[0::2] /= ratio_w bbox[1::2] /= ratio_h x, y, w, h = bbox x1, y1, x2, y2 = x - w // 2, y - h // 2, x + w // 2, y + h // 2 bbox = np.array([x1, y1, x2, y2]) return bbox class SaveTableResults(BaseTransform): """ SaveTableResults """ _TABLE_RES_SUFFIX = '_bbox' _FILE_EXT = '.png' # _DEFAULT_FILE_NAME = 'table_res_out.png' def __init__(self, save_dir): super().__init__() self.save_dir = save_dir # We use pillow backend to save both numpy arrays and PIL Image objects self._writer = ImageWriter(backend='pillow') def apply(self, data): """ apply """ ori_path = data[K.IM_PATH] bbox_res = data[K.BBOX_RES] file_name = os.path.basename(ori_path) file_name = self._replace_ext(file_name, self._FILE_EXT) table_res_save_path = os.path.join(self.save_dir, file_name) if len(bbox_res) > 0 and len(bbox_res[0]) == 4: vis_img = self.draw_rectangle(data[K.ORI_IM], bbox_res) else: vis_img = self.draw_bbox(data[K.ORI_IM], bbox_res) table_res_save_path = self._add_suffix(table_res_save_path, self._TABLE_RES_SUFFIX) self._write_im(table_res_save_path, vis_img) return data @classmethod def get_input_keys(cls): """ get input keys """ return [K.IM_PATH, K.ORI_IM, K.BBOX_RES] @classmethod def get_output_keys(cls): """ get output keys """ return [] def _write_im(self, path, im): """ write image """ if os.path.exists(path): logging.warning(f"{path} already exists. Overwriting it.") self._writer.write(path, im) @staticmethod def _add_suffix(path, suffix): """ _add_suffix """ stem, ext = os.path.splitext(path) return stem + suffix + ext @staticmethod def _replace_ext(path, new_ext): """ _replace_ext """ stem, _ = os.path.splitext(path) return stem + new_ext def draw_rectangle(self, img_path, boxes): """ draw_rectangle """ boxes = np.array(boxes) img = cv2.imread(img_path) img_show = img.copy() for box in boxes.astype(int): x1, y1, x2, y2 = box cv2.rectangle(img_show, (x1, y1), (x2, y2), (255, 0, 0), 2) return img_show def draw_bbox(self, image, boxes): """ draw_bbox """ for box in boxes: box = np.reshape(np.array(box), [-1, 1, 2]).astype(np.int64) image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2) return image class PrintResult(BaseTransform): """ Print Result Transform """ def apply(self, data): """ apply """ logging.info("The prediction result is:") logging.info(data[K.BOXES]) return data @classmethod def get_input_keys(cls): """ get input keys """ return [K.BOXES] @classmethod def get_output_keys(cls): """ get output keys """ return []