| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792 |
- # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import re
- import numpy as np
- import torch
- class BaseRecLabelDecode(object):
- """ Convert between text-label and text-index """
- def __init__(self,
- character_dict_path=None,
- use_space_char=False):
- self.beg_str = "sos"
- self.end_str = "eos"
- self.reverse = False
- self.character_str = []
- if character_dict_path is None:
- self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
- dict_character = list(self.character_str)
- else:
- with open(character_dict_path, "rb") as fin:
- lines = fin.readlines()
- for line in lines:
- line = line.decode('utf-8').strip("\n").strip("\r\n")
- self.character_str.append(line)
- if use_space_char:
- self.character_str.append(" ")
- dict_character = list(self.character_str)
- if "arabic" in character_dict_path:
- self.reverse = True
- dict_character = self.add_special_char(dict_character)
- self.dict = {}
- for i, char in enumerate(dict_character):
- self.dict[char] = i
- self.character = dict_character
- def pred_reverse(self, pred):
- pred_re = []
- c_current = ""
- for c in pred:
- if not bool(re.search("[a-zA-Z0-9 :*./%+-]", c)):
- if c_current != "":
- pred_re.append(c_current)
- pred_re.append(c)
- c_current = ""
- else:
- c_current += c
- if c_current != "":
- pred_re.append(c_current)
- return "".join(pred_re[::-1])
- def add_special_char(self, dict_character):
- return dict_character
- def get_word_info(self, text, selection):
- """
- Group the decoded characters and record the corresponding decoded positions.
- Args:
- text: the decoded text
- selection: the bool array that identifies which columns of features are decoded as non-separated characters
- Returns:
- word_list: list of the grouped words
- word_col_list: list of decoding positions corresponding to each character in the grouped word
- state_list: list of marker to identify the type of grouping words, including two types of grouping words:
- - 'cn': continuous chinese characters (e.g., 你好啊)
- - 'en&num': continuous english characters (e.g., hello), number (e.g., 123, 1.123), or mixed of them connected by '-' (e.g., VGG-16)
- The remaining characters in text are treated as separators between groups (e.g., space, '(', ')', etc.).
- """
- state = None
- word_content = []
- word_col_content = []
- word_list = []
- word_col_list = []
- state_list = []
- valid_col = np.where(selection == True)[0]
- for c_i, char in enumerate(text):
- if "\u4e00" <= char <= "\u9fff":
- c_state = "cn"
- elif bool(re.search("[a-zA-Z0-9]", char)):
- c_state = "en&num"
- else:
- c_state = "splitter"
- if (
- char == "."
- and state == "en&num"
- and c_i + 1 < len(text)
- and bool(re.search("[0-9]", text[c_i + 1]))
- ): # grouping floating number
- c_state = "en&num"
- if (
- char == "-" and state == "en&num"
- ): # grouping word with '-', such as 'state-of-the-art'
- c_state = "en&num"
- if state == None:
- state = c_state
- if state != c_state:
- if len(word_content) != 0:
- word_list.append(word_content)
- word_col_list.append(word_col_content)
- state_list.append(state)
- word_content = []
- word_col_content = []
- state = c_state
- if state != "splitter":
- word_content.append(char)
- word_col_content.append(valid_col[c_i])
- if len(word_content) != 0:
- word_list.append(word_content)
- word_col_list.append(word_col_content)
- state_list.append(state)
- return word_list, word_col_list, state_list
- def decode(
- self,
- text_index,
- text_prob=None,
- is_remove_duplicate=False,
- return_word_box=False,
- ):
- """ convert text-index into text-label. """
- result_list = []
- ignored_tokens = self.get_ignored_tokens()
- batch_size = len(text_index)
- for batch_idx in range(batch_size):
- char_list = []
- conf_list = []
- for idx in range(len(text_index[batch_idx])):
- if text_index[batch_idx][idx] in ignored_tokens:
- continue
- if is_remove_duplicate:
- # only for predict
- if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
- batch_idx][idx]:
- continue
- char_list.append(self.character[int(text_index[batch_idx][
- idx])])
- if text_prob is not None:
- conf_list.append(text_prob[batch_idx][idx])
- else:
- conf_list.append(1)
- text = ''.join(char_list)
- result_list.append((text, np.mean(conf_list)))
- return result_list
- def get_ignored_tokens(self):
- return [0] # for ctc blank
- class CTCLabelDecode(BaseRecLabelDecode):
- """ Convert between text-label and text-index """
- def __init__(self,
- character_dict_path=None,
- use_space_char=False,
- **kwargs):
- super(CTCLabelDecode, self).__init__(character_dict_path,
- use_space_char)
- def __call__(self, preds, label=None, return_word_box=False, *args, **kwargs):
- if isinstance(preds, torch.Tensor):
- preds = preds.numpy()
- preds_idx = preds.argmax(axis=2)
- preds_prob = preds.max(axis=2)
- text = self.decode(
- preds_idx,
- preds_prob,
- is_remove_duplicate=True,
- return_word_box=return_word_box,
- )
- if return_word_box:
- for rec_idx, rec in enumerate(text):
- wh_ratio = kwargs["wh_ratio_list"][rec_idx]
- max_wh_ratio = kwargs["max_wh_ratio"]
- rec[2][0] = rec[2][0] * (wh_ratio / max_wh_ratio)
- if label is None:
- return text
- label = self.decode(label)
- return text, label
- def add_special_char(self, dict_character):
- dict_character = ['blank'] + dict_character
- return dict_character
- class NRTRLabelDecode(BaseRecLabelDecode):
- """ Convert between text-label and text-index """
- def __init__(self, character_dict_path=None, use_space_char=True, **kwargs):
- super(NRTRLabelDecode, self).__init__(character_dict_path,
- use_space_char)
- def __call__(self, preds, label=None, *args, **kwargs):
- if len(preds) == 2:
- preds_id = preds[0]
- preds_prob = preds[1]
- if isinstance(preds_id, torch.Tensor):
- preds_id = preds_id.numpy()
- if isinstance(preds_prob, torch.Tensor):
- preds_prob = preds_prob.numpy()
- if preds_id[0][0] == 2:
- preds_idx = preds_id[:, 1:]
- preds_prob = preds_prob[:, 1:]
- else:
- preds_idx = preds_id
- text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
- if label is None:
- return text
- label = self.decode(label[:, 1:])
- else:
- if isinstance(preds, torch.Tensor):
- preds = preds.numpy()
- preds_idx = preds.argmax(axis=2)
- preds_prob = preds.max(axis=2)
- text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
- if label is None:
- return text
- label = self.decode(label[:, 1:])
- return text, label
- def add_special_char(self, dict_character):
- dict_character = ['blank', '<unk>', '<s>', '</s>'] + dict_character
- return dict_character
- def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
- """ convert text-index into text-label. """
- result_list = []
- batch_size = len(text_index)
- for batch_idx in range(batch_size):
- char_list = []
- conf_list = []
- for idx in range(len(text_index[batch_idx])):
- try:
- char_idx = self.character[int(text_index[batch_idx][idx])]
- except:
- continue
- if char_idx == '</s>': # end
- break
- char_list.append(char_idx)
- if text_prob is not None:
- conf_list.append(text_prob[batch_idx][idx])
- else:
- conf_list.append(1)
- text = ''.join(char_list)
- result_list.append((text.lower(), np.mean(conf_list).tolist()))
- return result_list
- class ViTSTRLabelDecode(NRTRLabelDecode):
- """ Convert between text-label and text-index """
- def __init__(self, character_dict_path=None, use_space_char=False,
- **kwargs):
- super(ViTSTRLabelDecode, self).__init__(character_dict_path,
- use_space_char)
- def __call__(self, preds, label=None, *args, **kwargs):
- if isinstance(preds, torch.Tensor):
- preds = preds[:, 1:].numpy()
- else:
- preds = preds[:, 1:]
- preds_idx = preds.argmax(axis=2)
- preds_prob = preds.max(axis=2)
- text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
- if label is None:
- return text
- label = self.decode(label[:, 1:])
- return text, label
- def add_special_char(self, dict_character):
- dict_character = ['<s>', '</s>'] + dict_character
- return dict_character
- class AttnLabelDecode(BaseRecLabelDecode):
- """ Convert between text-label and text-index """
- def __init__(self,
- character_dict_path=None,
- use_space_char=False,
- **kwargs):
- super(AttnLabelDecode, self).__init__(character_dict_path,
- use_space_char)
- def add_special_char(self, dict_character):
- self.beg_str = "sos"
- self.end_str = "eos"
- dict_character = dict_character
- dict_character = [self.beg_str] + dict_character + [self.end_str]
- return dict_character
- def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
- """ convert text-index into text-label. """
- result_list = []
- ignored_tokens = self.get_ignored_tokens()
- [beg_idx, end_idx] = self.get_ignored_tokens()
- batch_size = len(text_index)
- for batch_idx in range(batch_size):
- char_list = []
- conf_list = []
- for idx in range(len(text_index[batch_idx])):
- if text_index[batch_idx][idx] in ignored_tokens:
- continue
- if int(text_index[batch_idx][idx]) == int(end_idx):
- break
- if is_remove_duplicate:
- # only for predict
- if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
- batch_idx][idx]:
- continue
- char_list.append(self.character[int(text_index[batch_idx][
- idx])])
- if text_prob is not None:
- conf_list.append(text_prob[batch_idx][idx])
- else:
- conf_list.append(1)
- text = ''.join(char_list)
- result_list.append((text, np.mean(conf_list)))
- return result_list
- def __call__(self, preds, label=None, *args, **kwargs):
- """
- text = self.decode(text)
- if label is None:
- return text
- else:
- label = self.decode(label, is_remove_duplicate=False)
- return text, label
- """
- if isinstance(preds, torch.Tensor):
- preds = preds.cpu().numpy()
- preds_idx = preds.argmax(axis=2)
- preds_prob = preds.max(axis=2)
- text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
- if label is None:
- return text
- label = self.decode(label, is_remove_duplicate=False)
- return text, label
- def get_ignored_tokens(self):
- beg_idx = self.get_beg_end_flag_idx("beg")
- end_idx = self.get_beg_end_flag_idx("end")
- return [beg_idx, end_idx]
- def get_beg_end_flag_idx(self, beg_or_end):
- if beg_or_end == "beg":
- idx = np.array(self.dict[self.beg_str])
- elif beg_or_end == "end":
- idx = np.array(self.dict[self.end_str])
- else:
- assert False, "unsupport type %s in get_beg_end_flag_idx" \
- % beg_or_end
- return idx
- class RFLLabelDecode(BaseRecLabelDecode):
- """ Convert between text-label and text-index """
- def __init__(self, character_dict_path=None, use_space_char=False,
- **kwargs):
- super(RFLLabelDecode, self).__init__(character_dict_path,
- use_space_char)
- def add_special_char(self, dict_character):
- self.beg_str = "sos"
- self.end_str = "eos"
- dict_character = dict_character
- dict_character = [self.beg_str] + dict_character + [self.end_str]
- return dict_character
- def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
- """ convert text-index into text-label. """
- result_list = []
- ignored_tokens = self.get_ignored_tokens()
- [beg_idx, end_idx] = self.get_ignored_tokens()
- batch_size = len(text_index)
- for batch_idx in range(batch_size):
- char_list = []
- conf_list = []
- for idx in range(len(text_index[batch_idx])):
- if text_index[batch_idx][idx] in ignored_tokens:
- continue
- if int(text_index[batch_idx][idx]) == int(end_idx):
- break
- if is_remove_duplicate:
- # only for predict
- if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
- batch_idx][idx]:
- continue
- char_list.append(self.character[int(text_index[batch_idx][
- idx])])
- if text_prob is not None:
- conf_list.append(text_prob[batch_idx][idx])
- else:
- conf_list.append(1)
- text = ''.join(char_list)
- result_list.append((text, np.mean(conf_list).tolist()))
- return result_list
- def __call__(self, preds, label=None, *args, **kwargs):
- # if seq_outputs is not None:
- if isinstance(preds, tuple) or isinstance(preds, list):
- cnt_outputs, seq_outputs = preds
- if isinstance(seq_outputs, torch.Tensor):
- seq_outputs = seq_outputs.numpy()
- preds_idx = seq_outputs.argmax(axis=2)
- preds_prob = seq_outputs.max(axis=2)
- text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
- if label is None:
- return text
- label = self.decode(label, is_remove_duplicate=False)
- return text, label
- else:
- cnt_outputs = preds
- if isinstance(cnt_outputs, torch.Tensor):
- cnt_outputs = cnt_outputs.numpy()
- cnt_length = []
- for lens in cnt_outputs:
- length = round(np.sum(lens))
- cnt_length.append(length)
- if label is None:
- return cnt_length
- label = self.decode(label, is_remove_duplicate=False)
- length = [len(res[0]) for res in label]
- return cnt_length, length
- def get_ignored_tokens(self):
- beg_idx = self.get_beg_end_flag_idx("beg")
- end_idx = self.get_beg_end_flag_idx("end")
- return [beg_idx, end_idx]
- def get_beg_end_flag_idx(self, beg_or_end):
- if beg_or_end == "beg":
- idx = np.array(self.dict[self.beg_str])
- elif beg_or_end == "end":
- idx = np.array(self.dict[self.end_str])
- else:
- assert False, "unsupport type %s in get_beg_end_flag_idx" \
- % beg_or_end
- return idx
- class SRNLabelDecode(BaseRecLabelDecode):
- """ Convert between text-label and text-index """
- def __init__(self,
- character_dict_path=None,
- use_space_char=False,
- **kwargs):
- self.max_text_length = kwargs.get('max_text_length', 25)
- super(SRNLabelDecode, self).__init__(character_dict_path,
- use_space_char)
- def __call__(self, preds, label=None, *args, **kwargs):
- pred = preds['predict']
- char_num = len(self.character_str) + 2
- if isinstance(pred, torch.Tensor):
- pred = pred.numpy()
- pred = np.reshape(pred, [-1, char_num])
- preds_idx = np.argmax(pred, axis=1)
- preds_prob = np.max(pred, axis=1)
- preds_idx = np.reshape(preds_idx, [-1, self.max_text_length])
- preds_prob = np.reshape(preds_prob, [-1, self.max_text_length])
- text = self.decode(preds_idx, preds_prob)
- if label is None:
- text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
- return text
- label = self.decode(label)
- return text, label
- def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
- """ convert text-index into text-label. """
- result_list = []
- ignored_tokens = self.get_ignored_tokens()
- batch_size = len(text_index)
- for batch_idx in range(batch_size):
- char_list = []
- conf_list = []
- for idx in range(len(text_index[batch_idx])):
- if text_index[batch_idx][idx] in ignored_tokens:
- continue
- if is_remove_duplicate:
- # only for predict
- if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
- batch_idx][idx]:
- continue
- char_list.append(self.character[int(text_index[batch_idx][
- idx])])
- if text_prob is not None:
- conf_list.append(text_prob[batch_idx][idx])
- else:
- conf_list.append(1)
- text = ''.join(char_list)
- result_list.append((text, np.mean(conf_list)))
- return result_list
- def add_special_char(self, dict_character):
- dict_character = dict_character + [self.beg_str, self.end_str]
- return dict_character
- def get_ignored_tokens(self):
- beg_idx = self.get_beg_end_flag_idx("beg")
- end_idx = self.get_beg_end_flag_idx("end")
- return [beg_idx, end_idx]
- def get_beg_end_flag_idx(self, beg_or_end):
- if beg_or_end == "beg":
- idx = np.array(self.dict[self.beg_str])
- elif beg_or_end == "end":
- idx = np.array(self.dict[self.end_str])
- else:
- assert False, "unsupport type %s in get_beg_end_flag_idx" \
- % beg_or_end
- return idx
- class TableLabelDecode(object):
- """ """
- def __init__(self,
- character_dict_path,
- **kwargs):
- list_character, list_elem = self.load_char_elem_dict(character_dict_path)
- list_character = self.add_special_char(list_character)
- list_elem = self.add_special_char(list_elem)
- self.dict_character = {}
- self.dict_idx_character = {}
- for i, char in enumerate(list_character):
- self.dict_idx_character[i] = char
- self.dict_character[char] = i
- self.dict_elem = {}
- self.dict_idx_elem = {}
- for i, elem in enumerate(list_elem):
- self.dict_idx_elem[i] = elem
- self.dict_elem[elem] = i
- def load_char_elem_dict(self, character_dict_path):
- list_character = []
- list_elem = []
- with open(character_dict_path, "rb") as fin:
- lines = fin.readlines()
- substr = lines[0].decode('utf-8').strip("\n").strip("\r\n").split("\t")
- character_num = int(substr[0])
- elem_num = int(substr[1])
- for cno in range(1, 1 + character_num):
- character = lines[cno].decode('utf-8').strip("\n").strip("\r\n")
- list_character.append(character)
- for eno in range(1 + character_num, 1 + character_num + elem_num):
- elem = lines[eno].decode('utf-8').strip("\n").strip("\r\n")
- list_elem.append(elem)
- return list_character, list_elem
- def add_special_char(self, list_character):
- self.beg_str = "sos"
- self.end_str = "eos"
- list_character = [self.beg_str] + list_character + [self.end_str]
- return list_character
- def __call__(self, preds):
- structure_probs = preds['structure_probs']
- loc_preds = preds['loc_preds']
- if isinstance(structure_probs,torch.Tensor):
- structure_probs = structure_probs.numpy()
- if isinstance(loc_preds,torch.Tensor):
- loc_preds = loc_preds.numpy()
- structure_idx = structure_probs.argmax(axis=2)
- structure_probs = structure_probs.max(axis=2)
- structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode(structure_idx,
- structure_probs, 'elem')
- res_html_code_list = []
- res_loc_list = []
- batch_num = len(structure_str)
- for bno in range(batch_num):
- res_loc = []
- for sno in range(len(structure_str[bno])):
- text = structure_str[bno][sno]
- if text in ['<td>', '<td']:
- pos = structure_pos[bno][sno]
- res_loc.append(loc_preds[bno, pos])
- res_html_code = ''.join(structure_str[bno])
- res_loc = np.array(res_loc)
- res_html_code_list.append(res_html_code)
- res_loc_list.append(res_loc)
- return {'res_html_code': res_html_code_list, 'res_loc': res_loc_list, 'res_score_list': result_score_list,
- 'res_elem_idx_list': result_elem_idx_list,'structure_str_list':structure_str}
- def decode(self, text_index, structure_probs, char_or_elem):
- """convert text-label into text-index.
- """
- if char_or_elem == "char":
- current_dict = self.dict_idx_character
- else:
- current_dict = self.dict_idx_elem
- ignored_tokens = self.get_ignored_tokens('elem')
- beg_idx, end_idx = ignored_tokens
- result_list = []
- result_pos_list = []
- result_score_list = []
- result_elem_idx_list = []
- batch_size = len(text_index)
- for batch_idx in range(batch_size):
- char_list = []
- elem_pos_list = []
- elem_idx_list = []
- score_list = []
- for idx in range(len(text_index[batch_idx])):
- tmp_elem_idx = int(text_index[batch_idx][idx])
- if idx > 0 and tmp_elem_idx == end_idx:
- break
- if tmp_elem_idx in ignored_tokens:
- continue
- char_list.append(current_dict[tmp_elem_idx])
- elem_pos_list.append(idx)
- score_list.append(structure_probs[batch_idx, idx])
- elem_idx_list.append(tmp_elem_idx)
- result_list.append(char_list)
- result_pos_list.append(elem_pos_list)
- result_score_list.append(score_list)
- result_elem_idx_list.append(elem_idx_list)
- return result_list, result_pos_list, result_score_list, result_elem_idx_list
- def get_ignored_tokens(self, char_or_elem):
- beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem)
- end_idx = self.get_beg_end_flag_idx("end", char_or_elem)
- return [beg_idx, end_idx]
- def get_beg_end_flag_idx(self, beg_or_end, char_or_elem):
- if char_or_elem == "char":
- if beg_or_end == "beg":
- idx = self.dict_character[self.beg_str]
- elif beg_or_end == "end":
- idx = self.dict_character[self.end_str]
- else:
- assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \
- % beg_or_end
- elif char_or_elem == "elem":
- if beg_or_end == "beg":
- idx = self.dict_elem[self.beg_str]
- elif beg_or_end == "end":
- idx = self.dict_elem[self.end_str]
- else:
- assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \
- % beg_or_end
- else:
- assert False, "Unsupport type %s in char_or_elem" \
- % char_or_elem
- return idx
- class SARLabelDecode(BaseRecLabelDecode):
- """ Convert between text-label and text-index """
- def __init__(self, character_dict_path=None, use_space_char=False,
- **kwargs):
- super(SARLabelDecode, self).__init__(character_dict_path,
- use_space_char)
- self.rm_symbol = kwargs.get('rm_symbol', False)
- def add_special_char(self, dict_character):
- beg_end_str = "<BOS/EOS>"
- unknown_str = "<UKN>"
- padding_str = "<PAD>"
- dict_character = dict_character + [unknown_str]
- self.unknown_idx = len(dict_character) - 1
- dict_character = dict_character + [beg_end_str]
- self.start_idx = len(dict_character) - 1
- self.end_idx = len(dict_character) - 1
- dict_character = dict_character + [padding_str]
- self.padding_idx = len(dict_character) - 1
- return dict_character
- def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
- """ convert text-index into text-label. """
- result_list = []
- ignored_tokens = self.get_ignored_tokens()
- batch_size = len(text_index)
- for batch_idx in range(batch_size):
- char_list = []
- conf_list = []
- for idx in range(len(text_index[batch_idx])):
- if text_index[batch_idx][idx] in ignored_tokens:
- continue
- if int(text_index[batch_idx][idx]) == int(self.end_idx):
- if text_prob is None and idx == 0:
- continue
- else:
- break
- if is_remove_duplicate:
- # only for predict
- if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
- batch_idx][idx]:
- continue
- char_list.append(self.character[int(text_index[batch_idx][
- idx])])
- if text_prob is not None:
- conf_list.append(text_prob[batch_idx][idx])
- else:
- conf_list.append(1)
- text = ''.join(char_list)
- if self.rm_symbol:
- comp = re.compile('[^A-Z^a-z^0-9^\u4e00-\u9fa5]')
- text = text.lower()
- text = comp.sub('', text)
- result_list.append((text, np.mean(conf_list).tolist()))
- return result_list
- def __call__(self, preds, label=None, *args, **kwargs):
- if isinstance(preds, torch.Tensor):
- preds = preds.cpu().numpy()
- preds_idx = preds.argmax(axis=2)
- preds_prob = preds.max(axis=2)
- text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
- if label is None:
- return text
- label = self.decode(label, is_remove_duplicate=False)
- return text, label
- def get_ignored_tokens(self):
- return [self.padding_idx]
- class CANLabelDecode(BaseRecLabelDecode):
- """ Convert between latex-symbol and symbol-index """
- def __init__(self, character_dict_path=None, use_space_char=False,
- **kwargs):
- super(CANLabelDecode, self).__init__(character_dict_path,
- use_space_char)
- def decode(self, text_index, preds_prob=None):
- result_list = []
- batch_size = len(text_index)
- for batch_idx in range(batch_size):
- seq_end = text_index[batch_idx].argmin(0)
- idx_list = text_index[batch_idx][:seq_end].tolist()
- symbol_list = [self.character[idx] for idx in idx_list]
- probs = []
- if preds_prob is not None:
- probs = preds_prob[batch_idx][:len(symbol_list)].tolist()
- result_list.append([' '.join(symbol_list), probs])
- return result_list
- def __call__(self, preds, label=None, *args, **kwargs):
- pred_prob, _, _, _ = preds
- preds_idx = pred_prob.argmax(axis=2)
- text = self.decode(preds_idx)
- if label is None:
- return text
- label = self.decode(label)
- return text, label
|