Browse Source

Merge pull request #3803 from myhloli/dev

Dev
Xiaomeng Zhao 3 weeks ago
parent
commit
5c1ca9271e

+ 1 - 1
mineru.template.json

@@ -17,7 +17,7 @@
         "title_aided": {
             "api_key": "your_api_key",
             "base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
-            "model": "qwen2.5-32b-instruct",
+            "model": "qwen3-next-80b-a3b-instruct",
             "enable": false
         }
     },

+ 4 - 2
mineru/backend/pipeline/model_init.py

@@ -17,8 +17,10 @@ from ...model.table.rec.unet_table.main import UnetTableModel
 from ...utils.enum_class import ModelPath
 from ...utils.models_download_utils import auto_download_and_get_model_root_path
 
-MFR_MODEL = "unimernet_small"
-# MFR_MODEL = "pp_formulanet_plus_m"
+MFR_MODEL = os.getenv('MINERU_MFR_MODEL', None)
+if MFR_MODEL is None:
+    # MFR_MODEL = "unimernet_small"
+    MFR_MODEL = "pp_formulanet_plus_m"
 
 
 def img_orientation_cls_model_init():

+ 8 - 1
mineru/backend/pipeline/model_json_to_middle_json.py

@@ -231,7 +231,14 @@ def result_to_middle_json(model_list, images_list, pdf_doc, image_writer, lang=N
     para_split(middle_json["pdf_info"])
 
     """表格跨页合并"""
-    merge_table(middle_json["pdf_info"])
+    is_merge_table = os.getenv('MINERU_MERGE_TABLE', 'true')
+    if is_merge_table.lower() == 'true':
+        merge_table(middle_json["pdf_info"])
+    elif is_merge_table.lower() == 'false':
+        pass
+    else:
+        logger.warning(f'unknown MINERU_MERGE_TABLE config: {is_merge_table}, pass')
+        pass
 
     """llm优化"""
     llm_aided_config = get_llm_aided_config()

+ 8 - 1
mineru/backend/vlm/model_output_to_middle_json.py

@@ -110,7 +110,14 @@ def result_to_middle_json(model_output_blocks_list, images_list, pdf_doc, image_
     """表格跨页合并"""
     table_enable = get_table_enable(os.getenv('MINERU_VLM_TABLE_ENABLE', 'True').lower() == 'true')
     if table_enable:
-        merge_table(middle_json["pdf_info"])
+        is_merge_table = os.getenv('MINERU_MERGE_TABLE', 'true')
+        if is_merge_table.lower() == 'true':
+            merge_table(middle_json["pdf_info"])
+        elif is_merge_table.lower() == 'false':
+            pass
+        else:
+            logger.warning(f'unknown MINERU_MERGE_TABLE config: {is_merge_table}, pass')
+            pass
 
     """llm优化标题分级"""
     if heading_level_import_success:

+ 9 - 6
mineru/model/mfr/pp_formulanet_plus_m/predict_formula.py

@@ -45,7 +45,6 @@ class FormulaRecognizer(BaseOCRV20):
         super(FormulaRecognizer, self).__init__(network_config)
 
         self.load_state_dict(weights)
-        # device = "cpu"
         self.device = torch.device(device) if isinstance(device, str) else device
         self.net.to(self.device)
         self.net.eval()
@@ -65,18 +64,22 @@ class FormulaRecognizer(BaseOCRV20):
         )
 
     def predict(self, img_list, batch_size: int = 64):
+        # Reduce batch size by 50% to avoid potential memory issues during inference.
+        batch_size = int(0.5 * batch_size)
         batch_imgs = self.pre_tfs["UniMERNetImgDecode"](imgs=img_list)
         batch_imgs = self.pre_tfs["UniMERNetTestTransform"](imgs=batch_imgs)
         batch_imgs = self.pre_tfs["LatexImageFormat"](imgs=batch_imgs)
-        x = self.pre_tfs["ToBatch"](imgs=batch_imgs)
-        x = torch.from_numpy(x[0]).to(self.device)
+        inp = self.pre_tfs["ToBatch"](imgs=batch_imgs)
+        inp = torch.from_numpy(inp[0])
+        inp = inp.to(self.device)
         rec_formula = []
         with torch.no_grad():
-            with tqdm(total=len(x), desc="Formula Predict") as pbar:
-                for index in range(0, len(x), batch_size):
-                    batch_data = x[index: index + batch_size]
+            with tqdm(total=len(inp), desc="MFR Predict") as pbar:
+                for index in range(0, len(inp), batch_size):
+                    batch_data = inp[index: index + batch_size]
                     batch_preds = [self.net(batch_data)]
                     batch_preds = [p.reshape([-1]) for p in batch_preds[0]]
+                    batch_preds = [bp.cpu().numpy() for bp in batch_preds]
                     rec_formula += self.post_op(batch_preds)
                     pbar.update(len(batch_preds))
         return rec_formula

+ 23 - 1
mineru/model/mfr/pp_formulanet_plus_m/processors.py

@@ -6,9 +6,14 @@ import re
 
 from PIL import Image, ImageOps
 from typing import List, Optional, Tuple, Union, Dict, Any
+
+from loguru import logger
 from tokenizers import AddedToken
 from tokenizers import Tokenizer as TokenizerFast
 
+from mineru.model.mfr.utils import fix_latex_left_right, fix_latex_environments, remove_up_commands, \
+    remove_unsupported_commands
+
 
 class UniMERNetImgDecode(object):
     """Class for decoding images for UniMERNet, including cropping margins, resizing, and padding."""
@@ -602,7 +607,24 @@ class UniMERNetDecode(object):
 
         text = self.remove_chinese_text_wrapping(text)
         text = fix_text(text)
-        text = self.normalize(text)
+        # logger.debug(f"Text after ftfy fix: {text}")
+        text = self.fix_latex(text)
+        return text
+
+    def fix_latex(self, text: str) -> str:
+        """Fixes LaTeX formatting in a string.
+
+        Args:
+            text (str): String to fix.
+
+        Returns:
+            str: Fixed string.
+        """
+        text = fix_latex_left_right(text, fix_delimiter=False)
+        text = fix_latex_environments(text)
+        text = remove_up_commands(text)
+        text = remove_unsupported_commands(text)
+        # text = self.normalize(text)
         return text
 
     def __call__(

+ 1 - 326
mineru/model/mfr/unimernet/unimernet_hf/modeling_unimernet.py

@@ -1,5 +1,4 @@
 import os
-import re
 import warnings
 from typing import Optional
 
@@ -13,6 +12,7 @@ from transformers.models.vision_encoder_decoder.modeling_vision_encoder_decoder
 
 from .unimer_swin import UnimerSwinConfig, UnimerSwinModel, UnimerSwinImageProcessor
 from .unimer_mbart import UnimerMBartConfig, UnimerMBartForCausalLM
+from ...utils import latex_rm_whitespace
 
 AutoConfig.register(UnimerSwinConfig.model_type, UnimerSwinConfig)
 AutoConfig.register(UnimerMBartConfig.model_type, UnimerMBartConfig)
@@ -57,331 +57,6 @@ class TokenizerWrapper:
                     del toks[b][i]
         return toks
 
-
-LEFT_PATTERN = re.compile(r'(\\left)(\S*)')
-RIGHT_PATTERN = re.compile(r'(\\right)(\S*)')
-LEFT_COUNT_PATTERN = re.compile(r'\\left(?![a-zA-Z])')
-RIGHT_COUNT_PATTERN = re.compile(r'\\right(?![a-zA-Z])')
-LEFT_RIGHT_REMOVE_PATTERN = re.compile(r'\\left\.?|\\right\.?')
-
-def fix_latex_left_right(s):
-    """
-    修复LaTeX中的\\left和\\right命令
-    1. 确保它们后面跟有效分隔符
-    2. 平衡\\left和\\right的数量
-    """
-    # 白名单分隔符
-    valid_delims_list = [r'(', r')', r'[', r']', r'{', r'}', r'/', r'|',
-                         r'\{', r'\}', r'\lceil', r'\rceil', r'\lfloor',
-                         r'\rfloor', r'\backslash', r'\uparrow', r'\downarrow',
-                         r'\Uparrow', r'\Downarrow', r'\|', r'\.']
-
-    # 为\left后缺失有效分隔符的情况添加点
-    def fix_delim(match, is_left=True):
-        cmd = match.group(1)  # \left 或 \right
-        rest = match.group(2) if len(match.groups()) > 1 else ""
-        if not rest or rest not in valid_delims_list:
-            return cmd + "."
-        return match.group(0)
-
-    # 使用更精确的模式匹配\left和\right命令
-    # 确保它们是独立的命令,不是其他命令的一部分
-    # 使用预编译正则和统一回调函数
-    s = LEFT_PATTERN.sub(lambda m: fix_delim(m, True), s)
-    s = RIGHT_PATTERN.sub(lambda m: fix_delim(m, False), s)
-
-    # 更精确地计算\left和\right的数量
-    left_count = len(LEFT_COUNT_PATTERN.findall(s))  # 不匹配\lefteqn等
-    right_count = len(RIGHT_COUNT_PATTERN.findall(s))  # 不匹配\rightarrow等
-
-    if left_count == right_count:
-        # 如果数量相等,检查是否在同一组
-        return fix_left_right_pairs(s)
-    else:
-        # 如果数量不等,移除所有\left和\right
-        # logger.debug(f"latex:{s}")
-        # logger.warning(f"left_count: {left_count}, right_count: {right_count}")
-        return LEFT_RIGHT_REMOVE_PATTERN.sub('', s)
-
-
-def fix_left_right_pairs(latex_formula):
-    """
-    检测并修复LaTeX公式中\\left和\\right不在同一组的情况
-
-    Args:
-        latex_formula (str): 输入的LaTeX公式
-
-    Returns:
-        str: 修复后的LaTeX公式
-    """
-    # 用于跟踪花括号嵌套层级
-    brace_stack = []
-    # 用于存储\left信息: (位置, 深度, 分隔符)
-    left_stack = []
-    # 存储需要调整的\right信息: (开始位置, 结束位置, 目标位置)
-    adjustments = []
-
-    i = 0
-    while i < len(latex_formula):
-        # 检查是否是转义字符
-        if i > 0 and latex_formula[i - 1] == '\\':
-            backslash_count = 0
-            j = i - 1
-            while j >= 0 and latex_formula[j] == '\\':
-                backslash_count += 1
-                j -= 1
-
-            if backslash_count % 2 == 1:
-                i += 1
-                continue
-
-        # 检测\left命令
-        if i + 5 < len(latex_formula) and latex_formula[i:i + 5] == "\\left" and i + 5 < len(latex_formula):
-            delimiter = latex_formula[i + 5]
-            left_stack.append((i, len(brace_stack), delimiter))
-            i += 6  # 跳过\left和分隔符
-            continue
-
-        # 检测\right命令
-        elif i + 6 < len(latex_formula) and latex_formula[i:i + 6] == "\\right" and i + 6 < len(latex_formula):
-            delimiter = latex_formula[i + 6]
-
-            if left_stack:
-                left_pos, left_depth, left_delim = left_stack.pop()
-
-                # 如果\left和\right不在同一花括号深度
-                if left_depth != len(brace_stack):
-                    # 找到\left所在花括号组的结束位置
-                    target_pos = find_group_end(latex_formula, left_pos, left_depth)
-                    if target_pos != -1:
-                        # 记录需要移动的\right
-                        adjustments.append((i, i + 7, target_pos))
-
-            i += 7  # 跳过\right和分隔符
-            continue
-
-        # 处理花括号
-        if latex_formula[i] == '{':
-            brace_stack.append(i)
-        elif latex_formula[i] == '}':
-            if brace_stack:
-                brace_stack.pop()
-
-        i += 1
-
-    # 应用调整,从后向前处理以避免索引变化
-    if not adjustments:
-        return latex_formula
-
-    result = list(latex_formula)
-    adjustments.sort(reverse=True, key=lambda x: x[0])
-
-    for start, end, target in adjustments:
-        # 提取\right部分
-        right_part = result[start:end]
-        # 从原位置删除
-        del result[start:end]
-        # 在目标位置插入
-        result.insert(target, ''.join(right_part))
-
-    return ''.join(result)
-
-
-def find_group_end(text, pos, depth):
-    """查找特定深度的花括号组的结束位置"""
-    current_depth = depth
-    i = pos
-
-    while i < len(text):
-        if text[i] == '{' and (i == 0 or not is_escaped(text, i)):
-            current_depth += 1
-        elif text[i] == '}' and (i == 0 or not is_escaped(text, i)):
-            current_depth -= 1
-            if current_depth < depth:
-                return i
-        i += 1
-
-    return -1  # 未找到对应结束位置
-
-
-def is_escaped(text, pos):
-    """检查字符是否被转义"""
-    backslash_count = 0
-    j = pos - 1
-    while j >= 0 and text[j] == '\\':
-        backslash_count += 1
-        j -= 1
-
-    return backslash_count % 2 == 1
-
-
-def fix_unbalanced_braces(latex_formula):
-    """
-    检测LaTeX公式中的花括号是否闭合,并删除无法配对的花括号
-
-    Args:
-        latex_formula (str): 输入的LaTeX公式
-
-    Returns:
-        str: 删除无法配对的花括号后的LaTeX公式
-    """
-    stack = []  # 存储左括号的索引
-    unmatched = set()  # 存储不匹配括号的索引
-    i = 0
-
-    while i < len(latex_formula):
-        # 检查是否是转义的花括号
-        if latex_formula[i] in ['{', '}']:
-            # 计算前面连续的反斜杠数量
-            backslash_count = 0
-            j = i - 1
-            while j >= 0 and latex_formula[j] == '\\':
-                backslash_count += 1
-                j -= 1
-
-            # 如果前面有奇数个反斜杠,则该花括号是转义的,不参与匹配
-            if backslash_count % 2 == 1:
-                i += 1
-                continue
-
-            # 否则,该花括号参与匹配
-            if latex_formula[i] == '{':
-                stack.append(i)
-            else:  # latex_formula[i] == '}'
-                if stack:  # 有对应的左括号
-                    stack.pop()
-                else:  # 没有对应的左括号
-                    unmatched.add(i)
-
-        i += 1
-
-    # 所有未匹配的左括号
-    unmatched.update(stack)
-
-    # 构建新字符串,删除不匹配的括号
-    return ''.join(char for i, char in enumerate(latex_formula) if i not in unmatched)
-
-
-def process_latex(input_string):
-    """
-        处理LaTeX公式中的反斜杠:
-        1. 如果\后跟特殊字符(#$%&~_^\\{})或空格,保持不变
-        2. 如果\后跟两个小写字母,保持不变
-        3. 其他情况,在\后添加空格
-
-        Args:
-            input_string (str): 输入的LaTeX公式
-
-        Returns:
-            str: 处理后的LaTeX公式
-        """
-
-    def replace_func(match):
-        # 获取\后面的字符
-        next_char = match.group(1)
-
-        # 如果是特殊字符或空格,保持不变
-        if next_char in "#$%&~_^|\\{} \t\n\r\v\f":
-            return match.group(0)
-
-        # 如果是字母,检查下一个字符
-        if 'a' <= next_char <= 'z' or 'A' <= next_char <= 'Z':
-            pos = match.start() + 2  # \x后的位置
-            if pos < len(input_string) and ('a' <= input_string[pos] <= 'z' or 'A' <= input_string[pos] <= 'Z'):
-                # 下一个字符也是字母,保持不变
-                return match.group(0)
-
-        # 其他情况,在\后添加空格
-        return '\\' + ' ' + next_char
-
-    # 匹配\后面跟一个字符的情况
-    pattern = r'\\(.)'
-
-    return re.sub(pattern, replace_func, input_string)
-
-# 常见的在KaTeX/MathJax中可用的数学环境
-ENV_TYPES = ['array', 'matrix', 'pmatrix', 'bmatrix', 'vmatrix',
-             'Bmatrix', 'Vmatrix', 'cases', 'aligned', 'gathered']
-ENV_BEGIN_PATTERNS = {env: re.compile(r'\\begin\{' + env + r'\}') for env in ENV_TYPES}
-ENV_END_PATTERNS = {env: re.compile(r'\\end\{' + env + r'\}') for env in ENV_TYPES}
-ENV_FORMAT_PATTERNS = {env: re.compile(r'\\begin\{' + env + r'\}\{([^}]*)\}') for env in ENV_TYPES}
-
-def fix_latex_environments(s):
-    """
-    检测LaTeX中环境(如array)的\\begin和\\end是否匹配
-    1. 如果缺少\\begin标签则在开头添加
-    2. 如果缺少\\end标签则在末尾添加
-    """
-    for env in ENV_TYPES:
-        begin_count = len(ENV_BEGIN_PATTERNS[env].findall(s))
-        end_count = len(ENV_END_PATTERNS[env].findall(s))
-
-        if begin_count != end_count:
-            if end_count > begin_count:
-                format_match = ENV_FORMAT_PATTERNS[env].search(s)
-                default_format = '{c}' if env == 'array' else ''
-                format_str = '{' + format_match.group(1) + '}' if format_match else default_format
-
-                missing_count = end_count - begin_count
-                begin_command = '\\begin{' + env + '}' + format_str + ' '
-                s = begin_command * missing_count + s
-            else:
-                missing_count = begin_count - end_count
-                s = s + (' \\end{' + env + '}') * missing_count
-
-    return s
-
-
-UP_PATTERN = re.compile(r'\\up([a-zA-Z]+)')
-COMMANDS_TO_REMOVE_PATTERN = re.compile(
-    r'\\(?:lefteqn|boldmath|ensuremath|centering|textsubscript|sides|textsl|textcent|emph|protect|null)')
-REPLACEMENTS_PATTERNS = {
-    re.compile(r'\\underbar'): r'\\underline',
-    re.compile(r'\\Bar'): r'\\hat',
-    re.compile(r'\\Hat'): r'\\hat',
-    re.compile(r'\\Tilde'): r'\\tilde',
-    re.compile(r'\\slash'): r'/',
-    re.compile(r'\\textperthousand'): r'‰',
-    re.compile(r'\\sun'): r'☉',
-    re.compile(r'\\textunderscore'): r'\\_',
-    re.compile(r'\\fint'): r'⨏',
-    re.compile(r'\\up '): r'\\ ',
-    re.compile(r'\\vline = '): r'\\models ',
-    re.compile(r'\\vDash '): r'\\models ',
-    re.compile(r'\\sq \\sqcup '): r'\\square ',
-    re.compile(r'\\copyright'): r'©',
-}
-QQUAD_PATTERN = re.compile(r'\\qquad(?!\s)')
-
-def latex_rm_whitespace(s: str):
-    """Remove unnecessary whitespace from LaTeX code."""
-    s = fix_unbalanced_braces(s)
-    s = fix_latex_left_right(s)
-    s = fix_latex_environments(s)
-
-    # 使用预编译的正则表达式
-    s = UP_PATTERN.sub(
-        lambda m: m.group(0) if m.group(1) in ["arrow", "downarrow", "lus", "silon"] else f"\\{m.group(1)}", s
-    )
-    s = COMMANDS_TO_REMOVE_PATTERN.sub('', s)
-
-    # 应用所有替换
-    for pattern, replacement in REPLACEMENTS_PATTERNS.items():
-        s = pattern.sub(replacement, s)
-
-    # 处理LaTeX中的反斜杠和空格
-    s = process_latex(s)
-
-    # \qquad后补空格
-    s = QQUAD_PATTERN.sub(r'\\qquad ', s)
-
-    # 如果字符串以反斜杠结尾,去掉最后的反斜杠
-    while s.endswith('\\'):
-        s = s[:-1]
-
-    return s
-
-
 class UnimernetModel(VisionEncoderDecoderModel):
     def __init__(
         self,

+ 338 - 0
mineru/model/mfr/utils.py

@@ -0,0 +1,338 @@
+import re
+
+LEFT_PATTERN = re.compile(r'(\\left)(\S*)')
+RIGHT_PATTERN = re.compile(r'(\\right)(\S*)')
+LEFT_COUNT_PATTERN = re.compile(r'\\left(?![a-zA-Z])')
+RIGHT_COUNT_PATTERN = re.compile(r'\\right(?![a-zA-Z])')
+LEFT_RIGHT_REMOVE_PATTERN = re.compile(r'\\left\.?|\\right\.?')
+
+def fix_latex_left_right(s, fix_delimiter=True):
+    """
+    修复LaTeX中的\\left和\\right命令
+    1. 确保它们后面跟有效分隔符
+    2. 平衡\\left和\\right的数量
+    """
+    # 白名单分隔符
+    valid_delims_list = [r'(', r')', r'[', r']', r'{', r'}', r'/', r'|',
+                         r'\{', r'\}', r'\lceil', r'\rceil', r'\lfloor',
+                         r'\rfloor', r'\backslash', r'\uparrow', r'\downarrow',
+                         r'\Uparrow', r'\Downarrow', r'\|', r'\.']
+
+    # 为\left后缺失有效分隔符的情况添加点
+    def fix_delim(match, is_left=True):
+        cmd = match.group(1)  # \left 或 \right
+        rest = match.group(2) if len(match.groups()) > 1 else ""
+        if not rest or rest not in valid_delims_list:
+            return cmd + "."
+        return match.group(0)
+
+    # 使用更精确的模式匹配\left和\right命令
+    # 确保它们是独立的命令,不是其他命令的一部分
+    # 使用预编译正则和统一回调函数
+    if fix_delimiter:
+        s = LEFT_PATTERN.sub(lambda m: fix_delim(m, True), s)
+        s = RIGHT_PATTERN.sub(lambda m: fix_delim(m, False), s)
+
+    # 更精确地计算\left和\right的数量
+    left_count = len(LEFT_COUNT_PATTERN.findall(s))  # 不匹配\lefteqn等
+    right_count = len(RIGHT_COUNT_PATTERN.findall(s))  # 不匹配\rightarrow等
+
+    if left_count == right_count:
+        # 如果数量相等,检查是否在同一组
+        return fix_left_right_pairs(s)
+        # return s
+    else:
+        # 如果数量不等,移除所有\left和\right
+        # logger.debug(f"latex:{s}")
+        # logger.warning(f"left_count: {left_count}, right_count: {right_count}")
+        return LEFT_RIGHT_REMOVE_PATTERN.sub('', s)
+
+
+def fix_left_right_pairs(latex_formula):
+    """
+    检测并修复LaTeX公式中\\left和\\right不在同一组的情况
+
+    Args:
+        latex_formula (str): 输入的LaTeX公式
+
+    Returns:
+        str: 修复后的LaTeX公式
+    """
+    # 用于跟踪花括号嵌套层级
+    brace_stack = []
+    # 用于存储\left信息: (位置, 深度, 分隔符)
+    left_stack = []
+    # 存储需要调整的\right信息: (开始位置, 结束位置, 目标位置)
+    adjustments = []
+
+    i = 0
+    while i < len(latex_formula):
+        # 检查是否是转义字符
+        if i > 0 and latex_formula[i - 1] == '\\':
+            backslash_count = 0
+            j = i - 1
+            while j >= 0 and latex_formula[j] == '\\':
+                backslash_count += 1
+                j -= 1
+
+            if backslash_count % 2 == 1:
+                i += 1
+                continue
+
+        # 检测\left命令
+        if i + 5 < len(latex_formula) and latex_formula[i:i + 5] == "\\left" and i + 5 < len(latex_formula):
+            delimiter = latex_formula[i + 5]
+            left_stack.append((i, len(brace_stack), delimiter))
+            i += 6  # 跳过\left和分隔符
+            continue
+
+        # 检测\right命令
+        elif i + 6 < len(latex_formula) and latex_formula[i:i + 6] == "\\right" and i + 6 < len(latex_formula):
+            delimiter = latex_formula[i + 6]
+
+            if left_stack:
+                left_pos, left_depth, left_delim = left_stack.pop()
+
+                # 如果\left和\right不在同一花括号深度
+                if left_depth != len(brace_stack):
+                    # 找到\left所在花括号组的结束位置
+                    target_pos = find_group_end(latex_formula, left_pos, left_depth)
+                    if target_pos != -1:
+                        # 记录需要移动的\right
+                        adjustments.append((i, i + 7, target_pos))
+
+            i += 7  # 跳过\right和分隔符
+            continue
+
+        # 处理花括号
+        if latex_formula[i] == '{':
+            brace_stack.append(i)
+        elif latex_formula[i] == '}':
+            if brace_stack:
+                brace_stack.pop()
+
+        i += 1
+
+    # 应用调整,从后向前处理以避免索引变化
+    if not adjustments:
+        return latex_formula
+
+    result = list(latex_formula)
+    adjustments.sort(reverse=True, key=lambda x: x[0])
+
+    for start, end, target in adjustments:
+        # 提取\right部分
+        right_part = result[start:end]
+        # 从原位置删除
+        del result[start:end]
+        # 在目标位置插入
+        result.insert(target, ''.join(right_part))
+
+    return ''.join(result)
+
+
+def find_group_end(text, pos, depth):
+    """查找特定深度的花括号组的结束位置"""
+    current_depth = depth
+    i = pos
+
+    while i < len(text):
+        if text[i] == '{' and (i == 0 or not is_escaped(text, i)):
+            current_depth += 1
+        elif text[i] == '}' and (i == 0 or not is_escaped(text, i)):
+            current_depth -= 1
+            if current_depth < depth:
+                return i
+        i += 1
+
+    return -1  # 未找到对应结束位置
+
+
+def is_escaped(text, pos):
+    """检查字符是否被转义"""
+    backslash_count = 0
+    j = pos - 1
+    while j >= 0 and text[j] == '\\':
+        backslash_count += 1
+        j -= 1
+
+    return backslash_count % 2 == 1
+
+
+def fix_unbalanced_braces(latex_formula):
+    """
+    检测LaTeX公式中的花括号是否闭合,并删除无法配对的花括号
+
+    Args:
+        latex_formula (str): 输入的LaTeX公式
+
+    Returns:
+        str: 删除无法配对的花括号后的LaTeX公式
+    """
+    stack = []  # 存储左括号的索引
+    unmatched = set()  # 存储不匹配括号的索引
+    i = 0
+
+    while i < len(latex_formula):
+        # 检查是否是转义的花括号
+        if latex_formula[i] in ['{', '}']:
+            # 计算前面连续的反斜杠数量
+            backslash_count = 0
+            j = i - 1
+            while j >= 0 and latex_formula[j] == '\\':
+                backslash_count += 1
+                j -= 1
+
+            # 如果前面有奇数个反斜杠,则该花括号是转义的,不参与匹配
+            if backslash_count % 2 == 1:
+                i += 1
+                continue
+
+            # 否则,该花括号参与匹配
+            if latex_formula[i] == '{':
+                stack.append(i)
+            else:  # latex_formula[i] == '}'
+                if stack:  # 有对应的左括号
+                    stack.pop()
+                else:  # 没有对应的左括号
+                    unmatched.add(i)
+
+        i += 1
+
+    # 所有未匹配的左括号
+    unmatched.update(stack)
+
+    # 构建新字符串,删除不匹配的括号
+    return ''.join(char for i, char in enumerate(latex_formula) if i not in unmatched)
+
+
+def process_latex(input_string):
+    """
+        处理LaTeX公式中的反斜杠:
+        1. 如果\后跟特殊字符(#$%&~_^\\{})或空格,保持不变
+        2. 如果\后跟两个小写字母,保持不变
+        3. 其他情况,在\后添加空格
+
+        Args:
+            input_string (str): 输入的LaTeX公式
+
+        Returns:
+            str: 处理后的LaTeX公式
+        """
+
+    def replace_func(match):
+        # 获取\后面的字符
+        next_char = match.group(1)
+
+        # 如果是特殊字符或空格,保持不变
+        if next_char in "#$%&~_^|\\{} \t\n\r\v\f":
+            return match.group(0)
+
+        # 如果是字母,检查下一个字符
+        if 'a' <= next_char <= 'z' or 'A' <= next_char <= 'Z':
+            pos = match.start() + 2  # \x后的位置
+            if pos < len(input_string) and ('a' <= input_string[pos] <= 'z' or 'A' <= input_string[pos] <= 'Z'):
+                # 下一个字符也是字母,保持不变
+                return match.group(0)
+
+        # 其他情况,在\后添加空格
+        return '\\' + ' ' + next_char
+
+    # 匹配\后面跟一个字符的情况
+    pattern = r'\\(.)'
+
+    return re.sub(pattern, replace_func, input_string)
+
+# 常见的在KaTeX/MathJax中可用的数学环境
+ENV_TYPES = ['array', 'matrix', 'pmatrix', 'bmatrix', 'vmatrix',
+             'Bmatrix', 'Vmatrix', 'cases', 'aligned', 'gathered', 'align', 'align*']
+ENV_BEGIN_PATTERNS = {env: re.compile(r'\\begin\{' + env + r'\}') for env in ENV_TYPES}
+ENV_END_PATTERNS = {env: re.compile(r'\\end\{' + env + r'\}') for env in ENV_TYPES}
+ENV_FORMAT_PATTERNS = {env: re.compile(r'\\begin\{' + env + r'\}\{([^}]*)\}') for env in ENV_TYPES}
+
+def fix_latex_environments(s):
+    """
+    检测LaTeX中环境(如array)的\\begin和\\end是否匹配
+    1. 如果缺少\\begin标签则在开头添加
+    2. 如果缺少\\end标签则在末尾添加
+    """
+    for env in ENV_TYPES:
+        begin_count = len(ENV_BEGIN_PATTERNS[env].findall(s))
+        end_count = len(ENV_END_PATTERNS[env].findall(s))
+
+        if begin_count != end_count:
+            if end_count > begin_count:
+                format_match = ENV_FORMAT_PATTERNS[env].search(s)
+                default_format = '{c}' if env == 'array' else ''
+                format_str = '{' + format_match.group(1) + '}' if format_match else default_format
+
+                missing_count = end_count - begin_count
+                begin_command = '\\begin{' + env + '}' + format_str + ' '
+                s = begin_command * missing_count + s
+            else:
+                missing_count = begin_count - end_count
+                s = s + (' \\end{' + env + '}') * missing_count
+
+    return s
+
+
+REPLACEMENTS_PATTERNS = {
+    re.compile(r'\\underbar'): r'\\underline',
+    re.compile(r'\\Bar'): r'\\hat',
+    re.compile(r'\\Hat'): r'\\hat',
+    re.compile(r'\\Tilde'): r'\\tilde',
+    re.compile(r'\\slash'): r'/',
+    re.compile(r'\\textperthousand'): r'‰',
+    re.compile(r'\\sun'): r'☉',
+    re.compile(r'\\textunderscore'): r'\\_',
+    re.compile(r'\\fint'): r'⨏',
+    re.compile(r'\\up '): r'\\ ',
+    re.compile(r'\\vline = '): r'\\models ',
+    re.compile(r'\\vDash '): r'\\models ',
+    re.compile(r'\\sq \\sqcup '): r'\\square ',
+    re.compile(r'\\copyright'): r'©',
+}
+QQUAD_PATTERN = re.compile(r'\\qquad(?!\s)')
+
+
+def remove_up_commands(s: str):
+    """Remove unnecessary up commands from LaTeX code."""
+    UP_PATTERN = re.compile(r'\\up([a-zA-Z]+)')
+    s = UP_PATTERN.sub(
+        lambda m: m.group(0) if m.group(1) in ["arrow", "downarrow", "lus", "silon"] else f"\\{m.group(1)}", s
+    )
+    return s
+
+
+def remove_unsupported_commands(s: str):
+    """Remove unsupported LaTeX commands."""
+    COMMANDS_TO_REMOVE_PATTERN = re.compile(
+        r'\\(?:lefteqn|boldmath|ensuremath|centering|textsubscript|sides|textsl|textcent|emph|protect|null)')
+    s = COMMANDS_TO_REMOVE_PATTERN.sub('', s)
+    return s
+
+
+def latex_rm_whitespace(s: str):
+    """Remove unnecessary whitespace from LaTeX code."""
+    s = fix_unbalanced_braces(s)
+    s = fix_latex_left_right(s)
+    s = fix_latex_environments(s)
+
+    s = remove_up_commands(s)
+    s = remove_unsupported_commands(s)
+
+    # 应用所有替换
+    for pattern, replacement in REPLACEMENTS_PATTERNS.items():
+        s = pattern.sub(replacement, s)
+
+    # 处理LaTeX中的反斜杠和空格
+    s = process_latex(s)
+
+    # \qquad后补空格
+    s = QQUAD_PATTERN.sub(r'\\qquad ', s)
+
+    # 如果字符串以反斜杠结尾,去掉最后的反斜杠
+    while s.endswith('\\'):
+        s = s[:-1]
+
+    return s

+ 11 - 7
mineru/model/utils/pytorchocr/modeling/heads/rec_ppformulanet_head.py

@@ -23,6 +23,7 @@ from dataclasses import dataclass, fields, is_dataclass
 
 from sympy import totient
 
+from mineru.utils.config_reader import get_device
 from .rec_unimernet_head import (
     MBartForCausalLM,
     MBartDecoder,
@@ -797,6 +798,7 @@ class PPFormulaNet_Head(UniMERNetHead):
                 generation_config["forced_eos_token_id"],
             )
         )
+        self.device = torch.device(get_device())
 
     def prepare_inputs_for_generation(
             self,
@@ -891,8 +893,8 @@ class PPFormulaNet_Head(UniMERNetHead):
 
     def stopping_criteria(self, input_ids):
         if self.is_export:
-            return input_ids[:, -1] == torch.Tensor([self.eos_token_id])
-        is_done = torch.isin(input_ids[:, -1], torch.Tensor([self.eos_token_id]))
+            return input_ids[:, -1].cpu() == torch.Tensor([self.eos_token_id])
+        is_done = torch.isin(input_ids[:, -1].cpu(), torch.Tensor([self.eos_token_id]))
         return is_done
 
     def stopping_criteria_parallel(self, input_ids):
@@ -997,6 +999,7 @@ class PPFormulaNet_Head(UniMERNetHead):
                         torch.ones(
                             (batch_size, parallel_step),
                             dtype=torch.int64,
+                            device=self.device,
                         )
                         * decoder_start_token_id
                 )
@@ -1005,6 +1008,7 @@ class PPFormulaNet_Head(UniMERNetHead):
                         torch.ones(
                             (batch_size, 1),
                             dtype=torch.int64,
+                            device=self.device,
                         )
                         * decoder_start_token_id
                 )
@@ -1078,11 +1082,11 @@ class PPFormulaNet_Head(UniMERNetHead):
         eos_token = self.eos_token_id
         if use_parallel:
             unfinished_sequences = torch.ones(
-                [batch_size, parallel_step], dtype=torch.int64
+                [batch_size, parallel_step], dtype=torch.int64, device=self.device
             )
             parallel_length = math.ceil(self.max_seq_len // parallel_step)
         else:
-            unfinished_sequences = torch.ones(batch_size, dtype=torch.int64)
+            unfinished_sequences = torch.ones(batch_size, dtype=torch.int64, device=self.device)
             parallel_length = self.max_seq_len
 
         i_idx = 0
@@ -1103,7 +1107,7 @@ class PPFormulaNet_Head(UniMERNetHead):
             model_inputs = self.prepare_inputs_for_generation_export(
                 past_key_values=past_key_values, **model_kwargs
             )
-            decoder_attention_mask = torch.ones(input_ids.shape)
+            decoder_attention_mask = torch.ones(input_ids.shape, device=self.device)
 
             outputs = self.generate_single_iter(
                 decoder_input_ids=decoder_input_ids,
@@ -1147,12 +1151,12 @@ class PPFormulaNet_Head(UniMERNetHead):
             if use_parallel:
                 unfinished_sequences = (
                         unfinished_sequences
-                        & ~self.stopping_criteria_parallel(input_ids).to(torch.int64)
+                        & ~self.stopping_criteria_parallel(input_ids).to(torch.int64).to(self.device)
                 )
             else:
                 unfinished_sequences = unfinished_sequences & ~self.stopping_criteria(
                     input_ids
-                ).to(torch.int64)
+                ).to(torch.int64).to(self.device)
 
             if (
                     eos_token is not None

+ 11 - 4
mineru/model/utils/pytorchocr/modeling/heads/rec_unimernet_head.py

@@ -14,6 +14,8 @@ from torch import Tensor
 import torch.nn.functional as F
 from torch.nn import CrossEntropyLoss
 
+from mineru.utils.config_reader import get_device
+
 
 class ModelOutput(OrderedDict):
 
@@ -441,13 +443,14 @@ class MBartLearnedPositionalEmbedding(nn.Embedding):
     def __init__(self, num_embeddings, embedding_dim):
         self.offset = 2
         super().__init__(num_embeddings + self.offset, embedding_dim)
+        self.device = torch.device(get_device())
 
     def forward(self, input_ids, past_key_values_length=0):
         """`input_ids' shape is expected to be [bsz x seqlen]."""
         bsz, seq_len = input_ids.shape[:2]
         positions = torch.arange(
             past_key_values_length, past_key_values_length + seq_len, dtype=torch.int64
-        ).expand([bsz, -1])
+        ).expand([bsz, -1]).to(self.device)
         return nn.Embedding.forward(self, positions + self.offset)
 
 
@@ -656,6 +659,7 @@ class MBartDecoderLayer(nn.Module):
         self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
         self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
         self.final_layer_norm = nn.LayerNorm(self.embed_dim)
+        self.device = torch.device(get_device())
 
     def forward(
             self,
@@ -672,9 +676,12 @@ class MBartDecoderLayer(nn.Module):
 
         residual = hidden_states
         hidden_states = self.self_attn_layer_norm(hidden_states)
-        self_attn_past_key_value = (
-            past_key_value[:2] if past_key_value is not None else None
-        )
+
+        self_attn_past_key_value = None
+        if past_key_value is not None:
+            self_attn_past_key_value = tuple(
+                t.to(self.device) if isinstance(t, torch.Tensor) else t for t in past_key_value[:2]
+            )
 
         hidden_states, self_attn_weights, present_key_value = self.self_attn(
             hidden_states=hidden_states,

+ 4 - 2
mineru/utils/llm_aided.py

@@ -51,7 +51,7 @@ def llm_aided_title(page_info_list, title_aided_config):
 3. 保持字典内key-value的对应关系不变
 
 4. 优化层次结构:
-    - 为每个标题元素添加适当的层次结构
+    - 根据标题内容的语义为每个标题元素添加适当的层次结构
     - 行高较大的标题一般是更高级别的标题
     - 标题从前至后的层级必须是连续的,不能跳过层级
     - 标题层级最多为4级,不要添加过多的层级
@@ -61,7 +61,6 @@ def llm_aided_title(page_info_list, title_aided_config):
     - 在完成初步分级后,仔细检查分级结果的合理性
     - 根据上下文关系和逻辑顺序,对不合理的分级进行微调
     - 确保最终的分级结果符合文档的实际结构和逻辑
-    - 字典中可能包含被误当成标题的正文,你可以通过将其层级标记为 0 来排除它们
 
 IMPORTANT: 
 请直接返回优化过的由标题层级组成的字典,格式为{{标题id:标题层级}},如下:
@@ -78,6 +77,8 @@ Input title list:
 
 Corrected title list:
 """
+    #5.
+    #- 字典中可能包含被误当成标题的正文,你可以通过将其层级标记为 0 来排除它们
 
     retry_count = 0
     max_retries = 3
@@ -89,6 +90,7 @@ Corrected title list:
                 model=title_aided_config["model"],
                 messages=[
                     {'role': 'user', 'content': title_optimize_prompt}],
+                extra_body={"enable_thinking": False},
                 temperature=0.7,
                 stream=True,
             )

+ 1 - 1
pyproject.toml

@@ -115,7 +115,7 @@ namespaces = false
 
 [tool.setuptools.package-data]
 "mineru" = ["resources/**"]
-"mineru.model.ocr.paddleocr2pytorch.pytorchocr.utils" = ["resources/**"]
+"mineru.model.utils.pytorchocr.utils" = ["resources/**"]
 
 [tool.setuptools]
 include-package-data = true