rec_postprocess.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import re
  15. import numpy as np
  16. import torch
  17. class BaseRecLabelDecode(object):
  18. """ Convert between text-label and text-index """
  19. def __init__(self,
  20. character_dict_path=None,
  21. use_space_char=False):
  22. self.beg_str = "sos"
  23. self.end_str = "eos"
  24. self.reverse = False
  25. self.character_str = []
  26. if character_dict_path is None:
  27. self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
  28. dict_character = list(self.character_str)
  29. else:
  30. with open(character_dict_path, "rb") as fin:
  31. lines = fin.readlines()
  32. for line in lines:
  33. line = line.decode('utf-8').strip("\n").strip("\r\n")
  34. self.character_str.append(line)
  35. if use_space_char:
  36. self.character_str.append(" ")
  37. dict_character = list(self.character_str)
  38. if "arabic" in character_dict_path:
  39. self.reverse = True
  40. dict_character = self.add_special_char(dict_character)
  41. self.dict = {}
  42. for i, char in enumerate(dict_character):
  43. self.dict[char] = i
  44. self.character = dict_character
  45. def pred_reverse(self, pred):
  46. pred_re = []
  47. c_current = ""
  48. for c in pred:
  49. if not bool(re.search("[a-zA-Z0-9 :*./%+-]", c)):
  50. if c_current != "":
  51. pred_re.append(c_current)
  52. pred_re.append(c)
  53. c_current = ""
  54. else:
  55. c_current += c
  56. if c_current != "":
  57. pred_re.append(c_current)
  58. return "".join(pred_re[::-1])
  59. def add_special_char(self, dict_character):
  60. return dict_character
  61. def get_word_info(self, text, selection):
  62. """
  63. Group the decoded characters and record the corresponding decoded positions.
  64. Args:
  65. text: the decoded text
  66. selection: the bool array that identifies which columns of features are decoded as non-separated characters
  67. Returns:
  68. word_list: list of the grouped words
  69. word_col_list: list of decoding positions corresponding to each character in the grouped word
  70. state_list: list of marker to identify the type of grouping words, including two types of grouping words:
  71. - 'cn': continuous chinese characters (e.g., 你好啊)
  72. - 'en&num': continuous english characters (e.g., hello), number (e.g., 123, 1.123), or mixed of them connected by '-' (e.g., VGG-16)
  73. The remaining characters in text are treated as separators between groups (e.g., space, '(', ')', etc.).
  74. """
  75. state = None
  76. word_content = []
  77. word_col_content = []
  78. word_list = []
  79. word_col_list = []
  80. state_list = []
  81. valid_col = np.where(selection == True)[0]
  82. for c_i, char in enumerate(text):
  83. if "\u4e00" <= char <= "\u9fff":
  84. c_state = "cn"
  85. elif bool(re.search("[a-zA-Z0-9]", char)):
  86. c_state = "en&num"
  87. else:
  88. c_state = "splitter"
  89. if (
  90. char == "."
  91. and state == "en&num"
  92. and c_i + 1 < len(text)
  93. and bool(re.search("[0-9]", text[c_i + 1]))
  94. ): # grouping floating number
  95. c_state = "en&num"
  96. if (
  97. char == "-" and state == "en&num"
  98. ): # grouping word with '-', such as 'state-of-the-art'
  99. c_state = "en&num"
  100. if state == None:
  101. state = c_state
  102. if state != c_state:
  103. if len(word_content) != 0:
  104. word_list.append(word_content)
  105. word_col_list.append(word_col_content)
  106. state_list.append(state)
  107. word_content = []
  108. word_col_content = []
  109. state = c_state
  110. if state != "splitter":
  111. word_content.append(char)
  112. word_col_content.append(valid_col[c_i])
  113. if len(word_content) != 0:
  114. word_list.append(word_content)
  115. word_col_list.append(word_col_content)
  116. state_list.append(state)
  117. return word_list, word_col_list, state_list
  118. def decode(
  119. self,
  120. text_index,
  121. text_prob=None,
  122. is_remove_duplicate=False,
  123. return_word_box=False,
  124. ):
  125. """ convert text-index into text-label. """
  126. result_list = []
  127. ignored_tokens = self.get_ignored_tokens()
  128. batch_size = len(text_index)
  129. for batch_idx in range(batch_size):
  130. char_list = []
  131. conf_list = []
  132. for idx in range(len(text_index[batch_idx])):
  133. if text_index[batch_idx][idx] in ignored_tokens:
  134. continue
  135. if is_remove_duplicate:
  136. # only for predict
  137. if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
  138. batch_idx][idx]:
  139. continue
  140. char_list.append(self.character[int(text_index[batch_idx][
  141. idx])])
  142. if text_prob is not None:
  143. conf_list.append(text_prob[batch_idx][idx])
  144. else:
  145. conf_list.append(1)
  146. text = ''.join(char_list)
  147. result_list.append((text, np.mean(conf_list)))
  148. return result_list
  149. def get_ignored_tokens(self):
  150. return [0] # for ctc blank
  151. class CTCLabelDecode(BaseRecLabelDecode):
  152. """ Convert between text-label and text-index """
  153. def __init__(self,
  154. character_dict_path=None,
  155. use_space_char=False,
  156. **kwargs):
  157. super(CTCLabelDecode, self).__init__(character_dict_path,
  158. use_space_char)
  159. def __call__(self, preds, label=None, return_word_box=False, *args, **kwargs):
  160. if isinstance(preds, torch.Tensor):
  161. preds = preds.numpy()
  162. preds_idx = preds.argmax(axis=2)
  163. preds_prob = preds.max(axis=2)
  164. text = self.decode(
  165. preds_idx,
  166. preds_prob,
  167. is_remove_duplicate=True,
  168. return_word_box=return_word_box,
  169. )
  170. if return_word_box:
  171. for rec_idx, rec in enumerate(text):
  172. wh_ratio = kwargs["wh_ratio_list"][rec_idx]
  173. max_wh_ratio = kwargs["max_wh_ratio"]
  174. rec[2][0] = rec[2][0] * (wh_ratio / max_wh_ratio)
  175. if label is None:
  176. return text
  177. label = self.decode(label)
  178. return text, label
  179. def add_special_char(self, dict_character):
  180. dict_character = ['blank'] + dict_character
  181. return dict_character
  182. class NRTRLabelDecode(BaseRecLabelDecode):
  183. """ Convert between text-label and text-index """
  184. def __init__(self, character_dict_path=None, use_space_char=True, **kwargs):
  185. super(NRTRLabelDecode, self).__init__(character_dict_path,
  186. use_space_char)
  187. def __call__(self, preds, label=None, *args, **kwargs):
  188. if len(preds) == 2:
  189. preds_id = preds[0]
  190. preds_prob = preds[1]
  191. if isinstance(preds_id, torch.Tensor):
  192. preds_id = preds_id.numpy()
  193. if isinstance(preds_prob, torch.Tensor):
  194. preds_prob = preds_prob.numpy()
  195. if preds_id[0][0] == 2:
  196. preds_idx = preds_id[:, 1:]
  197. preds_prob = preds_prob[:, 1:]
  198. else:
  199. preds_idx = preds_id
  200. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  201. if label is None:
  202. return text
  203. label = self.decode(label[:, 1:])
  204. else:
  205. if isinstance(preds, torch.Tensor):
  206. preds = preds.numpy()
  207. preds_idx = preds.argmax(axis=2)
  208. preds_prob = preds.max(axis=2)
  209. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  210. if label is None:
  211. return text
  212. label = self.decode(label[:, 1:])
  213. return text, label
  214. def add_special_char(self, dict_character):
  215. dict_character = ['blank', '<unk>', '<s>', '</s>'] + dict_character
  216. return dict_character
  217. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  218. """ convert text-index into text-label. """
  219. result_list = []
  220. batch_size = len(text_index)
  221. for batch_idx in range(batch_size):
  222. char_list = []
  223. conf_list = []
  224. for idx in range(len(text_index[batch_idx])):
  225. try:
  226. char_idx = self.character[int(text_index[batch_idx][idx])]
  227. except:
  228. continue
  229. if char_idx == '</s>': # end
  230. break
  231. char_list.append(char_idx)
  232. if text_prob is not None:
  233. conf_list.append(text_prob[batch_idx][idx])
  234. else:
  235. conf_list.append(1)
  236. text = ''.join(char_list)
  237. result_list.append((text.lower(), np.mean(conf_list).tolist()))
  238. return result_list
  239. class ViTSTRLabelDecode(NRTRLabelDecode):
  240. """ Convert between text-label and text-index """
  241. def __init__(self, character_dict_path=None, use_space_char=False,
  242. **kwargs):
  243. super(ViTSTRLabelDecode, self).__init__(character_dict_path,
  244. use_space_char)
  245. def __call__(self, preds, label=None, *args, **kwargs):
  246. if isinstance(preds, torch.Tensor):
  247. preds = preds[:, 1:].numpy()
  248. else:
  249. preds = preds[:, 1:]
  250. preds_idx = preds.argmax(axis=2)
  251. preds_prob = preds.max(axis=2)
  252. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  253. if label is None:
  254. return text
  255. label = self.decode(label[:, 1:])
  256. return text, label
  257. def add_special_char(self, dict_character):
  258. dict_character = ['<s>', '</s>'] + dict_character
  259. return dict_character
  260. class AttnLabelDecode(BaseRecLabelDecode):
  261. """ Convert between text-label and text-index """
  262. def __init__(self,
  263. character_dict_path=None,
  264. use_space_char=False,
  265. **kwargs):
  266. super(AttnLabelDecode, self).__init__(character_dict_path,
  267. use_space_char)
  268. def add_special_char(self, dict_character):
  269. self.beg_str = "sos"
  270. self.end_str = "eos"
  271. dict_character = dict_character
  272. dict_character = [self.beg_str] + dict_character + [self.end_str]
  273. return dict_character
  274. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  275. """ convert text-index into text-label. """
  276. result_list = []
  277. ignored_tokens = self.get_ignored_tokens()
  278. [beg_idx, end_idx] = self.get_ignored_tokens()
  279. batch_size = len(text_index)
  280. for batch_idx in range(batch_size):
  281. char_list = []
  282. conf_list = []
  283. for idx in range(len(text_index[batch_idx])):
  284. if text_index[batch_idx][idx] in ignored_tokens:
  285. continue
  286. if int(text_index[batch_idx][idx]) == int(end_idx):
  287. break
  288. if is_remove_duplicate:
  289. # only for predict
  290. if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
  291. batch_idx][idx]:
  292. continue
  293. char_list.append(self.character[int(text_index[batch_idx][
  294. idx])])
  295. if text_prob is not None:
  296. conf_list.append(text_prob[batch_idx][idx])
  297. else:
  298. conf_list.append(1)
  299. text = ''.join(char_list)
  300. result_list.append((text, np.mean(conf_list)))
  301. return result_list
  302. def __call__(self, preds, label=None, *args, **kwargs):
  303. """
  304. text = self.decode(text)
  305. if label is None:
  306. return text
  307. else:
  308. label = self.decode(label, is_remove_duplicate=False)
  309. return text, label
  310. """
  311. if isinstance(preds, torch.Tensor):
  312. preds = preds.cpu().numpy()
  313. preds_idx = preds.argmax(axis=2)
  314. preds_prob = preds.max(axis=2)
  315. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  316. if label is None:
  317. return text
  318. label = self.decode(label, is_remove_duplicate=False)
  319. return text, label
  320. def get_ignored_tokens(self):
  321. beg_idx = self.get_beg_end_flag_idx("beg")
  322. end_idx = self.get_beg_end_flag_idx("end")
  323. return [beg_idx, end_idx]
  324. def get_beg_end_flag_idx(self, beg_or_end):
  325. if beg_or_end == "beg":
  326. idx = np.array(self.dict[self.beg_str])
  327. elif beg_or_end == "end":
  328. idx = np.array(self.dict[self.end_str])
  329. else:
  330. assert False, "unsupport type %s in get_beg_end_flag_idx" \
  331. % beg_or_end
  332. return idx
  333. class RFLLabelDecode(BaseRecLabelDecode):
  334. """ Convert between text-label and text-index """
  335. def __init__(self, character_dict_path=None, use_space_char=False,
  336. **kwargs):
  337. super(RFLLabelDecode, self).__init__(character_dict_path,
  338. use_space_char)
  339. def add_special_char(self, dict_character):
  340. self.beg_str = "sos"
  341. self.end_str = "eos"
  342. dict_character = dict_character
  343. dict_character = [self.beg_str] + dict_character + [self.end_str]
  344. return dict_character
  345. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  346. """ convert text-index into text-label. """
  347. result_list = []
  348. ignored_tokens = self.get_ignored_tokens()
  349. [beg_idx, end_idx] = self.get_ignored_tokens()
  350. batch_size = len(text_index)
  351. for batch_idx in range(batch_size):
  352. char_list = []
  353. conf_list = []
  354. for idx in range(len(text_index[batch_idx])):
  355. if text_index[batch_idx][idx] in ignored_tokens:
  356. continue
  357. if int(text_index[batch_idx][idx]) == int(end_idx):
  358. break
  359. if is_remove_duplicate:
  360. # only for predict
  361. if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
  362. batch_idx][idx]:
  363. continue
  364. char_list.append(self.character[int(text_index[batch_idx][
  365. idx])])
  366. if text_prob is not None:
  367. conf_list.append(text_prob[batch_idx][idx])
  368. else:
  369. conf_list.append(1)
  370. text = ''.join(char_list)
  371. result_list.append((text, np.mean(conf_list).tolist()))
  372. return result_list
  373. def __call__(self, preds, label=None, *args, **kwargs):
  374. # if seq_outputs is not None:
  375. if isinstance(preds, tuple) or isinstance(preds, list):
  376. cnt_outputs, seq_outputs = preds
  377. if isinstance(seq_outputs, torch.Tensor):
  378. seq_outputs = seq_outputs.numpy()
  379. preds_idx = seq_outputs.argmax(axis=2)
  380. preds_prob = seq_outputs.max(axis=2)
  381. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  382. if label is None:
  383. return text
  384. label = self.decode(label, is_remove_duplicate=False)
  385. return text, label
  386. else:
  387. cnt_outputs = preds
  388. if isinstance(cnt_outputs, torch.Tensor):
  389. cnt_outputs = cnt_outputs.numpy()
  390. cnt_length = []
  391. for lens in cnt_outputs:
  392. length = round(np.sum(lens))
  393. cnt_length.append(length)
  394. if label is None:
  395. return cnt_length
  396. label = self.decode(label, is_remove_duplicate=False)
  397. length = [len(res[0]) for res in label]
  398. return cnt_length, length
  399. def get_ignored_tokens(self):
  400. beg_idx = self.get_beg_end_flag_idx("beg")
  401. end_idx = self.get_beg_end_flag_idx("end")
  402. return [beg_idx, end_idx]
  403. def get_beg_end_flag_idx(self, beg_or_end):
  404. if beg_or_end == "beg":
  405. idx = np.array(self.dict[self.beg_str])
  406. elif beg_or_end == "end":
  407. idx = np.array(self.dict[self.end_str])
  408. else:
  409. assert False, "unsupport type %s in get_beg_end_flag_idx" \
  410. % beg_or_end
  411. return idx
  412. class SRNLabelDecode(BaseRecLabelDecode):
  413. """ Convert between text-label and text-index """
  414. def __init__(self,
  415. character_dict_path=None,
  416. use_space_char=False,
  417. **kwargs):
  418. self.max_text_length = kwargs.get('max_text_length', 25)
  419. super(SRNLabelDecode, self).__init__(character_dict_path,
  420. use_space_char)
  421. def __call__(self, preds, label=None, *args, **kwargs):
  422. pred = preds['predict']
  423. char_num = len(self.character_str) + 2
  424. if isinstance(pred, torch.Tensor):
  425. pred = pred.numpy()
  426. pred = np.reshape(pred, [-1, char_num])
  427. preds_idx = np.argmax(pred, axis=1)
  428. preds_prob = np.max(pred, axis=1)
  429. preds_idx = np.reshape(preds_idx, [-1, self.max_text_length])
  430. preds_prob = np.reshape(preds_prob, [-1, self.max_text_length])
  431. text = self.decode(preds_idx, preds_prob)
  432. if label is None:
  433. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  434. return text
  435. label = self.decode(label)
  436. return text, label
  437. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  438. """ convert text-index into text-label. """
  439. result_list = []
  440. ignored_tokens = self.get_ignored_tokens()
  441. batch_size = len(text_index)
  442. for batch_idx in range(batch_size):
  443. char_list = []
  444. conf_list = []
  445. for idx in range(len(text_index[batch_idx])):
  446. if text_index[batch_idx][idx] in ignored_tokens:
  447. continue
  448. if is_remove_duplicate:
  449. # only for predict
  450. if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
  451. batch_idx][idx]:
  452. continue
  453. char_list.append(self.character[int(text_index[batch_idx][
  454. idx])])
  455. if text_prob is not None:
  456. conf_list.append(text_prob[batch_idx][idx])
  457. else:
  458. conf_list.append(1)
  459. text = ''.join(char_list)
  460. result_list.append((text, np.mean(conf_list)))
  461. return result_list
  462. def add_special_char(self, dict_character):
  463. dict_character = dict_character + [self.beg_str, self.end_str]
  464. return dict_character
  465. def get_ignored_tokens(self):
  466. beg_idx = self.get_beg_end_flag_idx("beg")
  467. end_idx = self.get_beg_end_flag_idx("end")
  468. return [beg_idx, end_idx]
  469. def get_beg_end_flag_idx(self, beg_or_end):
  470. if beg_or_end == "beg":
  471. idx = np.array(self.dict[self.beg_str])
  472. elif beg_or_end == "end":
  473. idx = np.array(self.dict[self.end_str])
  474. else:
  475. assert False, "unsupport type %s in get_beg_end_flag_idx" \
  476. % beg_or_end
  477. return idx
  478. class TableLabelDecode(object):
  479. """ """
  480. def __init__(self,
  481. character_dict_path,
  482. **kwargs):
  483. list_character, list_elem = self.load_char_elem_dict(character_dict_path)
  484. list_character = self.add_special_char(list_character)
  485. list_elem = self.add_special_char(list_elem)
  486. self.dict_character = {}
  487. self.dict_idx_character = {}
  488. for i, char in enumerate(list_character):
  489. self.dict_idx_character[i] = char
  490. self.dict_character[char] = i
  491. self.dict_elem = {}
  492. self.dict_idx_elem = {}
  493. for i, elem in enumerate(list_elem):
  494. self.dict_idx_elem[i] = elem
  495. self.dict_elem[elem] = i
  496. def load_char_elem_dict(self, character_dict_path):
  497. list_character = []
  498. list_elem = []
  499. with open(character_dict_path, "rb") as fin:
  500. lines = fin.readlines()
  501. substr = lines[0].decode('utf-8').strip("\n").strip("\r\n").split("\t")
  502. character_num = int(substr[0])
  503. elem_num = int(substr[1])
  504. for cno in range(1, 1 + character_num):
  505. character = lines[cno].decode('utf-8').strip("\n").strip("\r\n")
  506. list_character.append(character)
  507. for eno in range(1 + character_num, 1 + character_num + elem_num):
  508. elem = lines[eno].decode('utf-8').strip("\n").strip("\r\n")
  509. list_elem.append(elem)
  510. return list_character, list_elem
  511. def add_special_char(self, list_character):
  512. self.beg_str = "sos"
  513. self.end_str = "eos"
  514. list_character = [self.beg_str] + list_character + [self.end_str]
  515. return list_character
  516. def __call__(self, preds):
  517. structure_probs = preds['structure_probs']
  518. loc_preds = preds['loc_preds']
  519. if isinstance(structure_probs,torch.Tensor):
  520. structure_probs = structure_probs.numpy()
  521. if isinstance(loc_preds,torch.Tensor):
  522. loc_preds = loc_preds.numpy()
  523. structure_idx = structure_probs.argmax(axis=2)
  524. structure_probs = structure_probs.max(axis=2)
  525. structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode(structure_idx,
  526. structure_probs, 'elem')
  527. res_html_code_list = []
  528. res_loc_list = []
  529. batch_num = len(structure_str)
  530. for bno in range(batch_num):
  531. res_loc = []
  532. for sno in range(len(structure_str[bno])):
  533. text = structure_str[bno][sno]
  534. if text in ['<td>', '<td']:
  535. pos = structure_pos[bno][sno]
  536. res_loc.append(loc_preds[bno, pos])
  537. res_html_code = ''.join(structure_str[bno])
  538. res_loc = np.array(res_loc)
  539. res_html_code_list.append(res_html_code)
  540. res_loc_list.append(res_loc)
  541. return {'res_html_code': res_html_code_list, 'res_loc': res_loc_list, 'res_score_list': result_score_list,
  542. 'res_elem_idx_list': result_elem_idx_list,'structure_str_list':structure_str}
  543. def decode(self, text_index, structure_probs, char_or_elem):
  544. """convert text-label into text-index.
  545. """
  546. if char_or_elem == "char":
  547. current_dict = self.dict_idx_character
  548. else:
  549. current_dict = self.dict_idx_elem
  550. ignored_tokens = self.get_ignored_tokens('elem')
  551. beg_idx, end_idx = ignored_tokens
  552. result_list = []
  553. result_pos_list = []
  554. result_score_list = []
  555. result_elem_idx_list = []
  556. batch_size = len(text_index)
  557. for batch_idx in range(batch_size):
  558. char_list = []
  559. elem_pos_list = []
  560. elem_idx_list = []
  561. score_list = []
  562. for idx in range(len(text_index[batch_idx])):
  563. tmp_elem_idx = int(text_index[batch_idx][idx])
  564. if idx > 0 and tmp_elem_idx == end_idx:
  565. break
  566. if tmp_elem_idx in ignored_tokens:
  567. continue
  568. char_list.append(current_dict[tmp_elem_idx])
  569. elem_pos_list.append(idx)
  570. score_list.append(structure_probs[batch_idx, idx])
  571. elem_idx_list.append(tmp_elem_idx)
  572. result_list.append(char_list)
  573. result_pos_list.append(elem_pos_list)
  574. result_score_list.append(score_list)
  575. result_elem_idx_list.append(elem_idx_list)
  576. return result_list, result_pos_list, result_score_list, result_elem_idx_list
  577. def get_ignored_tokens(self, char_or_elem):
  578. beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem)
  579. end_idx = self.get_beg_end_flag_idx("end", char_or_elem)
  580. return [beg_idx, end_idx]
  581. def get_beg_end_flag_idx(self, beg_or_end, char_or_elem):
  582. if char_or_elem == "char":
  583. if beg_or_end == "beg":
  584. idx = self.dict_character[self.beg_str]
  585. elif beg_or_end == "end":
  586. idx = self.dict_character[self.end_str]
  587. else:
  588. assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \
  589. % beg_or_end
  590. elif char_or_elem == "elem":
  591. if beg_or_end == "beg":
  592. idx = self.dict_elem[self.beg_str]
  593. elif beg_or_end == "end":
  594. idx = self.dict_elem[self.end_str]
  595. else:
  596. assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \
  597. % beg_or_end
  598. else:
  599. assert False, "Unsupport type %s in char_or_elem" \
  600. % char_or_elem
  601. return idx
  602. class SARLabelDecode(BaseRecLabelDecode):
  603. """ Convert between text-label and text-index """
  604. def __init__(self, character_dict_path=None, use_space_char=False,
  605. **kwargs):
  606. super(SARLabelDecode, self).__init__(character_dict_path,
  607. use_space_char)
  608. self.rm_symbol = kwargs.get('rm_symbol', False)
  609. def add_special_char(self, dict_character):
  610. beg_end_str = "<BOS/EOS>"
  611. unknown_str = "<UKN>"
  612. padding_str = "<PAD>"
  613. dict_character = dict_character + [unknown_str]
  614. self.unknown_idx = len(dict_character) - 1
  615. dict_character = dict_character + [beg_end_str]
  616. self.start_idx = len(dict_character) - 1
  617. self.end_idx = len(dict_character) - 1
  618. dict_character = dict_character + [padding_str]
  619. self.padding_idx = len(dict_character) - 1
  620. return dict_character
  621. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  622. """ convert text-index into text-label. """
  623. result_list = []
  624. ignored_tokens = self.get_ignored_tokens()
  625. batch_size = len(text_index)
  626. for batch_idx in range(batch_size):
  627. char_list = []
  628. conf_list = []
  629. for idx in range(len(text_index[batch_idx])):
  630. if text_index[batch_idx][idx] in ignored_tokens:
  631. continue
  632. if int(text_index[batch_idx][idx]) == int(self.end_idx):
  633. if text_prob is None and idx == 0:
  634. continue
  635. else:
  636. break
  637. if is_remove_duplicate:
  638. # only for predict
  639. if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
  640. batch_idx][idx]:
  641. continue
  642. char_list.append(self.character[int(text_index[batch_idx][
  643. idx])])
  644. if text_prob is not None:
  645. conf_list.append(text_prob[batch_idx][idx])
  646. else:
  647. conf_list.append(1)
  648. text = ''.join(char_list)
  649. if self.rm_symbol:
  650. comp = re.compile('[^A-Z^a-z^0-9^\u4e00-\u9fa5]')
  651. text = text.lower()
  652. text = comp.sub('', text)
  653. result_list.append((text, np.mean(conf_list).tolist()))
  654. return result_list
  655. def __call__(self, preds, label=None, *args, **kwargs):
  656. if isinstance(preds, torch.Tensor):
  657. preds = preds.cpu().numpy()
  658. preds_idx = preds.argmax(axis=2)
  659. preds_prob = preds.max(axis=2)
  660. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  661. if label is None:
  662. return text
  663. label = self.decode(label, is_remove_duplicate=False)
  664. return text, label
  665. def get_ignored_tokens(self):
  666. return [self.padding_idx]
  667. class CANLabelDecode(BaseRecLabelDecode):
  668. """ Convert between latex-symbol and symbol-index """
  669. def __init__(self, character_dict_path=None, use_space_char=False,
  670. **kwargs):
  671. super(CANLabelDecode, self).__init__(character_dict_path,
  672. use_space_char)
  673. def decode(self, text_index, preds_prob=None):
  674. result_list = []
  675. batch_size = len(text_index)
  676. for batch_idx in range(batch_size):
  677. seq_end = text_index[batch_idx].argmin(0)
  678. idx_list = text_index[batch_idx][:seq_end].tolist()
  679. symbol_list = [self.character[idx] for idx in idx_list]
  680. probs = []
  681. if preds_prob is not None:
  682. probs = preds_prob[batch_idx][:len(symbol_list)].tolist()
  683. result_list.append([' '.join(symbol_list), probs])
  684. return result_list
  685. def __call__(self, preds, label=None, *args, **kwargs):
  686. pred_prob, _, _, _ = preds
  687. preds_idx = pred_prob.argmax(axis=2)
  688. text = self.decode(preds_idx)
  689. if label is None:
  690. return text
  691. label = self.decode(label)
  692. return text, label