# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import cv2 import numpy as np from numpy import ndarray from ..common.vision import funcs as F class Pad: """Pad the image.""" INPUT_KEYS = "img" OUTPUT_KEYS = ["img", "img_size"] DEAULT_INPUTS = {"img": "img"} DEAULT_OUTPUTS = {"img": "img", "img_size": "img_size"} def __init__(self, target_size, val=127.5): """ Initialize the instance. Args: target_size (list|tuple|int): Target width and height of the image after padding. val (float, optional): Value to fill the padded area. Default: 127.5. """ super().__init__() if isinstance(target_size, int): target_size = [target_size, target_size] self.target_size = target_size self.val = val def apply(self, img): """apply""" h, w = img.shape[:2] tw, th = self.target_size ph = th - h pw = tw - w if ph < 0 or pw < 0: raise ValueError( f"Input image ({w}, {h}) smaller than the target size ({tw}, {th})." ) else: img = F.pad(img, pad=(0, ph, 0, pw), val=self.val) return [img, [img.shape[1], img.shape[0]]] def __call__(self, imgs): """apply""" return [self.apply(img) for img in imgs] class TableLabelDecode: """decode the table model outputs(probs) to character str""" ENABLE_BATCH = True INPUT_KEYS = ["pred", "img_size", "ori_img_size"] OUTPUT_KEYS = ["bbox", "structure", "structure_score"] DEAULT_INPUTS = { "pred": "pred", "img_size": "img_size", "ori_img_size": "ori_img_size", } DEAULT_OUTPUTS = { "bbox": "bbox", "structure": "structure", "structure_score": "structure_score", } def __init__(self, model_name, merge_no_span_structure=True, dict_character=[]): super().__init__() if merge_no_span_structure: if "" not in dict_character: dict_character.append("") if "" in dict_character: dict_character.remove("") self.model_name = model_name dict_character = self.add_special_char(dict_character) self.dict = {} for i, char in enumerate(dict_character): self.dict[char] = i self.character = dict_character self.td_token = ["", ""] def add_special_char(self, dict_character): """add_special_char""" self.beg_str = "sos" self.end_str = "eos" dict_character = dict_character dict_character = [self.beg_str] + dict_character + [self.end_str] return dict_character def get_ignored_tokens(self): """get_ignored_tokens""" beg_idx = self.get_beg_end_flag_idx("beg") end_idx = self.get_beg_end_flag_idx("end") return [beg_idx, end_idx] def get_beg_end_flag_idx(self, beg_or_end): """get_beg_end_flag_idx""" if beg_or_end == "beg": idx = np.array(self.dict[self.beg_str]) elif beg_or_end == "end": idx = np.array(self.dict[self.end_str]) else: assert False, "unsupported type %s in get_beg_end_flag_idx" % beg_or_end return idx def __call__(self, pred, img_size, ori_img_size): """apply""" bbox_preds, structure_probs = [], [] for i in range(len(pred[0][0])): bbox_preds.append(pred[0][0][i]) structure_probs.append(pred[1][0][i]) bbox_preds = [bbox_preds] structure_probs = [structure_probs] bbox_preds = np.array(bbox_preds) structure_probs = np.array(structure_probs) bbox_list, structure_str_list, structure_score = self.decode( structure_probs, bbox_preds, img_size, ori_img_size ) structure_str_list = [ ( ["", "", ""] + structure + ["
", "", ""] ) for structure in structure_str_list ] return [ {"bbox": bbox, "structure": structure, "structure_score": structure_score} for bbox, structure in zip(bbox_list, structure_str_list) ] def decode(self, structure_probs, bbox_preds, padding_size, ori_img_size): """convert text-label into text-index.""" ignored_tokens = self.get_ignored_tokens() end_idx = self.dict[self.end_str] structure_idx = structure_probs.argmax(axis=2) structure_probs = structure_probs.max(axis=2) structure_batch_list = [] bbox_batch_list = [] batch_size = len(structure_idx) for batch_idx in range(batch_size): structure_list = [] bbox_list = [] score_list = [] for idx in range(len(structure_idx[batch_idx])): char_idx = int(structure_idx[batch_idx][idx]) if idx > 0 and char_idx == end_idx: break if char_idx in ignored_tokens: continue text = self.character[char_idx] if text in self.td_token: bbox = bbox_preds[batch_idx, idx] bbox = self._bbox_decode( bbox, padding_size[batch_idx], ori_img_size[batch_idx] ) bbox_list.append(bbox.astype(int)) structure_list.append(text) score_list.append(structure_probs[batch_idx, idx]) structure_batch_list.append(structure_list) structure_score = np.mean(score_list) bbox_batch_list.append(bbox_list) return bbox_batch_list, structure_batch_list, structure_score def decode_label(self, batch): """convert text-label into text-index.""" structure_idx = batch[1] gt_bbox_list = batch[2] shape_list = batch[-1] ignored_tokens = self.get_ignored_tokens() end_idx = self.dict[self.end_str] structure_batch_list = [] bbox_batch_list = [] batch_size = len(structure_idx) for batch_idx in range(batch_size): structure_list = [] bbox_list = [] for idx in range(len(structure_idx[batch_idx])): char_idx = int(structure_idx[batch_idx][idx]) if idx > 0 and char_idx == end_idx: break if char_idx in ignored_tokens: continue structure_list.append(self.character[char_idx]) bbox = gt_bbox_list[batch_idx][idx] if bbox.sum() != 0: bbox = self._bbox_decode(bbox, shape_list[batch_idx]) bbox_list.append(bbox.astype(int)) structure_batch_list.append(structure_list) bbox_batch_list.append(bbox_list) return bbox_batch_list, structure_batch_list def _bbox_decode(self, bbox, padding_shape, ori_shape): if self.model_name == "SLANet": w, h = ori_shape bbox[0::2] *= w bbox[1::2] *= h else: w, h = padding_shape ori_w, ori_h = ori_shape ratio_w = w / ori_w ratio_h = h / ori_h ratio = min(ratio_w, ratio_h) bbox[0::2] *= w bbox[1::2] *= h bbox[0::2] /= ratio bbox[1::2] /= ratio return bbox