rec_postprocess.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import re
  15. import numpy as np
  16. import torch
  17. class BaseRecLabelDecode(object):
  18. """ Convert between text-label and text-index """
  19. def __init__(self,
  20. character_dict_path=None,
  21. use_space_char=False):
  22. self.beg_str = "sos"
  23. self.end_str = "eos"
  24. self.reverse = False
  25. self.character_str = []
  26. if character_dict_path is None:
  27. self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
  28. dict_character = list(self.character_str)
  29. else:
  30. with open(character_dict_path, "rb") as fin:
  31. lines = fin.readlines()
  32. for line in lines:
  33. line = line.decode('utf-8').strip("\n").strip("\r\n")
  34. self.character_str.append(line)
  35. if use_space_char:
  36. self.character_str.append(" ")
  37. dict_character = list(self.character_str)
  38. if "arabic" in character_dict_path:
  39. self.reverse = True
  40. dict_character = self.add_special_char(dict_character)
  41. self.dict = {}
  42. for i, char in enumerate(dict_character):
  43. self.dict[char] = i
  44. self.character = np.array(dict_character)
  45. def pred_reverse(self, pred):
  46. pred_re = []
  47. c_current = ""
  48. for c in pred:
  49. if not bool(re.search("[a-zA-Z0-9 :*./%+-]", c)):
  50. if c_current != "":
  51. pred_re.append(c_current)
  52. pred_re.append(c)
  53. c_current = ""
  54. else:
  55. c_current += c
  56. if c_current != "":
  57. pred_re.append(c_current)
  58. return "".join(pred_re[::-1])
  59. def add_special_char(self, dict_character):
  60. return dict_character
  61. def get_word_info(self, text, selection):
  62. """
  63. Group the decoded characters and record the corresponding decoded positions.
  64. Args:
  65. text: the decoded text
  66. selection: the bool array that identifies which columns of features are decoded as non-separated characters
  67. Returns:
  68. word_list: list of the grouped words
  69. word_col_list: list of decoding positions corresponding to each character in the grouped word
  70. state_list: list of marker to identify the type of grouping words, including two types of grouping words:
  71. - 'cn': continuous chinese characters (e.g., 你好啊)
  72. - 'en&num': continuous english characters (e.g., hello), number (e.g., 123, 1.123), or mixed of them connected by '-' (e.g., VGG-16)
  73. The remaining characters in text are treated as separators between groups (e.g., space, '(', ')', etc.).
  74. """
  75. state = None
  76. word_content = []
  77. word_col_content = []
  78. word_list = []
  79. word_col_list = []
  80. state_list = []
  81. valid_col = np.where(selection == True)[0]
  82. for c_i, char in enumerate(text):
  83. if "\u4e00" <= char <= "\u9fff":
  84. c_state = "cn"
  85. elif bool(re.search("[a-zA-Z0-9]", char)):
  86. c_state = "en&num"
  87. else:
  88. c_state = "splitter"
  89. if (
  90. char == "."
  91. and state == "en&num"
  92. and c_i + 1 < len(text)
  93. and bool(re.search("[0-9]", text[c_i + 1]))
  94. ): # grouping floating number
  95. c_state = "en&num"
  96. if (
  97. char == "-" and state == "en&num"
  98. ): # grouping word with '-', such as 'state-of-the-art'
  99. c_state = "en&num"
  100. if state == None:
  101. state = c_state
  102. if state != c_state:
  103. if len(word_content) != 0:
  104. word_list.append(word_content)
  105. word_col_list.append(word_col_content)
  106. state_list.append(state)
  107. word_content = []
  108. word_col_content = []
  109. state = c_state
  110. if state != "splitter":
  111. word_content.append(char)
  112. word_col_content.append(valid_col[c_i])
  113. if len(word_content) != 0:
  114. word_list.append(word_content)
  115. word_col_list.append(word_col_content)
  116. state_list.append(state)
  117. return word_list, word_col_list, state_list
  118. def decode(
  119. self,
  120. text_index,
  121. text_prob=None,
  122. is_remove_duplicate=False,
  123. return_word_box=False,
  124. ):
  125. """ convert text-index into text-label. """
  126. result_list = []
  127. batch_size = text_index.shape[0]
  128. blank_word = self.get_ignored_tokens()[0]
  129. for batch_idx in range(batch_size):
  130. probs = None if text_prob is None else np.array(text_prob[batch_idx])
  131. sequence = text_index[batch_idx]
  132. final_mask = sequence != blank_word
  133. if is_remove_duplicate:
  134. duplicate_mask = np.insert(sequence[1:] != sequence[:-1], 0, True)
  135. final_mask &= duplicate_mask
  136. sequence = sequence[final_mask]
  137. probs = None if probs is None else probs[final_mask]
  138. text = "".join(self.character[sequence])
  139. if text_prob is not None and probs is not None and len(probs) > 0:
  140. mean_conf = np.mean(probs)
  141. else:
  142. # 如果没有提供概率或最终结果为空,则默认置信度为1.0
  143. mean_conf = 1.0
  144. result_list.append((text, mean_conf))
  145. return result_list
  146. def get_ignored_tokens(self):
  147. return [0] # for ctc blank
  148. class CTCLabelDecode(BaseRecLabelDecode):
  149. """ Convert between text-label and text-index """
  150. def __init__(self,
  151. character_dict_path=None,
  152. use_space_char=False,
  153. **kwargs):
  154. super(CTCLabelDecode, self).__init__(character_dict_path,
  155. use_space_char)
  156. def __call__(self, preds, label=None, return_word_box=False, *args, **kwargs):
  157. preds_prob, preds_idx = preds.max(axis=2)
  158. text = self.decode(
  159. preds_idx.cpu().numpy(),
  160. preds_prob.float().cpu().numpy(),
  161. is_remove_duplicate=True,
  162. return_word_box=return_word_box,
  163. )
  164. if return_word_box:
  165. for rec_idx, rec in enumerate(text):
  166. wh_ratio = kwargs["wh_ratio_list"][rec_idx]
  167. max_wh_ratio = kwargs["max_wh_ratio"]
  168. rec[2][0] = rec[2][0] * (wh_ratio / max_wh_ratio)
  169. if label is None:
  170. return text
  171. label = self.decode(label.cpu().numpy())
  172. return text, label
  173. def add_special_char(self, dict_character):
  174. dict_character = ['blank'] + dict_character
  175. return dict_character
  176. class NRTRLabelDecode(BaseRecLabelDecode):
  177. """ Convert between text-label and text-index """
  178. def __init__(self, character_dict_path=None, use_space_char=True, **kwargs):
  179. super(NRTRLabelDecode, self).__init__(character_dict_path,
  180. use_space_char)
  181. def __call__(self, preds, label=None, *args, **kwargs):
  182. if len(preds) == 2:
  183. preds_id = preds[0]
  184. preds_prob = preds[1]
  185. if isinstance(preds_id, torch.Tensor):
  186. preds_id = preds_id.numpy()
  187. if isinstance(preds_prob, torch.Tensor):
  188. preds_prob = preds_prob.numpy()
  189. if preds_id[0][0] == 2:
  190. preds_idx = preds_id[:, 1:]
  191. preds_prob = preds_prob[:, 1:]
  192. else:
  193. preds_idx = preds_id
  194. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  195. if label is None:
  196. return text
  197. label = self.decode(label[:, 1:])
  198. else:
  199. if isinstance(preds, torch.Tensor):
  200. preds = preds.numpy()
  201. preds_idx = preds.argmax(axis=2)
  202. preds_prob = preds.max(axis=2)
  203. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  204. if label is None:
  205. return text
  206. label = self.decode(label[:, 1:])
  207. return text, label
  208. def add_special_char(self, dict_character):
  209. dict_character = ['blank', '<unk>', '<s>', '</s>'] + dict_character
  210. return dict_character
  211. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  212. """ convert text-index into text-label. """
  213. result_list = []
  214. batch_size = len(text_index)
  215. for batch_idx in range(batch_size):
  216. char_list = []
  217. conf_list = []
  218. for idx in range(len(text_index[batch_idx])):
  219. try:
  220. char_idx = self.character[int(text_index[batch_idx][idx])]
  221. except:
  222. continue
  223. if char_idx == '</s>': # end
  224. break
  225. char_list.append(char_idx)
  226. if text_prob is not None:
  227. conf_list.append(text_prob[batch_idx][idx])
  228. else:
  229. conf_list.append(1)
  230. text = ''.join(char_list)
  231. result_list.append((text.lower(), np.mean(conf_list).tolist()))
  232. return result_list
  233. class ViTSTRLabelDecode(NRTRLabelDecode):
  234. """ Convert between text-label and text-index """
  235. def __init__(self, character_dict_path=None, use_space_char=False,
  236. **kwargs):
  237. super(ViTSTRLabelDecode, self).__init__(character_dict_path,
  238. use_space_char)
  239. def __call__(self, preds, label=None, *args, **kwargs):
  240. if isinstance(preds, torch.Tensor):
  241. preds = preds[:, 1:].numpy()
  242. else:
  243. preds = preds[:, 1:]
  244. preds_idx = preds.argmax(axis=2)
  245. preds_prob = preds.max(axis=2)
  246. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  247. if label is None:
  248. return text
  249. label = self.decode(label[:, 1:])
  250. return text, label
  251. def add_special_char(self, dict_character):
  252. dict_character = ['<s>', '</s>'] + dict_character
  253. return dict_character
  254. class AttnLabelDecode(BaseRecLabelDecode):
  255. """ Convert between text-label and text-index """
  256. def __init__(self,
  257. character_dict_path=None,
  258. use_space_char=False,
  259. **kwargs):
  260. super(AttnLabelDecode, self).__init__(character_dict_path,
  261. use_space_char)
  262. def add_special_char(self, dict_character):
  263. self.beg_str = "sos"
  264. self.end_str = "eos"
  265. dict_character = dict_character
  266. dict_character = [self.beg_str] + dict_character + [self.end_str]
  267. return dict_character
  268. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  269. """ convert text-index into text-label. """
  270. result_list = []
  271. ignored_tokens = self.get_ignored_tokens()
  272. [beg_idx, end_idx] = self.get_ignored_tokens()
  273. batch_size = len(text_index)
  274. for batch_idx in range(batch_size):
  275. char_list = []
  276. conf_list = []
  277. for idx in range(len(text_index[batch_idx])):
  278. if text_index[batch_idx][idx] in ignored_tokens:
  279. continue
  280. if int(text_index[batch_idx][idx]) == int(end_idx):
  281. break
  282. if is_remove_duplicate:
  283. # only for predict
  284. if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
  285. batch_idx][idx]:
  286. continue
  287. char_list.append(self.character[int(text_index[batch_idx][
  288. idx])])
  289. if text_prob is not None:
  290. conf_list.append(text_prob[batch_idx][idx])
  291. else:
  292. conf_list.append(1)
  293. text = ''.join(char_list)
  294. result_list.append((text, np.mean(conf_list)))
  295. return result_list
  296. def __call__(self, preds, label=None, *args, **kwargs):
  297. """
  298. text = self.decode(text)
  299. if label is None:
  300. return text
  301. else:
  302. label = self.decode(label, is_remove_duplicate=False)
  303. return text, label
  304. """
  305. if isinstance(preds, torch.Tensor):
  306. preds = preds.cpu().numpy()
  307. preds_idx = preds.argmax(axis=2)
  308. preds_prob = preds.max(axis=2)
  309. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  310. if label is None:
  311. return text
  312. label = self.decode(label, is_remove_duplicate=False)
  313. return text, label
  314. def get_ignored_tokens(self):
  315. beg_idx = self.get_beg_end_flag_idx("beg")
  316. end_idx = self.get_beg_end_flag_idx("end")
  317. return [beg_idx, end_idx]
  318. def get_beg_end_flag_idx(self, beg_or_end):
  319. if beg_or_end == "beg":
  320. idx = np.array(self.dict[self.beg_str])
  321. elif beg_or_end == "end":
  322. idx = np.array(self.dict[self.end_str])
  323. else:
  324. assert False, "unsupport type %s in get_beg_end_flag_idx" \
  325. % beg_or_end
  326. return idx
  327. class RFLLabelDecode(BaseRecLabelDecode):
  328. """ Convert between text-label and text-index """
  329. def __init__(self, character_dict_path=None, use_space_char=False,
  330. **kwargs):
  331. super(RFLLabelDecode, self).__init__(character_dict_path,
  332. use_space_char)
  333. def add_special_char(self, dict_character):
  334. self.beg_str = "sos"
  335. self.end_str = "eos"
  336. dict_character = dict_character
  337. dict_character = [self.beg_str] + dict_character + [self.end_str]
  338. return dict_character
  339. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  340. """ convert text-index into text-label. """
  341. result_list = []
  342. ignored_tokens = self.get_ignored_tokens()
  343. [beg_idx, end_idx] = self.get_ignored_tokens()
  344. batch_size = len(text_index)
  345. for batch_idx in range(batch_size):
  346. char_list = []
  347. conf_list = []
  348. for idx in range(len(text_index[batch_idx])):
  349. if text_index[batch_idx][idx] in ignored_tokens:
  350. continue
  351. if int(text_index[batch_idx][idx]) == int(end_idx):
  352. break
  353. if is_remove_duplicate:
  354. # only for predict
  355. if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
  356. batch_idx][idx]:
  357. continue
  358. char_list.append(self.character[int(text_index[batch_idx][
  359. idx])])
  360. if text_prob is not None:
  361. conf_list.append(text_prob[batch_idx][idx])
  362. else:
  363. conf_list.append(1)
  364. text = ''.join(char_list)
  365. result_list.append((text, np.mean(conf_list).tolist()))
  366. return result_list
  367. def __call__(self, preds, label=None, *args, **kwargs):
  368. # if seq_outputs is not None:
  369. if isinstance(preds, tuple) or isinstance(preds, list):
  370. cnt_outputs, seq_outputs = preds
  371. if isinstance(seq_outputs, torch.Tensor):
  372. seq_outputs = seq_outputs.numpy()
  373. preds_idx = seq_outputs.argmax(axis=2)
  374. preds_prob = seq_outputs.max(axis=2)
  375. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  376. if label is None:
  377. return text
  378. label = self.decode(label, is_remove_duplicate=False)
  379. return text, label
  380. else:
  381. cnt_outputs = preds
  382. if isinstance(cnt_outputs, torch.Tensor):
  383. cnt_outputs = cnt_outputs.numpy()
  384. cnt_length = []
  385. for lens in cnt_outputs:
  386. length = round(np.sum(lens))
  387. cnt_length.append(length)
  388. if label is None:
  389. return cnt_length
  390. label = self.decode(label, is_remove_duplicate=False)
  391. length = [len(res[0]) for res in label]
  392. return cnt_length, length
  393. def get_ignored_tokens(self):
  394. beg_idx = self.get_beg_end_flag_idx("beg")
  395. end_idx = self.get_beg_end_flag_idx("end")
  396. return [beg_idx, end_idx]
  397. def get_beg_end_flag_idx(self, beg_or_end):
  398. if beg_or_end == "beg":
  399. idx = np.array(self.dict[self.beg_str])
  400. elif beg_or_end == "end":
  401. idx = np.array(self.dict[self.end_str])
  402. else:
  403. assert False, "unsupport type %s in get_beg_end_flag_idx" \
  404. % beg_or_end
  405. return idx
  406. class SRNLabelDecode(BaseRecLabelDecode):
  407. """ Convert between text-label and text-index """
  408. def __init__(self,
  409. character_dict_path=None,
  410. use_space_char=False,
  411. **kwargs):
  412. self.max_text_length = kwargs.get('max_text_length', 25)
  413. super(SRNLabelDecode, self).__init__(character_dict_path,
  414. use_space_char)
  415. def __call__(self, preds, label=None, *args, **kwargs):
  416. pred = preds['predict']
  417. char_num = len(self.character_str) + 2
  418. if isinstance(pred, torch.Tensor):
  419. pred = pred.numpy()
  420. pred = np.reshape(pred, [-1, char_num])
  421. preds_idx = np.argmax(pred, axis=1)
  422. preds_prob = np.max(pred, axis=1)
  423. preds_idx = np.reshape(preds_idx, [-1, self.max_text_length])
  424. preds_prob = np.reshape(preds_prob, [-1, self.max_text_length])
  425. text = self.decode(preds_idx, preds_prob)
  426. if label is None:
  427. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  428. return text
  429. label = self.decode(label)
  430. return text, label
  431. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  432. """ convert text-index into text-label. """
  433. result_list = []
  434. ignored_tokens = self.get_ignored_tokens()
  435. batch_size = len(text_index)
  436. for batch_idx in range(batch_size):
  437. char_list = []
  438. conf_list = []
  439. for idx in range(len(text_index[batch_idx])):
  440. if text_index[batch_idx][idx] in ignored_tokens:
  441. continue
  442. if is_remove_duplicate:
  443. # only for predict
  444. if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
  445. batch_idx][idx]:
  446. continue
  447. char_list.append(self.character[int(text_index[batch_idx][
  448. idx])])
  449. if text_prob is not None:
  450. conf_list.append(text_prob[batch_idx][idx])
  451. else:
  452. conf_list.append(1)
  453. text = ''.join(char_list)
  454. result_list.append((text, np.mean(conf_list)))
  455. return result_list
  456. def add_special_char(self, dict_character):
  457. dict_character = dict_character + [self.beg_str, self.end_str]
  458. return dict_character
  459. def get_ignored_tokens(self):
  460. beg_idx = self.get_beg_end_flag_idx("beg")
  461. end_idx = self.get_beg_end_flag_idx("end")
  462. return [beg_idx, end_idx]
  463. def get_beg_end_flag_idx(self, beg_or_end):
  464. if beg_or_end == "beg":
  465. idx = np.array(self.dict[self.beg_str])
  466. elif beg_or_end == "end":
  467. idx = np.array(self.dict[self.end_str])
  468. else:
  469. assert False, "unsupport type %s in get_beg_end_flag_idx" \
  470. % beg_or_end
  471. return idx
  472. class TableLabelDecode(object):
  473. """ """
  474. def __init__(self,
  475. character_dict_path,
  476. **kwargs):
  477. list_character, list_elem = self.load_char_elem_dict(character_dict_path)
  478. list_character = self.add_special_char(list_character)
  479. list_elem = self.add_special_char(list_elem)
  480. self.dict_character = {}
  481. self.dict_idx_character = {}
  482. for i, char in enumerate(list_character):
  483. self.dict_idx_character[i] = char
  484. self.dict_character[char] = i
  485. self.dict_elem = {}
  486. self.dict_idx_elem = {}
  487. for i, elem in enumerate(list_elem):
  488. self.dict_idx_elem[i] = elem
  489. self.dict_elem[elem] = i
  490. def load_char_elem_dict(self, character_dict_path):
  491. list_character = []
  492. list_elem = []
  493. with open(character_dict_path, "rb") as fin:
  494. lines = fin.readlines()
  495. substr = lines[0].decode('utf-8').strip("\n").strip("\r\n").split("\t")
  496. character_num = int(substr[0])
  497. elem_num = int(substr[1])
  498. for cno in range(1, 1 + character_num):
  499. character = lines[cno].decode('utf-8').strip("\n").strip("\r\n")
  500. list_character.append(character)
  501. for eno in range(1 + character_num, 1 + character_num + elem_num):
  502. elem = lines[eno].decode('utf-8').strip("\n").strip("\r\n")
  503. list_elem.append(elem)
  504. return list_character, list_elem
  505. def add_special_char(self, list_character):
  506. self.beg_str = "sos"
  507. self.end_str = "eos"
  508. list_character = [self.beg_str] + list_character + [self.end_str]
  509. return list_character
  510. def __call__(self, preds):
  511. structure_probs = preds['structure_probs']
  512. loc_preds = preds['loc_preds']
  513. if isinstance(structure_probs,torch.Tensor):
  514. structure_probs = structure_probs.numpy()
  515. if isinstance(loc_preds,torch.Tensor):
  516. loc_preds = loc_preds.numpy()
  517. structure_idx = structure_probs.argmax(axis=2)
  518. structure_probs = structure_probs.max(axis=2)
  519. structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode(structure_idx,
  520. structure_probs, 'elem')
  521. res_html_code_list = []
  522. res_loc_list = []
  523. batch_num = len(structure_str)
  524. for bno in range(batch_num):
  525. res_loc = []
  526. for sno in range(len(structure_str[bno])):
  527. text = structure_str[bno][sno]
  528. if text in ['<td>', '<td']:
  529. pos = structure_pos[bno][sno]
  530. res_loc.append(loc_preds[bno, pos])
  531. res_html_code = ''.join(structure_str[bno])
  532. res_loc = np.array(res_loc)
  533. res_html_code_list.append(res_html_code)
  534. res_loc_list.append(res_loc)
  535. return {'res_html_code': res_html_code_list, 'res_loc': res_loc_list, 'res_score_list': result_score_list,
  536. 'res_elem_idx_list': result_elem_idx_list,'structure_str_list':structure_str}
  537. def decode(self, text_index, structure_probs, char_or_elem):
  538. """convert text-label into text-index.
  539. """
  540. if char_or_elem == "char":
  541. current_dict = self.dict_idx_character
  542. else:
  543. current_dict = self.dict_idx_elem
  544. ignored_tokens = self.get_ignored_tokens('elem')
  545. beg_idx, end_idx = ignored_tokens
  546. result_list = []
  547. result_pos_list = []
  548. result_score_list = []
  549. result_elem_idx_list = []
  550. batch_size = len(text_index)
  551. for batch_idx in range(batch_size):
  552. char_list = []
  553. elem_pos_list = []
  554. elem_idx_list = []
  555. score_list = []
  556. for idx in range(len(text_index[batch_idx])):
  557. tmp_elem_idx = int(text_index[batch_idx][idx])
  558. if idx > 0 and tmp_elem_idx == end_idx:
  559. break
  560. if tmp_elem_idx in ignored_tokens:
  561. continue
  562. char_list.append(current_dict[tmp_elem_idx])
  563. elem_pos_list.append(idx)
  564. score_list.append(structure_probs[batch_idx, idx])
  565. elem_idx_list.append(tmp_elem_idx)
  566. result_list.append(char_list)
  567. result_pos_list.append(elem_pos_list)
  568. result_score_list.append(score_list)
  569. result_elem_idx_list.append(elem_idx_list)
  570. return result_list, result_pos_list, result_score_list, result_elem_idx_list
  571. def get_ignored_tokens(self, char_or_elem):
  572. beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem)
  573. end_idx = self.get_beg_end_flag_idx("end", char_or_elem)
  574. return [beg_idx, end_idx]
  575. def get_beg_end_flag_idx(self, beg_or_end, char_or_elem):
  576. if char_or_elem == "char":
  577. if beg_or_end == "beg":
  578. idx = self.dict_character[self.beg_str]
  579. elif beg_or_end == "end":
  580. idx = self.dict_character[self.end_str]
  581. else:
  582. assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \
  583. % beg_or_end
  584. elif char_or_elem == "elem":
  585. if beg_or_end == "beg":
  586. idx = self.dict_elem[self.beg_str]
  587. elif beg_or_end == "end":
  588. idx = self.dict_elem[self.end_str]
  589. else:
  590. assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \
  591. % beg_or_end
  592. else:
  593. assert False, "Unsupport type %s in char_or_elem" \
  594. % char_or_elem
  595. return idx
  596. class SARLabelDecode(BaseRecLabelDecode):
  597. """ Convert between text-label and text-index """
  598. def __init__(self, character_dict_path=None, use_space_char=False,
  599. **kwargs):
  600. super(SARLabelDecode, self).__init__(character_dict_path,
  601. use_space_char)
  602. self.rm_symbol = kwargs.get('rm_symbol', False)
  603. def add_special_char(self, dict_character):
  604. beg_end_str = "<BOS/EOS>"
  605. unknown_str = "<UKN>"
  606. padding_str = "<PAD>"
  607. dict_character = dict_character + [unknown_str]
  608. self.unknown_idx = len(dict_character) - 1
  609. dict_character = dict_character + [beg_end_str]
  610. self.start_idx = len(dict_character) - 1
  611. self.end_idx = len(dict_character) - 1
  612. dict_character = dict_character + [padding_str]
  613. self.padding_idx = len(dict_character) - 1
  614. return dict_character
  615. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  616. """ convert text-index into text-label. """
  617. result_list = []
  618. ignored_tokens = self.get_ignored_tokens()
  619. batch_size = len(text_index)
  620. for batch_idx in range(batch_size):
  621. char_list = []
  622. conf_list = []
  623. for idx in range(len(text_index[batch_idx])):
  624. if text_index[batch_idx][idx] in ignored_tokens:
  625. continue
  626. if int(text_index[batch_idx][idx]) == int(self.end_idx):
  627. if text_prob is None and idx == 0:
  628. continue
  629. else:
  630. break
  631. if is_remove_duplicate:
  632. # only for predict
  633. if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
  634. batch_idx][idx]:
  635. continue
  636. char_list.append(self.character[int(text_index[batch_idx][
  637. idx])])
  638. if text_prob is not None:
  639. conf_list.append(text_prob[batch_idx][idx])
  640. else:
  641. conf_list.append(1)
  642. text = ''.join(char_list)
  643. if self.rm_symbol:
  644. comp = re.compile('[^A-Z^a-z^0-9^\u4e00-\u9fa5]')
  645. text = text.lower()
  646. text = comp.sub('', text)
  647. result_list.append((text, np.mean(conf_list).tolist()))
  648. return result_list
  649. def __call__(self, preds, label=None, *args, **kwargs):
  650. if isinstance(preds, torch.Tensor):
  651. preds = preds.cpu().numpy()
  652. preds_idx = preds.argmax(axis=2)
  653. preds_prob = preds.max(axis=2)
  654. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  655. if label is None:
  656. return text
  657. label = self.decode(label, is_remove_duplicate=False)
  658. return text, label
  659. def get_ignored_tokens(self):
  660. return [self.padding_idx]
  661. class CANLabelDecode(BaseRecLabelDecode):
  662. """ Convert between latex-symbol and symbol-index """
  663. def __init__(self, character_dict_path=None, use_space_char=False,
  664. **kwargs):
  665. super(CANLabelDecode, self).__init__(character_dict_path,
  666. use_space_char)
  667. def decode(self, text_index, preds_prob=None):
  668. result_list = []
  669. batch_size = len(text_index)
  670. for batch_idx in range(batch_size):
  671. seq_end = text_index[batch_idx].argmin(0)
  672. idx_list = text_index[batch_idx][:seq_end].tolist()
  673. symbol_list = [self.character[idx] for idx in idx_list]
  674. probs = []
  675. if preds_prob is not None:
  676. probs = preds_prob[batch_idx][:len(symbol_list)].tolist()
  677. result_list.append([' '.join(symbol_list), probs])
  678. return result_list
  679. def __call__(self, preds, label=None, *args, **kwargs):
  680. pred_prob, _, _, _ = preds
  681. preds_idx = pred_prob.argmax(axis=2)
  682. text = self.decode(preds_idx)
  683. if label is None:
  684. return text
  685. label = self.decode(label)
  686. return text, label