|
|
@@ -15,9 +15,7 @@
|
|
|
|
|
|
import json
|
|
|
import math
|
|
|
-import os
|
|
|
import re
|
|
|
-import tempfile
|
|
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
|
|
|
|
import numpy as np
|
|
|
@@ -325,14 +323,9 @@ class LaTeXOCRDecode(object):
|
|
|
**kwargs: Additional keyword arguments for initialization.
|
|
|
"""
|
|
|
super(LaTeXOCRDecode, self).__init__()
|
|
|
- temp_path = tempfile.gettempdir()
|
|
|
- rec_char_dict_path = os.path.join(temp_path, "latexocr_tokenizer.json")
|
|
|
- try:
|
|
|
- with open(rec_char_dict_path, "w") as f:
|
|
|
- json.dump(character_list, f)
|
|
|
- except Exception as e:
|
|
|
- print(f"创建 latexocr_tokenizer.json 文件失败, 原因{str(e)}")
|
|
|
- self.tokenizer = TokenizerFast.from_file(rec_char_dict_path)
|
|
|
+ fast_tokenizer_str = json.dumps(character_list)
|
|
|
+ fast_tokenizer_buffer = fast_tokenizer_str.encode("utf-8")
|
|
|
+ self.tokenizer = TokenizerFast.from_buffer(fast_tokenizer_buffer)
|
|
|
|
|
|
def post_process(self, s: str) -> str:
|
|
|
"""Post-processes the decoded LaTeX string.
|
|
|
@@ -372,7 +365,7 @@ class LaTeXOCRDecode(object):
|
|
|
dec = [self.tokenizer.decode(tok) for tok in tokens]
|
|
|
dec_str_list = [
|
|
|
"".join(detok.split(" "))
|
|
|
- .replace("Ġ", " ")
|
|
|
+ .replace("臓", " ")
|
|
|
.replace("[EOS]", "")
|
|
|
.replace("[BOS]", "")
|
|
|
.replace("[PAD]", "")
|
|
|
@@ -631,80 +624,65 @@ class UniMERNetDecode(object):
|
|
|
self.pad_token_type_id = 0
|
|
|
self.pad_to_multiple_of = None
|
|
|
|
|
|
- with tempfile.NamedTemporaryFile(
|
|
|
- mode="w", suffix=".json", delete=True
|
|
|
- ) as temp_file1, tempfile.NamedTemporaryFile(
|
|
|
- mode="w", suffix=".json", delete=True
|
|
|
- ) as temp_file2:
|
|
|
- fast_tokenizer_file = temp_file1.name
|
|
|
- tokenizer_config_file = temp_file2.name
|
|
|
- try:
|
|
|
- with open(fast_tokenizer_file, "w") as f:
|
|
|
- json.dump(character_list["fast_tokenizer_file"], f)
|
|
|
- with open(tokenizer_config_file, "w") as f:
|
|
|
- json.dump(character_list["tokenizer_config_file"], f)
|
|
|
- except Exception as e:
|
|
|
- print(
|
|
|
- f"创建 tokenizer.json 和 tokenizer_config.json 文件失败, 原因{str(e)}"
|
|
|
- )
|
|
|
-
|
|
|
- self.tokenizer = TokenizerFast.from_file(fast_tokenizer_file)
|
|
|
- added_tokens_decoder = {}
|
|
|
- added_tokens_map = {}
|
|
|
- if tokenizer_config_file is not None:
|
|
|
- with open(
|
|
|
- tokenizer_config_file, encoding="utf-8"
|
|
|
- ) as tokenizer_config_handle:
|
|
|
- init_kwargs = json.load(tokenizer_config_handle)
|
|
|
- if "added_tokens_decoder" in init_kwargs:
|
|
|
- for idx, token in init_kwargs["added_tokens_decoder"].items():
|
|
|
- if isinstance(token, dict):
|
|
|
- token = AddedToken(**token)
|
|
|
- if isinstance(token, AddedToken):
|
|
|
- added_tokens_decoder[int(idx)] = token
|
|
|
- added_tokens_map[str(token)] = token
|
|
|
- else:
|
|
|
- raise ValueError(
|
|
|
- f"Found a {token.__class__} in the saved `added_tokens_decoder`, should be a dictionary or an AddedToken instance"
|
|
|
- )
|
|
|
- init_kwargs["added_tokens_decoder"] = added_tokens_decoder
|
|
|
- added_tokens_decoder = init_kwargs.pop("added_tokens_decoder", {})
|
|
|
- tokens_to_add = [
|
|
|
- token
|
|
|
- for index, token in sorted(
|
|
|
- added_tokens_decoder.items(), key=lambda x: x[0]
|
|
|
+ fast_tokenizer_str = json.dumps(character_list["fast_tokenizer_file"])
|
|
|
+ fast_tokenizer_buffer = fast_tokenizer_str.encode("utf-8")
|
|
|
+ self.tokenizer = TokenizerFast.from_buffer(fast_tokenizer_buffer)
|
|
|
+ tokenizer_config = (
|
|
|
+ character_list["tokenizer_config_file"]
|
|
|
+ if "tokenizer_config_file" in character_list
|
|
|
+ else None
|
|
|
+ )
|
|
|
+ added_tokens_decoder = {}
|
|
|
+ added_tokens_map = {}
|
|
|
+ if tokenizer_config is not None:
|
|
|
+ init_kwargs = tokenizer_config
|
|
|
+ if "added_tokens_decoder" in init_kwargs:
|
|
|
+ for idx, token in init_kwargs["added_tokens_decoder"].items():
|
|
|
+ if isinstance(token, dict):
|
|
|
+ token = AddedToken(**token)
|
|
|
+ if isinstance(token, AddedToken):
|
|
|
+ added_tokens_decoder[int(idx)] = token
|
|
|
+ added_tokens_map[str(token)] = token
|
|
|
+ else:
|
|
|
+ raise ValueError(
|
|
|
+ f"Found a {token.__class__} in the saved `added_tokens_decoder`, should be a dictionary or an AddedToken instance"
|
|
|
)
|
|
|
- if token not in added_tokens_decoder
|
|
|
- ]
|
|
|
- added_tokens_encoder = self.added_tokens_encoder(
|
|
|
- added_tokens_decoder
|
|
|
+ init_kwargs["added_tokens_decoder"] = added_tokens_decoder
|
|
|
+ added_tokens_decoder = init_kwargs.pop("added_tokens_decoder", {})
|
|
|
+ tokens_to_add = [
|
|
|
+ token
|
|
|
+ for index, token in sorted(
|
|
|
+ added_tokens_decoder.items(), key=lambda x: x[0]
|
|
|
+ )
|
|
|
+ if token not in added_tokens_decoder
|
|
|
+ ]
|
|
|
+ added_tokens_encoder = self.added_tokens_encoder(added_tokens_decoder)
|
|
|
+ encoder = list(added_tokens_encoder.keys()) + [
|
|
|
+ str(token) for token in tokens_to_add
|
|
|
+ ]
|
|
|
+ tokens_to_add += [
|
|
|
+ token
|
|
|
+ for token in self.all_special_tokens_extended
|
|
|
+ if token not in encoder and token not in tokens_to_add
|
|
|
+ ]
|
|
|
+ if len(tokens_to_add) > 0:
|
|
|
+ is_last_special = None
|
|
|
+ tokens = []
|
|
|
+ special_tokens = self.all_special_tokens
|
|
|
+ for token in tokens_to_add:
|
|
|
+ is_special = (
|
|
|
+ (token.special or str(token) in special_tokens)
|
|
|
+ if isinstance(token, AddedToken)
|
|
|
+ else str(token) in special_tokens
|
|
|
)
|
|
|
- encoder = list(added_tokens_encoder.keys()) + [
|
|
|
- str(token) for token in tokens_to_add
|
|
|
- ]
|
|
|
- tokens_to_add += [
|
|
|
- token
|
|
|
- for token in self.all_special_tokens_extended
|
|
|
- if token not in encoder and token not in tokens_to_add
|
|
|
- ]
|
|
|
- if len(tokens_to_add) > 0:
|
|
|
- is_last_special = None
|
|
|
- tokens = []
|
|
|
- special_tokens = self.all_special_tokens
|
|
|
- for token in tokens_to_add:
|
|
|
- is_special = (
|
|
|
- (token.special or str(token) in special_tokens)
|
|
|
- if isinstance(token, AddedToken)
|
|
|
- else str(token) in special_tokens
|
|
|
- )
|
|
|
- if is_last_special is None or is_last_special == is_special:
|
|
|
- tokens.append(token)
|
|
|
- else:
|
|
|
- self._add_tokens(tokens, special_tokens=is_last_special)
|
|
|
- tokens = [token]
|
|
|
- is_last_special = is_special
|
|
|
- if tokens:
|
|
|
- self._add_tokens(tokens, special_tokens=is_last_special)
|
|
|
+ if is_last_special is None or is_last_special == is_special:
|
|
|
+ tokens.append(token)
|
|
|
+ else:
|
|
|
+ self._add_tokens(tokens, special_tokens=is_last_special)
|
|
|
+ tokens = [token]
|
|
|
+ is_last_special = is_special
|
|
|
+ if tokens:
|
|
|
+ self._add_tokens(tokens, special_tokens=is_last_special)
|
|
|
|
|
|
def _add_tokens(
|
|
|
self, new_tokens: "List[Union[AddedToken, str]]", special_tokens: bool = False
|
|
|
@@ -820,7 +798,7 @@ class UniMERNetDecode(object):
|
|
|
for i in reversed(range(len(toks[b]))):
|
|
|
if toks[b][i] is None:
|
|
|
toks[b][i] = ""
|
|
|
- toks[b][i] = toks[b][i].replace("Ġ", " ").strip()
|
|
|
+ toks[b][i] = toks[b][i].replace("臓", " ").strip()
|
|
|
if toks[b][i] in (
|
|
|
[
|
|
|
self.tokenizer.bos_token,
|
|
|
@@ -876,6 +854,15 @@ class UniMERNetDecode(object):
|
|
|
break
|
|
|
return s
|
|
|
|
|
|
+ def remove_chinese_text_wrapping(self, formula):
|
|
|
+ pattern = re.compile(r"\\text\s*{\s*([^}]*?[\u4e00-\u9fff]+[^}]*?)\s*}")
|
|
|
+
|
|
|
+ def replacer(match):
|
|
|
+ return match.group(1)
|
|
|
+
|
|
|
+ replaced_formula = pattern.sub(replacer, formula)
|
|
|
+ return replaced_formula.replace('"', "")
|
|
|
+
|
|
|
def post_process(self, text: str) -> str:
|
|
|
"""Post-processes a string by fixing text and normalizing it.
|
|
|
|
|
|
@@ -887,8 +874,12 @@ class UniMERNetDecode(object):
|
|
|
"""
|
|
|
from ftfy import fix_text
|
|
|
|
|
|
+ text = self.remove_chinese_text_wrapping(text)
|
|
|
text = fix_text(text)
|
|
|
+ print("=" * 100)
|
|
|
+ print(text)
|
|
|
text = self.normalize(text)
|
|
|
+ print(text)
|
|
|
return text
|
|
|
|
|
|
def __call__(
|