# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import cv2
import numpy as np
from numpy import ndarray
from ..common.vision import funcs as F
class Pad:
"""Pad the image."""
INPUT_KEYS = "img"
OUTPUT_KEYS = ["img", "img_size"]
DEAULT_INPUTS = {"img": "img"}
DEAULT_OUTPUTS = {"img": "img", "img_size": "img_size"}
def __init__(self, target_size, val=127.5):
"""
Initialize the instance.
Args:
target_size (list|tuple|int): Target width and height of the image after
padding.
val (float, optional): Value to fill the padded area. Default: 127.5.
"""
super().__init__()
if isinstance(target_size, int):
target_size = [target_size, target_size]
self.target_size = target_size
self.val = val
def apply(self, img):
"""apply"""
h, w = img.shape[:2]
tw, th = self.target_size
ph = th - h
pw = tw - w
if ph < 0 or pw < 0:
raise ValueError(
f"Input image ({w}, {h}) smaller than the target size ({tw}, {th})."
)
else:
img = F.pad(img, pad=(0, ph, 0, pw), val=self.val)
return [img, [img.shape[1], img.shape[0]]]
def __call__(self, imgs):
"""apply"""
return [self.apply(img) for img in imgs]
class TableLabelDecode:
"""decode the table model outputs(probs) to character str"""
ENABLE_BATCH = True
INPUT_KEYS = ["pred", "img_size", "ori_img_size"]
OUTPUT_KEYS = ["bbox", "structure", "structure_score"]
DEAULT_INPUTS = {
"pred": "pred",
"img_size": "img_size",
"ori_img_size": "ori_img_size",
}
DEAULT_OUTPUTS = {
"bbox": "bbox",
"structure": "structure",
"structure_score": "structure_score",
}
def __init__(self, model_name, merge_no_span_structure=True, dict_character=[]):
super().__init__()
if merge_no_span_structure:
if "
| " not in dict_character:
dict_character.append(" | ")
if "" in dict_character:
dict_character.remove(" | ")
self.model_name = model_name
dict_character = self.add_special_char(dict_character)
self.dict = {}
for i, char in enumerate(dict_character):
self.dict[char] = i
self.character = dict_character
self.td_token = [" | ", " | | "]
def add_special_char(self, dict_character):
"""add_special_char"""
self.beg_str = "sos"
self.end_str = "eos"
dict_character = dict_character
dict_character = [self.beg_str] + dict_character + [self.end_str]
return dict_character
def get_ignored_tokens(self):
"""get_ignored_tokens"""
beg_idx = self.get_beg_end_flag_idx("beg")
end_idx = self.get_beg_end_flag_idx("end")
return [beg_idx, end_idx]
def get_beg_end_flag_idx(self, beg_or_end):
"""get_beg_end_flag_idx"""
if beg_or_end == "beg":
idx = np.array(self.dict[self.beg_str])
elif beg_or_end == "end":
idx = np.array(self.dict[self.end_str])
else:
assert False, "unsupported type %s in get_beg_end_flag_idx" % beg_or_end
return idx
def __call__(self, pred, img_size, ori_img_size):
"""apply"""
bbox_preds, structure_probs = [], []
for i in range(len(pred[0][0])):
bbox_preds.append(pred[0][0][i])
structure_probs.append(pred[1][0][i])
bbox_preds = [bbox_preds]
structure_probs = [structure_probs]
bbox_preds = np.array(bbox_preds)
structure_probs = np.array(structure_probs)
bbox_list, structure_str_list, structure_score = self.decode(
structure_probs, bbox_preds, img_size, ori_img_size
)
structure_str_list = [
(
["", "", "", "", ""]
)
for structure in structure_str_list
]
return [
{"bbox": bbox, "structure": structure, "structure_score": structure_score}
for bbox, structure in zip(bbox_list, structure_str_list)
]
def decode(self, structure_probs, bbox_preds, padding_size, ori_img_size):
"""convert text-label into text-index."""
ignored_tokens = self.get_ignored_tokens()
end_idx = self.dict[self.end_str]
structure_idx = structure_probs.argmax(axis=2)
structure_probs = structure_probs.max(axis=2)
structure_batch_list = []
bbox_batch_list = []
batch_size = len(structure_idx)
for batch_idx in range(batch_size):
structure_list = []
bbox_list = []
score_list = []
for idx in range(len(structure_idx[batch_idx])):
char_idx = int(structure_idx[batch_idx][idx])
if idx > 0 and char_idx == end_idx:
break
if char_idx in ignored_tokens:
continue
text = self.character[char_idx]
if text in self.td_token:
bbox = bbox_preds[batch_idx, idx]
bbox = self._bbox_decode(
bbox, padding_size[batch_idx], ori_img_size[batch_idx]
)
bbox_list.append(bbox.astype(int))
structure_list.append(text)
score_list.append(structure_probs[batch_idx, idx])
structure_batch_list.append(structure_list)
structure_score = np.mean(score_list)
bbox_batch_list.append(bbox_list)
return bbox_batch_list, structure_batch_list, structure_score
def decode_label(self, batch):
"""convert text-label into text-index."""
structure_idx = batch[1]
gt_bbox_list = batch[2]
shape_list = batch[-1]
ignored_tokens = self.get_ignored_tokens()
end_idx = self.dict[self.end_str]
structure_batch_list = []
bbox_batch_list = []
batch_size = len(structure_idx)
for batch_idx in range(batch_size):
structure_list = []
bbox_list = []
for idx in range(len(structure_idx[batch_idx])):
char_idx = int(structure_idx[batch_idx][idx])
if idx > 0 and char_idx == end_idx:
break
if char_idx in ignored_tokens:
continue
structure_list.append(self.character[char_idx])
bbox = gt_bbox_list[batch_idx][idx]
if bbox.sum() != 0:
bbox = self._bbox_decode(bbox, shape_list[batch_idx])
bbox_list.append(bbox.astype(int))
structure_batch_list.append(structure_list)
bbox_batch_list.append(bbox_list)
return bbox_batch_list, structure_batch_list
def _bbox_decode(self, bbox, padding_shape, ori_shape):
if self.model_name == "SLANet":
w, h = ori_shape
bbox[0::2] *= w
bbox[1::2] *= h
else:
w, h = padding_shape
ori_w, ori_h = ori_shape
ratio_w = w / ori_w
ratio_h = h / ori_h
ratio = min(ratio_w, ratio_h)
bbox[0::2] *= w
bbox[1::2] *= h
bbox[0::2] /= ratio
bbox[1::2] /= ratio
return bbox