|
@@ -6,9 +6,13 @@ import re
|
|
|
|
|
|
|
|
from PIL import Image, ImageOps
|
|
from PIL import Image, ImageOps
|
|
|
from typing import List, Optional, Tuple, Union, Dict, Any
|
|
from typing import List, Optional, Tuple, Union, Dict, Any
|
|
|
|
|
+
|
|
|
|
|
+from loguru import logger
|
|
|
from tokenizers import AddedToken
|
|
from tokenizers import AddedToken
|
|
|
from tokenizers import Tokenizer as TokenizerFast
|
|
from tokenizers import Tokenizer as TokenizerFast
|
|
|
|
|
|
|
|
|
|
+from mineru.model.mfr.unimernet.unimernet_hf.modeling_unimernet import fix_latex_left_right
|
|
|
|
|
+
|
|
|
|
|
|
|
|
class UniMERNetImgDecode(object):
|
|
class UniMERNetImgDecode(object):
|
|
|
"""Class for decoding images for UniMERNet, including cropping margins, resizing, and padding."""
|
|
"""Class for decoding images for UniMERNet, including cropping margins, resizing, and padding."""
|
|
@@ -589,6 +593,7 @@ class UniMERNetDecode(object):
|
|
|
replaced_formula = pattern.sub(replacer, formula)
|
|
replaced_formula = pattern.sub(replacer, formula)
|
|
|
return replaced_formula.replace('"', "")
|
|
return replaced_formula.replace('"', "")
|
|
|
|
|
|
|
|
|
|
+ UP_PATTERN = re.compile(r'\\up([a-zA-Z]+)')
|
|
|
def post_process(self, text: str) -> str:
|
|
def post_process(self, text: str) -> str:
|
|
|
"""Post-processes a string by fixing text and normalizing it.
|
|
"""Post-processes a string by fixing text and normalizing it.
|
|
|
|
|
|
|
@@ -602,6 +607,10 @@ class UniMERNetDecode(object):
|
|
|
|
|
|
|
|
text = self.remove_chinese_text_wrapping(text)
|
|
text = self.remove_chinese_text_wrapping(text)
|
|
|
text = fix_text(text)
|
|
text = fix_text(text)
|
|
|
|
|
+ text = fix_latex_left_right(text)
|
|
|
|
|
+ text = self.UP_PATTERN.sub(
|
|
|
|
|
+ lambda m: m.group(0) if m.group(1) in ["arrow", "downarrow", "lus", "silon"] else f"\\{m.group(1)}", text
|
|
|
|
|
+ )
|
|
|
text = self.normalize(text)
|
|
text = self.normalize(text)
|
|
|
return text
|
|
return text
|
|
|
|
|
|