rec_postprocess.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import torch
  16. class BaseRecLabelDecode(object):
  17. """ Convert between text-label and text-index """
  18. def __init__(self,
  19. character_dict_path=None,
  20. use_space_char=False):
  21. self.beg_str = "sos"
  22. self.end_str = "eos"
  23. self.character_str = []
  24. if character_dict_path is None:
  25. self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
  26. dict_character = list(self.character_str)
  27. else:
  28. with open(character_dict_path, "rb") as fin:
  29. lines = fin.readlines()
  30. for line in lines:
  31. line = line.decode('utf-8').strip("\n").strip("\r\n")
  32. self.character_str.append(line)
  33. if use_space_char:
  34. self.character_str.append(" ")
  35. dict_character = list(self.character_str)
  36. dict_character = self.add_special_char(dict_character)
  37. self.dict = {}
  38. for i, char in enumerate(dict_character):
  39. self.dict[char] = i
  40. self.character = dict_character
  41. def add_special_char(self, dict_character):
  42. return dict_character
  43. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  44. """ convert text-index into text-label. """
  45. result_list = []
  46. ignored_tokens = self.get_ignored_tokens()
  47. batch_size = len(text_index)
  48. for batch_idx in range(batch_size):
  49. char_list = []
  50. conf_list = []
  51. for idx in range(len(text_index[batch_idx])):
  52. if text_index[batch_idx][idx] in ignored_tokens:
  53. continue
  54. if is_remove_duplicate:
  55. # only for predict
  56. if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
  57. batch_idx][idx]:
  58. continue
  59. char_list.append(self.character[int(text_index[batch_idx][
  60. idx])])
  61. if text_prob is not None:
  62. conf_list.append(text_prob[batch_idx][idx])
  63. else:
  64. conf_list.append(1)
  65. text = ''.join(char_list)
  66. result_list.append((text, np.mean(conf_list)))
  67. return result_list
  68. def get_ignored_tokens(self):
  69. return [0] # for ctc blank
  70. class CTCLabelDecode(BaseRecLabelDecode):
  71. """ Convert between text-label and text-index """
  72. def __init__(self,
  73. character_dict_path=None,
  74. use_space_char=False,
  75. **kwargs):
  76. super(CTCLabelDecode, self).__init__(character_dict_path,
  77. use_space_char)
  78. def __call__(self, preds, label=None, *args, **kwargs):
  79. if isinstance(preds, torch.Tensor):
  80. preds = preds.numpy()
  81. preds_idx = preds.argmax(axis=2)
  82. preds_prob = preds.max(axis=2)
  83. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
  84. if label is None:
  85. return text
  86. label = self.decode(label)
  87. return text, label
  88. def add_special_char(self, dict_character):
  89. dict_character = ['blank'] + dict_character
  90. return dict_character
  91. class NRTRLabelDecode(BaseRecLabelDecode):
  92. """ Convert between text-label and text-index """
  93. def __init__(self, character_dict_path=None, use_space_char=True, **kwargs):
  94. super(NRTRLabelDecode, self).__init__(character_dict_path,
  95. use_space_char)
  96. def __call__(self, preds, label=None, *args, **kwargs):
  97. if len(preds) == 2:
  98. preds_id = preds[0]
  99. preds_prob = preds[1]
  100. if isinstance(preds_id, torch.Tensor):
  101. preds_id = preds_id.numpy()
  102. if isinstance(preds_prob, torch.Tensor):
  103. preds_prob = preds_prob.numpy()
  104. if preds_id[0][0] == 2:
  105. preds_idx = preds_id[:, 1:]
  106. preds_prob = preds_prob[:, 1:]
  107. else:
  108. preds_idx = preds_id
  109. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  110. if label is None:
  111. return text
  112. label = self.decode(label[:, 1:])
  113. else:
  114. if isinstance(preds, torch.Tensor):
  115. preds = preds.numpy()
  116. preds_idx = preds.argmax(axis=2)
  117. preds_prob = preds.max(axis=2)
  118. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  119. if label is None:
  120. return text
  121. label = self.decode(label[:, 1:])
  122. return text, label
  123. def add_special_char(self, dict_character):
  124. dict_character = ['blank', '<unk>', '<s>', '</s>'] + dict_character
  125. return dict_character
  126. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  127. """ convert text-index into text-label. """
  128. result_list = []
  129. batch_size = len(text_index)
  130. for batch_idx in range(batch_size):
  131. char_list = []
  132. conf_list = []
  133. for idx in range(len(text_index[batch_idx])):
  134. try:
  135. char_idx = self.character[int(text_index[batch_idx][idx])]
  136. except:
  137. continue
  138. if char_idx == '</s>': # end
  139. break
  140. char_list.append(char_idx)
  141. if text_prob is not None:
  142. conf_list.append(text_prob[batch_idx][idx])
  143. else:
  144. conf_list.append(1)
  145. text = ''.join(char_list)
  146. result_list.append((text.lower(), np.mean(conf_list).tolist()))
  147. return result_list
  148. class ViTSTRLabelDecode(NRTRLabelDecode):
  149. """ Convert between text-label and text-index """
  150. def __init__(self, character_dict_path=None, use_space_char=False,
  151. **kwargs):
  152. super(ViTSTRLabelDecode, self).__init__(character_dict_path,
  153. use_space_char)
  154. def __call__(self, preds, label=None, *args, **kwargs):
  155. if isinstance(preds, torch.Tensor):
  156. preds = preds[:, 1:].numpy()
  157. else:
  158. preds = preds[:, 1:]
  159. preds_idx = preds.argmax(axis=2)
  160. preds_prob = preds.max(axis=2)
  161. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  162. if label is None:
  163. return text
  164. label = self.decode(label[:, 1:])
  165. return text, label
  166. def add_special_char(self, dict_character):
  167. dict_character = ['<s>', '</s>'] + dict_character
  168. return dict_character
  169. class AttnLabelDecode(BaseRecLabelDecode):
  170. """ Convert between text-label and text-index """
  171. def __init__(self,
  172. character_dict_path=None,
  173. use_space_char=False,
  174. **kwargs):
  175. super(AttnLabelDecode, self).__init__(character_dict_path,
  176. use_space_char)
  177. def add_special_char(self, dict_character):
  178. self.beg_str = "sos"
  179. self.end_str = "eos"
  180. dict_character = dict_character
  181. dict_character = [self.beg_str] + dict_character + [self.end_str]
  182. return dict_character
  183. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  184. """ convert text-index into text-label. """
  185. result_list = []
  186. ignored_tokens = self.get_ignored_tokens()
  187. [beg_idx, end_idx] = self.get_ignored_tokens()
  188. batch_size = len(text_index)
  189. for batch_idx in range(batch_size):
  190. char_list = []
  191. conf_list = []
  192. for idx in range(len(text_index[batch_idx])):
  193. if text_index[batch_idx][idx] in ignored_tokens:
  194. continue
  195. if int(text_index[batch_idx][idx]) == int(end_idx):
  196. break
  197. if is_remove_duplicate:
  198. # only for predict
  199. if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
  200. batch_idx][idx]:
  201. continue
  202. char_list.append(self.character[int(text_index[batch_idx][
  203. idx])])
  204. if text_prob is not None:
  205. conf_list.append(text_prob[batch_idx][idx])
  206. else:
  207. conf_list.append(1)
  208. text = ''.join(char_list)
  209. result_list.append((text, np.mean(conf_list)))
  210. return result_list
  211. def __call__(self, preds, label=None, *args, **kwargs):
  212. """
  213. text = self.decode(text)
  214. if label is None:
  215. return text
  216. else:
  217. label = self.decode(label, is_remove_duplicate=False)
  218. return text, label
  219. """
  220. if isinstance(preds, torch.Tensor):
  221. preds = preds.cpu().numpy()
  222. preds_idx = preds.argmax(axis=2)
  223. preds_prob = preds.max(axis=2)
  224. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  225. if label is None:
  226. return text
  227. label = self.decode(label, is_remove_duplicate=False)
  228. return text, label
  229. def get_ignored_tokens(self):
  230. beg_idx = self.get_beg_end_flag_idx("beg")
  231. end_idx = self.get_beg_end_flag_idx("end")
  232. return [beg_idx, end_idx]
  233. def get_beg_end_flag_idx(self, beg_or_end):
  234. if beg_or_end == "beg":
  235. idx = np.array(self.dict[self.beg_str])
  236. elif beg_or_end == "end":
  237. idx = np.array(self.dict[self.end_str])
  238. else:
  239. assert False, "unsupport type %s in get_beg_end_flag_idx" \
  240. % beg_or_end
  241. return idx
  242. class RFLLabelDecode(BaseRecLabelDecode):
  243. """ Convert between text-label and text-index """
  244. def __init__(self, character_dict_path=None, use_space_char=False,
  245. **kwargs):
  246. super(RFLLabelDecode, self).__init__(character_dict_path,
  247. use_space_char)
  248. def add_special_char(self, dict_character):
  249. self.beg_str = "sos"
  250. self.end_str = "eos"
  251. dict_character = dict_character
  252. dict_character = [self.beg_str] + dict_character + [self.end_str]
  253. return dict_character
  254. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  255. """ convert text-index into text-label. """
  256. result_list = []
  257. ignored_tokens = self.get_ignored_tokens()
  258. [beg_idx, end_idx] = self.get_ignored_tokens()
  259. batch_size = len(text_index)
  260. for batch_idx in range(batch_size):
  261. char_list = []
  262. conf_list = []
  263. for idx in range(len(text_index[batch_idx])):
  264. if text_index[batch_idx][idx] in ignored_tokens:
  265. continue
  266. if int(text_index[batch_idx][idx]) == int(end_idx):
  267. break
  268. if is_remove_duplicate:
  269. # only for predict
  270. if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
  271. batch_idx][idx]:
  272. continue
  273. char_list.append(self.character[int(text_index[batch_idx][
  274. idx])])
  275. if text_prob is not None:
  276. conf_list.append(text_prob[batch_idx][idx])
  277. else:
  278. conf_list.append(1)
  279. text = ''.join(char_list)
  280. result_list.append((text, np.mean(conf_list).tolist()))
  281. return result_list
  282. def __call__(self, preds, label=None, *args, **kwargs):
  283. # if seq_outputs is not None:
  284. if isinstance(preds, tuple) or isinstance(preds, list):
  285. cnt_outputs, seq_outputs = preds
  286. if isinstance(seq_outputs, torch.Tensor):
  287. seq_outputs = seq_outputs.numpy()
  288. preds_idx = seq_outputs.argmax(axis=2)
  289. preds_prob = seq_outputs.max(axis=2)
  290. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  291. if label is None:
  292. return text
  293. label = self.decode(label, is_remove_duplicate=False)
  294. return text, label
  295. else:
  296. cnt_outputs = preds
  297. if isinstance(cnt_outputs, torch.Tensor):
  298. cnt_outputs = cnt_outputs.numpy()
  299. cnt_length = []
  300. for lens in cnt_outputs:
  301. length = round(np.sum(lens))
  302. cnt_length.append(length)
  303. if label is None:
  304. return cnt_length
  305. label = self.decode(label, is_remove_duplicate=False)
  306. length = [len(res[0]) for res in label]
  307. return cnt_length, length
  308. def get_ignored_tokens(self):
  309. beg_idx = self.get_beg_end_flag_idx("beg")
  310. end_idx = self.get_beg_end_flag_idx("end")
  311. return [beg_idx, end_idx]
  312. def get_beg_end_flag_idx(self, beg_or_end):
  313. if beg_or_end == "beg":
  314. idx = np.array(self.dict[self.beg_str])
  315. elif beg_or_end == "end":
  316. idx = np.array(self.dict[self.end_str])
  317. else:
  318. assert False, "unsupport type %s in get_beg_end_flag_idx" \
  319. % beg_or_end
  320. return idx
  321. class SRNLabelDecode(BaseRecLabelDecode):
  322. """ Convert between text-label and text-index """
  323. def __init__(self,
  324. character_dict_path=None,
  325. use_space_char=False,
  326. **kwargs):
  327. self.max_text_length = kwargs.get('max_text_length', 25)
  328. super(SRNLabelDecode, self).__init__(character_dict_path,
  329. use_space_char)
  330. def __call__(self, preds, label=None, *args, **kwargs):
  331. pred = preds['predict']
  332. char_num = len(self.character_str) + 2
  333. if isinstance(pred, torch.Tensor):
  334. pred = pred.numpy()
  335. pred = np.reshape(pred, [-1, char_num])
  336. preds_idx = np.argmax(pred, axis=1)
  337. preds_prob = np.max(pred, axis=1)
  338. preds_idx = np.reshape(preds_idx, [-1, self.max_text_length])
  339. preds_prob = np.reshape(preds_prob, [-1, self.max_text_length])
  340. text = self.decode(preds_idx, preds_prob)
  341. if label is None:
  342. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  343. return text
  344. label = self.decode(label)
  345. return text, label
  346. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  347. """ convert text-index into text-label. """
  348. result_list = []
  349. ignored_tokens = self.get_ignored_tokens()
  350. batch_size = len(text_index)
  351. for batch_idx in range(batch_size):
  352. char_list = []
  353. conf_list = []
  354. for idx in range(len(text_index[batch_idx])):
  355. if text_index[batch_idx][idx] in ignored_tokens:
  356. continue
  357. if is_remove_duplicate:
  358. # only for predict
  359. if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
  360. batch_idx][idx]:
  361. continue
  362. char_list.append(self.character[int(text_index[batch_idx][
  363. idx])])
  364. if text_prob is not None:
  365. conf_list.append(text_prob[batch_idx][idx])
  366. else:
  367. conf_list.append(1)
  368. text = ''.join(char_list)
  369. result_list.append((text, np.mean(conf_list)))
  370. return result_list
  371. def add_special_char(self, dict_character):
  372. dict_character = dict_character + [self.beg_str, self.end_str]
  373. return dict_character
  374. def get_ignored_tokens(self):
  375. beg_idx = self.get_beg_end_flag_idx("beg")
  376. end_idx = self.get_beg_end_flag_idx("end")
  377. return [beg_idx, end_idx]
  378. def get_beg_end_flag_idx(self, beg_or_end):
  379. if beg_or_end == "beg":
  380. idx = np.array(self.dict[self.beg_str])
  381. elif beg_or_end == "end":
  382. idx = np.array(self.dict[self.end_str])
  383. else:
  384. assert False, "unsupport type %s in get_beg_end_flag_idx" \
  385. % beg_or_end
  386. return idx
  387. class TableLabelDecode(object):
  388. """ """
  389. def __init__(self,
  390. character_dict_path,
  391. **kwargs):
  392. list_character, list_elem = self.load_char_elem_dict(character_dict_path)
  393. list_character = self.add_special_char(list_character)
  394. list_elem = self.add_special_char(list_elem)
  395. self.dict_character = {}
  396. self.dict_idx_character = {}
  397. for i, char in enumerate(list_character):
  398. self.dict_idx_character[i] = char
  399. self.dict_character[char] = i
  400. self.dict_elem = {}
  401. self.dict_idx_elem = {}
  402. for i, elem in enumerate(list_elem):
  403. self.dict_idx_elem[i] = elem
  404. self.dict_elem[elem] = i
  405. def load_char_elem_dict(self, character_dict_path):
  406. list_character = []
  407. list_elem = []
  408. with open(character_dict_path, "rb") as fin:
  409. lines = fin.readlines()
  410. substr = lines[0].decode('utf-8').strip("\n").strip("\r\n").split("\t")
  411. character_num = int(substr[0])
  412. elem_num = int(substr[1])
  413. for cno in range(1, 1 + character_num):
  414. character = lines[cno].decode('utf-8').strip("\n").strip("\r\n")
  415. list_character.append(character)
  416. for eno in range(1 + character_num, 1 + character_num + elem_num):
  417. elem = lines[eno].decode('utf-8').strip("\n").strip("\r\n")
  418. list_elem.append(elem)
  419. return list_character, list_elem
  420. def add_special_char(self, list_character):
  421. self.beg_str = "sos"
  422. self.end_str = "eos"
  423. list_character = [self.beg_str] + list_character + [self.end_str]
  424. return list_character
  425. def __call__(self, preds):
  426. structure_probs = preds['structure_probs']
  427. loc_preds = preds['loc_preds']
  428. if isinstance(structure_probs,torch.Tensor):
  429. structure_probs = structure_probs.numpy()
  430. if isinstance(loc_preds,torch.Tensor):
  431. loc_preds = loc_preds.numpy()
  432. structure_idx = structure_probs.argmax(axis=2)
  433. structure_probs = structure_probs.max(axis=2)
  434. structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode(structure_idx,
  435. structure_probs, 'elem')
  436. res_html_code_list = []
  437. res_loc_list = []
  438. batch_num = len(structure_str)
  439. for bno in range(batch_num):
  440. res_loc = []
  441. for sno in range(len(structure_str[bno])):
  442. text = structure_str[bno][sno]
  443. if text in ['<td>', '<td']:
  444. pos = structure_pos[bno][sno]
  445. res_loc.append(loc_preds[bno, pos])
  446. res_html_code = ''.join(structure_str[bno])
  447. res_loc = np.array(res_loc)
  448. res_html_code_list.append(res_html_code)
  449. res_loc_list.append(res_loc)
  450. return {'res_html_code': res_html_code_list, 'res_loc': res_loc_list, 'res_score_list': result_score_list,
  451. 'res_elem_idx_list': result_elem_idx_list,'structure_str_list':structure_str}
  452. def decode(self, text_index, structure_probs, char_or_elem):
  453. """convert text-label into text-index.
  454. """
  455. if char_or_elem == "char":
  456. current_dict = self.dict_idx_character
  457. else:
  458. current_dict = self.dict_idx_elem
  459. ignored_tokens = self.get_ignored_tokens('elem')
  460. beg_idx, end_idx = ignored_tokens
  461. result_list = []
  462. result_pos_list = []
  463. result_score_list = []
  464. result_elem_idx_list = []
  465. batch_size = len(text_index)
  466. for batch_idx in range(batch_size):
  467. char_list = []
  468. elem_pos_list = []
  469. elem_idx_list = []
  470. score_list = []
  471. for idx in range(len(text_index[batch_idx])):
  472. tmp_elem_idx = int(text_index[batch_idx][idx])
  473. if idx > 0 and tmp_elem_idx == end_idx:
  474. break
  475. if tmp_elem_idx in ignored_tokens:
  476. continue
  477. char_list.append(current_dict[tmp_elem_idx])
  478. elem_pos_list.append(idx)
  479. score_list.append(structure_probs[batch_idx, idx])
  480. elem_idx_list.append(tmp_elem_idx)
  481. result_list.append(char_list)
  482. result_pos_list.append(elem_pos_list)
  483. result_score_list.append(score_list)
  484. result_elem_idx_list.append(elem_idx_list)
  485. return result_list, result_pos_list, result_score_list, result_elem_idx_list
  486. def get_ignored_tokens(self, char_or_elem):
  487. beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem)
  488. end_idx = self.get_beg_end_flag_idx("end", char_or_elem)
  489. return [beg_idx, end_idx]
  490. def get_beg_end_flag_idx(self, beg_or_end, char_or_elem):
  491. if char_or_elem == "char":
  492. if beg_or_end == "beg":
  493. idx = self.dict_character[self.beg_str]
  494. elif beg_or_end == "end":
  495. idx = self.dict_character[self.end_str]
  496. else:
  497. assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \
  498. % beg_or_end
  499. elif char_or_elem == "elem":
  500. if beg_or_end == "beg":
  501. idx = self.dict_elem[self.beg_str]
  502. elif beg_or_end == "end":
  503. idx = self.dict_elem[self.end_str]
  504. else:
  505. assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \
  506. % beg_or_end
  507. else:
  508. assert False, "Unsupport type %s in char_or_elem" \
  509. % char_or_elem
  510. return idx
  511. class SARLabelDecode(BaseRecLabelDecode):
  512. """ Convert between text-label and text-index """
  513. def __init__(self, character_dict_path=None, use_space_char=False,
  514. **kwargs):
  515. super(SARLabelDecode, self).__init__(character_dict_path,
  516. use_space_char)
  517. self.rm_symbol = kwargs.get('rm_symbol', False)
  518. def add_special_char(self, dict_character):
  519. beg_end_str = "<BOS/EOS>"
  520. unknown_str = "<UKN>"
  521. padding_str = "<PAD>"
  522. dict_character = dict_character + [unknown_str]
  523. self.unknown_idx = len(dict_character) - 1
  524. dict_character = dict_character + [beg_end_str]
  525. self.start_idx = len(dict_character) - 1
  526. self.end_idx = len(dict_character) - 1
  527. dict_character = dict_character + [padding_str]
  528. self.padding_idx = len(dict_character) - 1
  529. return dict_character
  530. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  531. """ convert text-index into text-label. """
  532. result_list = []
  533. ignored_tokens = self.get_ignored_tokens()
  534. batch_size = len(text_index)
  535. for batch_idx in range(batch_size):
  536. char_list = []
  537. conf_list = []
  538. for idx in range(len(text_index[batch_idx])):
  539. if text_index[batch_idx][idx] in ignored_tokens:
  540. continue
  541. if int(text_index[batch_idx][idx]) == int(self.end_idx):
  542. if text_prob is None and idx == 0:
  543. continue
  544. else:
  545. break
  546. if is_remove_duplicate:
  547. # only for predict
  548. if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
  549. batch_idx][idx]:
  550. continue
  551. char_list.append(self.character[int(text_index[batch_idx][
  552. idx])])
  553. if text_prob is not None:
  554. conf_list.append(text_prob[batch_idx][idx])
  555. else:
  556. conf_list.append(1)
  557. text = ''.join(char_list)
  558. if self.rm_symbol:
  559. comp = re.compile('[^A-Z^a-z^0-9^\u4e00-\u9fa5]')
  560. text = text.lower()
  561. text = comp.sub('', text)
  562. result_list.append((text, np.mean(conf_list).tolist()))
  563. return result_list
  564. def __call__(self, preds, label=None, *args, **kwargs):
  565. if isinstance(preds, torch.Tensor):
  566. preds = preds.cpu().numpy()
  567. preds_idx = preds.argmax(axis=2)
  568. preds_prob = preds.max(axis=2)
  569. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  570. if label is None:
  571. return text
  572. label = self.decode(label, is_remove_duplicate=False)
  573. return text, label
  574. def get_ignored_tokens(self):
  575. return [self.padding_idx]
  576. class CANLabelDecode(BaseRecLabelDecode):
  577. """ Convert between latex-symbol and symbol-index """
  578. def __init__(self, character_dict_path=None, use_space_char=False,
  579. **kwargs):
  580. super(CANLabelDecode, self).__init__(character_dict_path,
  581. use_space_char)
  582. def decode(self, text_index, preds_prob=None):
  583. result_list = []
  584. batch_size = len(text_index)
  585. for batch_idx in range(batch_size):
  586. seq_end = text_index[batch_idx].argmin(0)
  587. idx_list = text_index[batch_idx][:seq_end].tolist()
  588. symbol_list = [self.character[idx] for idx in idx_list]
  589. probs = []
  590. if preds_prob is not None:
  591. probs = preds_prob[batch_idx][:len(symbol_list)].tolist()
  592. result_list.append([' '.join(symbol_list), probs])
  593. return result_list
  594. def __call__(self, preds, label=None, *args, **kwargs):
  595. pred_prob, _, _, _ = preds
  596. preds_idx = pred_prob.argmax(axis=2)
  597. text = self.decode(preds_idx)
  598. if label is None:
  599. return text
  600. label = self.decode(label)
  601. return text, label