Преглед изворни кода

fix(magic_pdf): improve LaTeX formula processing and environment handling

- Refactor LaTeX left/right pair fixing logic for better balance
- Add environment detection and correction for common math environments
- Implement more robust whitespace handling and command substitution
- Optimize regex patterns for improved performance and readability
myhloli пре 6 месеци
родитељ
комит
c8747cffb4

+ 90 - 74
magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/modeling_unimernet.py

@@ -65,13 +65,17 @@ def fix_latex_left_right(s):
     2. 平衡\left和\right的数量
     """
     # 白名单分隔符
-    valid_delims = r'[()\[\]{}/|]|\\{|\\}|\\lceil|\\rceil|\\lfloor|\\rfloor|/|\\backslash|\\uparrow|\\downarrow|\\Uparrow|\\Downarrow|\\||\\.'
+    valid_delims_list = [r'(', r')', r'[', r']', r'{', r'}', r'/', r'|',
+                         r'\{', r'\}', r'\lceil', r'\rceil', r'\lfloor',
+                         r'\rfloor', r'\backslash', r'\uparrow', r'\downarrow',
+                         r'\Uparrow', r'\Downarrow', r'\|', r'\.']
 
     # 为\left后缺失有效分隔符的情况添加点
     def fix_left_delim(match):
         cmd = match.group(1)  # \left
         rest = match.group(2) if len(match.groups()) > 1 else ""
-        if not rest or not re.match(f"^({valid_delims})", rest):
+        # if not rest or not re.match(f"^({valid_delims})", rest):
+        if not rest or rest not in valid_delims_list:
             return cmd + "."
         return match.group(0)
 
@@ -79,7 +83,8 @@ def fix_latex_left_right(s):
     def fix_right_delim(match):
         cmd = match.group(1)  # \right
         rest = match.group(2) if len(match.groups()) > 1 else ""
-        if not rest or not re.match(f"^({valid_delims})", rest):
+        # if not rest or not re.match(f"^({valid_delims})", rest):
+        if not rest or rest not in valid_delims_list:
             return cmd + "."
         return match.group(0)
 
@@ -92,28 +97,15 @@ def fix_latex_left_right(s):
     left_count = len(re.findall(r'\\left(?![a-zA-Z])', s))  # 不匹配\lefteqn等
     right_count = len(re.findall(r'\\right(?![a-zA-Z])', s))  # 不匹配\rightarrow等
 
-    if left_count != right_count:
-        logger.debug(f"latex:{s}")
-        logger.warning(f"left_count: {left_count}, right_count: {right_count}")
-
-    if left_count > right_count:
-        s += ''.join(['\\right.' for _ in range(left_count - right_count)])
-    elif right_count > left_count:
-        # 不再在开头插入\left.,而是在第一个\right前插入
-        if '\\right' in s:
-            # 找出所有\right的位置
-            right_positions = [m.start() for m in re.finditer(r'\\right(?![a-zA-Z])', s)]
-
-            # 从前到后为每个缺失的\left处理
-            new_s = s
-            offset = 0
-            for i in range(min(right_count - left_count, len(right_positions))):
-                pos = right_positions[i] + offset
-                new_s = new_s[:pos] + '\\left.' + new_s[pos:]
-                offset += 6  # \left.的长度
-            s = new_s
-
-    return fix_left_right_pairs(s)
+    if left_count == right_count:
+        # 如果数量相等,检查是否在同一组
+        return fix_left_right_pairs(s)
+    else:
+        # 如果数量不等,移除所有\left和\right
+        # logger.debug(f"latex:{s}")
+        # logger.warning(f"left_count: {left_count}, right_count: {right_count}")
+        return re.sub(r'\\left\.?|\\right\.?', '', s)
+
 
 def fix_left_right_pairs(latex_formula):
     """
@@ -311,66 +303,90 @@ def process_latex(input_string):
     return re.sub(pattern, replace_func, input_string)
 
 
-def latex_rm_whitespace(s: str):
-    """Remove unnecessary whitespace from LaTeX code.
+def fix_latex_environments(s):
     """
-    # logger.debug(f"latex_orig: {s}")
-
-    s = fix_unbalanced_braces(s)
-
-    # left right不匹配的情况(只考虑了不在同一个组里挪到同一个组里的逻辑,没有考虑right比left多的情况)
-    # 还有加一个\left或\right后至少要跟随一个符号,如果没符号就补.
-    s = fix_latex_left_right(s)
-    # s = fix_left_right_pairs(s)
-
-    # 用正则删除\left,\left.,\right,\right.
-    # s = re.sub(r'\\left\.?|\\right\.?', '', s)
+    检测LaTeX中环境(如array)的\begin和\end是否匹配
+    1. 如果缺少\begin标签则在开头添加
+    2. 如果缺少\end标签则在末尾添加
+    """
+    # 常见的在KaTeX/MathJax中可用的数学环境
+    env_types = [
+        'array', 'matrix', 'pmatrix', 'bmatrix', 'vmatrix',
+        'Bmatrix', 'Vmatrix', 'cases', 'aligned', 'gathered'
+    ]
+
+    for env in env_types:
+        # 计算\begin{env}和\end{env}的数量
+        begin_pattern = r'\\begin\{' + env + r'\}'
+        end_pattern = r'\\end\{' + env + r'\}'
+
+        begin_count = len(re.findall(begin_pattern, s))
+        end_count = len(re.findall(end_pattern, s))
+
+        # 处理两种不匹配情况
+        if begin_count != end_count:
+            # 情况1:缺少\begin - 在开头添加缺失的\begin{env}
+            if end_count > begin_count:
+                # 尝试从现有的\begin{env}中提取格式
+                format_match = re.search(r'\\begin\{' + env + r'\}\{([^}]*)\}', s)
+
+                # 默认格式,对于array需要列格式
+                default_format = ''
+                if env == 'array':
+                    default_format = '{c}'  # 默认单列居中
+
+                format_str = '{' + format_match.group(1) + '}' if format_match else default_format
+
+                # 添加缺失的\begin{env}
+                missing_count = end_count - begin_count
+                begin_command = '\\begin{' + env + '}' + format_str + ' '
+                s = begin_command * missing_count + s
+
+            # 情况2:缺少\end - 在末尾添加缺失的\end{env}
+            elif begin_count > end_count:
+                # 添加缺失的\end{env}
+                missing_count = begin_count - end_count
+                end_command = ' \\end{' + env + '}'
+                s = s + end_command * missing_count
 
+    return s
 
-    # 替换\up命令
-    s = re.sub(r'\\up([a-zA-Z]+)',
-               lambda m: m.group(0) if m.group(1) in ["arrow", "downarrow", "lus", "silon",] else f"\\{m.group(1)}", s)
 
-    # 替换\underbar为underline
-    s = re.sub(r'\\underbar', r'\\underline', s)
+UP_PATTERN = re.compile(r'\\up([a-zA-Z]+)')
+COMMANDS_TO_REMOVE_PATTERN = re.compile(
+    r'\\(?:lefteqn|boldmath|ensuremath|centering|textsubscript|sides|textsl|textcent|emph)')
+REPLACEMENTS_PATTERNS = {
+    re.compile(r'\\underbar'): r'\\underline',
+    re.compile(r'\\Bar'): r'\\hat',
+    re.compile(r'\\Hat'): r'\\hat',
+    re.compile(r'\\Tilde'): r'\\tilde',
+    re.compile(r'\\slash'): r'/',
+    re.compile(r'\\textperthousand'): r'‰',
+    re.compile(r'\\sun'): r'☉'
+}
+QQUAD_PATTERN = re.compile(r'\\qquad(?!\s)')
 
-    # 删除\lefteqn
-    s = re.sub(r'\\lefteqn', r'', s)
+def latex_rm_whitespace(s: str):
+    """Remove unnecessary whitespace from LaTeX code."""
+    s = fix_unbalanced_braces(s)
+    s = fix_latex_left_right(s)
+    s = fix_latex_environments(s)
 
-    # 删除\boldmath
-    s = re.sub(r'\\boldmath', r'', s)
+    # 使用预编译的正则表达式
+    s = UP_PATTERN.sub(
+        lambda m: m.group(0) if m.group(1) in ["arrow", "downarrow", "lus", "silon"] else f"\\{m.group(1)}", s
+    )
+    s = COMMANDS_TO_REMOVE_PATTERN.sub('', s)
 
-    # \Bar换成\hat
-    s = re.sub(r'\\Bar', r'\\hat', s)
+    # 应用所有替换
+    for pattern, replacement in REPLACEMENTS_PATTERNS.items():
+        s = pattern.sub(replacement, s)
 
-    # \后缺失空格的补空格
+    # 处理LaTeX中的反斜杠和空格
     s = process_latex(s)
 
     # \qquad后补空格
-    s = re.sub(r'\\qquad(?!\s)', r'\\qquad ', s)
-
-    # \slash 换成 /
-    s = re.sub(r'\\slash', r'/', s)
-
-    # # 先保存 "\ " 序列,防止被错误处理
-    # s = re.sub(r'\\ ', r'\\SPACE', s)
-    #
-    # text_reg = r'(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})'
-    # letter = r'[a-zA-Z]'
-    # noletter = r'[\W_^\d]'
-    # names = [x[0].replace(' ', '') for x in re.findall(text_reg, s)]
-    # s = re.sub(text_reg, lambda _: str(names.pop(0)), s)
-    # news = s
-    # while True:
-    #     s = news
-    #     news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, noletter), r'\1\2', s)
-    #     news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, letter), r'\1\2', news)
-    #     news = re.sub(r'(%s)\s+?(%s)' % (letter, noletter), r'\1\2', news)
-    #     if news == s:
-    #         break
-    #
-    # # 恢复 "\ " 序列
-    # news = re.sub(r'\\SPACE', r'\\ ', news)
+    s = QQUAD_PATTERN.sub(r'\\qquad ', s)
 
     return s