Explorar el Código

fix(mfr): optimize LaTeX formula repair functionality

- Improve \left and \right command handling in LaTeX formulas
- Enhance environment type matching for array, matrix, and other structures
- Refactor code for better readability and maintainability
myhloli hace 6 meses
padre
commit
2d1a0f2ca6

+ 27 - 47
magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/modeling_unimernet.py

@@ -58,6 +58,12 @@ class TokenizerWrapper:
         return toks
 
 
+LEFT_PATTERN = re.compile(r'(\\left)(\S*)')
+RIGHT_PATTERN = re.compile(r'(\\right)(\S*)')
+LEFT_COUNT_PATTERN = re.compile(r'\\left(?![a-zA-Z])')
+RIGHT_COUNT_PATTERN = re.compile(r'\\right(?![a-zA-Z])')
+LEFT_RIGHT_REMOVE_PATTERN = re.compile(r'\\left\.?|\\right\.?')
+
 def fix_latex_left_right(s):
     """
     修复LaTeX中的\left和\right命令
@@ -71,31 +77,22 @@ def fix_latex_left_right(s):
                          r'\Uparrow', r'\Downarrow', r'\|', r'\.']
 
     # 为\left后缺失有效分隔符的情况添加点
-    def fix_left_delim(match):
-        cmd = match.group(1)  # \left
-        rest = match.group(2) if len(match.groups()) > 1 else ""
-        # if not rest or not re.match(f"^({valid_delims})", rest):
-        if not rest or rest not in valid_delims_list:
-            return cmd + "."
-        return match.group(0)
-
-    # 为\right后缺失有效分隔符的情况添加点
-    def fix_right_delim(match):
-        cmd = match.group(1)  # \right
+    def fix_delim(match, is_left=True):
+        cmd = match.group(1)  # \left 或 \right
         rest = match.group(2) if len(match.groups()) > 1 else ""
-        # if not rest or not re.match(f"^({valid_delims})", rest):
         if not rest or rest not in valid_delims_list:
             return cmd + "."
         return match.group(0)
 
     # 使用更精确的模式匹配\left和\right命令
     # 确保它们是独立的命令,不是其他命令的一部分
-    s = re.sub(r'(\\left)(\S*)', fix_left_delim, s)
-    s = re.sub(r'(\\right)(\S*)', fix_right_delim, s)
+    # 使用预编译正则和统一回调函数
+    s = LEFT_PATTERN.sub(lambda m: fix_delim(m, True), s)
+    s = RIGHT_PATTERN.sub(lambda m: fix_delim(m, False), s)
 
     # 更精确地计算\left和\right的数量
-    left_count = len(re.findall(r'\\left(?![a-zA-Z])', s))  # 不匹配\lefteqn等
-    right_count = len(re.findall(r'\\right(?![a-zA-Z])', s))  # 不匹配\rightarrow等
+    left_count = len(LEFT_COUNT_PATTERN.findall(s))  # 不匹配\lefteqn等
+    right_count = len(RIGHT_COUNT_PATTERN.findall(s))  # 不匹配\rightarrow等
 
     if left_count == right_count:
         # 如果数量相等,检查是否在同一组
@@ -104,7 +101,7 @@ def fix_latex_left_right(s):
         # 如果数量不等,移除所有\left和\right
         # logger.debug(f"latex:{s}")
         # logger.warning(f"left_count: {left_count}, right_count: {right_count}")
-        return re.sub(r'\\left\.?|\\right\.?', '', s)
+        return LEFT_RIGHT_REMOVE_PATTERN.sub('', s)
 
 
 def fix_left_right_pairs(latex_formula):
@@ -302,6 +299,12 @@ def process_latex(input_string):
 
     return re.sub(pattern, replace_func, input_string)
 
+# 常见的在KaTeX/MathJax中可用的数学环境
+ENV_TYPES = ['array', 'matrix', 'pmatrix', 'bmatrix', 'vmatrix',
+             'Bmatrix', 'Vmatrix', 'cases', 'aligned', 'gathered']
+ENV_BEGIN_PATTERNS = {env: re.compile(r'\\begin\{' + env + r'\}') for env in ENV_TYPES}
+ENV_END_PATTERNS = {env: re.compile(r'\\end\{' + env + r'\}') for env in ENV_TYPES}
+ENV_FORMAT_PATTERNS = {env: re.compile(r'\\begin\{' + env + r'\}\{([^}]*)\}') for env in ENV_TYPES}
 
 def fix_latex_environments(s):
     """
@@ -309,45 +312,22 @@ def fix_latex_environments(s):
     1. 如果缺少\begin标签则在开头添加
     2. 如果缺少\end标签则在末尾添加
     """
-    # 常见的在KaTeX/MathJax中可用的数学环境
-    env_types = [
-        'array', 'matrix', 'pmatrix', 'bmatrix', 'vmatrix',
-        'Bmatrix', 'Vmatrix', 'cases', 'aligned', 'gathered'
-    ]
+    for env in ENV_TYPES:
+        begin_count = len(ENV_BEGIN_PATTERNS[env].findall(s))
+        end_count = len(ENV_END_PATTERNS[env].findall(s))
 
-    for env in env_types:
-        # 计算\begin{env}和\end{env}的数量
-        begin_pattern = r'\\begin\{' + env + r'\}'
-        end_pattern = r'\\end\{' + env + r'\}'
-
-        begin_count = len(re.findall(begin_pattern, s))
-        end_count = len(re.findall(end_pattern, s))
-
-        # 处理两种不匹配情况
         if begin_count != end_count:
-            # 情况1:缺少\begin - 在开头添加缺失的\begin{env}
             if end_count > begin_count:
-                # 尝试从现有的\begin{env}中提取格式
-                format_match = re.search(r'\\begin\{' + env + r'\}\{([^}]*)\}', s)
-
-                # 默认格式,对于array需要列格式
-                default_format = ''
-                if env == 'array':
-                    default_format = '{c}'  # 默认单列居中
-
+                format_match = ENV_FORMAT_PATTERNS[env].search(s)
+                default_format = '{c}' if env == 'array' else ''
                 format_str = '{' + format_match.group(1) + '}' if format_match else default_format
 
-                # 添加缺失的\begin{env}
                 missing_count = end_count - begin_count
                 begin_command = '\\begin{' + env + '}' + format_str + ' '
                 s = begin_command * missing_count + s
-
-            # 情况2:缺少\end - 在末尾添加缺失的\end{env}
-            elif begin_count > end_count:
-                # 添加缺失的\end{env}
+            else:
                 missing_count = begin_count - end_count
-                end_command = ' \\end{' + env + '}'
-                s = s + end_command * missing_count
+                s = s + (' \\end{' + env + '}') * missing_count
 
     return s