Переглянути джерело

refactor(normalize_financial_numbers): 优化金额 token 规范化逻辑,增强对多种格式的支持

zhch158_admin 2 тижнів тому
батько
коміт
1e4a3fa9b9
1 змінених файлів з 78 додано та 63 видалено
  1. 78 63
      ocr_utils/normalize_financial_numbers.py

+ 78 - 63
ocr_utils/normalize_financial_numbers.py

@@ -1,85 +1,58 @@
 import re
 import os
 from pathlib import Path
-from decimal import Decimal, InvalidOperation
 
 
 def _normalize_amount_token(token: str) -> str:
     """
-    规范单个金额 token 中逗号/小数点的用法。
-    仅在形态明显为金额时进行纠错,其他情况原样返回。
+    规范单个金额 token 中逗号/小数点的用法,统一输出美式格式(千分位逗号 + 点小数)。
+
+    算法:
+    1. 找小数分隔符:优先取最后一个 '.'(若其后恰好为 1-2 位纯数字),
+       次选最后一个 ','(同条件);均不满足则视为纯整数。
+    2. 整数部分去除所有逗号和点,得到纯数字串,重新按三位一组插入千分位逗号。
+    3. 与小数部分拼接,统一输出 xxx,xxx.xx 格式。
     """
     if not token:
         return token
 
-    # 只处理包含数字的简单 token,避免带字母/其他符号的误改
+    # 只处理含分隔符的数字串,避免误改年份/ID 等纯数字
     if not re.fullmatch(r"[+-]?\d[\d,\.]*\d", token):
         return token
+    if "," not in token and "." not in token:
+        return token
 
     sign = ""
     core = token
     if core[0] in "+-":
         sign, core = core[0], core[1:]
 
-    has_dot = "." in core
-    has_comma = "," in core
-
-    # 辅助: 尝试解析为 Decimal;失败则认为不安全,回退原值
-    def _safe_decimal(s: str) -> bool:
-        try:
-            Decimal(s.replace(",", ""))
-            return True
-        except (InvalidOperation, ValueError):
-            return False
-
-    # 规则A:同时包含 . 和 ,,最后一个分隔符是逗号,且其后为 1-2 位数字
-    if has_dot and has_comma:
-        last_comma = core.rfind(",")
-        last_dot = core.rfind(".")
-        if last_comma > last_dot and last_comma != -1:
-            frac = core[last_comma + 1 :]
-            if 1 <= len(frac) <= 2 and frac.isdigit():
-                # 先把所有点当作千分位逗号,再把最后一个逗号当作小数点
-                temp = core.replace(".", ",")
-                idx = temp.rfind(",")
-                if idx != -1:
-                    candidate = temp[:idx] + "." + temp[idx + 1 :]
-                    if _safe_decimal(candidate):
-                        return sign + candidate
-
-    # 规则B:只有 .,多个点,最后一段视为小数,其余为千分位
-    if has_dot and not has_comma:
-        parts = core.split(".")
-        if len(parts) >= 3:
-            last = parts[-1]
-            ints = parts[:-1]
-            if 1 <= len(last) <= 2 and all(len(p) == 3 for p in ints[1:]):
-                candidate = ",".join(ints) + "." + last
-                if _safe_decimal(candidate):
-                    return sign + candidate
-
-    # 规则C:只有 ,,多个逗号,最后一段长度为 1-2 且前面为 3 位分组
-    if has_comma and not has_dot:
-        parts = core.split(",")
-        if len(parts) >= 3:
-            last = parts[-1]
-            ints = parts[:-1]
-            if 1 <= len(last) <= 2 and all(len(p) == 3 for p in ints[1:]):
-                # 将最后一个逗号视为小数点
-                idx = core.rfind(",")
-                candidate = core[:idx] + "." + core[idx + 1 :]
-                if _safe_decimal(candidate):
-                    return sign + candidate
-        # 规则D:只有 ,,且仅有一个逗号、逗号后 1-2 位数字 → 欧洲格式小数,如 301,55 → 301.55
-        elif len(parts) == 2:
-            left, right = parts[0], parts[1]
-            if 1 <= len(right) <= 2 and right.isdigit() and left.isdigit():
-                candidate = left + "." + right
-                if _safe_decimal(candidate):
-                    return sign + candidate
-
-    # 没有需要纠错的典型形态,直接返回原 token
-    return token
+    # 步骤 1:确定小数分隔符('.' 优先于 ',')
+    dec_digits: str | None = None
+    int_part = core
+    for sep in (".", ","):
+        pos = core.rfind(sep)
+        if pos == -1:
+            continue
+        after = core[pos + 1 :]
+        if 1 <= len(after) <= 2 and after.isdigit():
+            dec_digits = after
+            int_part = core[:pos]
+            break
+
+    # 步骤 2:整数部分去除所有分隔符,得到纯数字串
+    int_digits = re.sub(r"[,.]", "", int_part)
+    if not int_digits or not int_digits.isdigit():
+        return token  # 无法解析,保留原样
+
+    # 步骤 3:重新做千分位分组
+    n = len(int_digits)
+    rem = n % 3 or 3
+    groups = [int_digits[:rem]] + [int_digits[i : i + 3] for i in range(rem, n, 3)]
+    result = sign + ",".join(groups)
+    if dec_digits is not None:
+        result += "." + dec_digits
+    return result
 
 
 def normalize_financial_numbers(text: str) -> str:
@@ -411,6 +384,8 @@ if __name__ == "__main__":
                 "<td data-bbox=\"[10,20,20,30]\">1,234,567,89</td></tr>"
                 "<tr><td data-bbox=\"[0,20,10,40]\">测试金额C</td>"
                 "<td data-bbox=\"[10,20,20,40]\">301,55</td></tr>"
+                "<tr><td data-bbox=\"[0,20,10,50]\">测试金额D</td>"
+                "<td data-bbox=\"[10,20,20,40]\">1.068.987,094.02</td></tr>"
                 "</tbody></table>"
             ),
         }
@@ -428,6 +403,7 @@ if __name__ == "__main__":
 <tr><td>测试金额A</td><td>12.123,456,00</td></tr>
 <tr><td>测试金额B</td><td>1,234,567,89</td></tr>
 <tr><td>测试金额C</td><td>301,55</td></tr>
+<tr><td>测试金额D</td><td>1.068.987,094.02</td></tr>
 </tbody></table>
 """
     print("原始 Markdown:")
@@ -435,3 +411,42 @@ if __name__ == "__main__":
     normalized_md = normalize_markdown_table(demo_md)
     print("\n标准化后 Markdown:")
     print(normalized_md)
+
+    cases = [
+        # A 类:标准美式格式,不应被修改
+        ("10,000.00",        "10,000.00"),
+        ("67,455.00",        "67,455.00"),
+        ("89,400.00",        "89,400.00"),
+        ("100,200.00",       "100,200.00"),
+        ("494,339.63",       "494,339.63"),
+        ("1,179.05",         "1,179.05"),
+        ("27,396.05",        "27,396.05"),
+        # B 类:混合/大数格式,需被修正
+        ("19.879,111.45",    "19,879,111.45"),
+        ("27.072,795.05",    "27,072,795.05"),
+        ("468.348,422.85",   "468,348,422.85"),
+        ("4740,251.56",      "4,740,251.56"),
+        # C 类:多余分隔符
+        ("585,515.936.19",   "585,515,936.19"),
+        ("22,240.761.60",    "22,240,761.60"),
+        ("198,757.280.38",   "198,757,280.38"),
+        ("618,846.219.71",   "618,846,219.71"),
+        # 原 demo 案例
+        ("12.123,456,00",    "12,123,456.00"),
+        ("1,234,567,89",     "1,234,567.89"),
+        ("301,55",           "301.55"),
+        ("1.068.987,094.02", "1,068,987,094.02"),
+        # 标准欧洲格式
+        ("1.234,56",         "1,234.56"),
+        ("1.234.567,89",     "1,234,567.89"),
+    ]
+
+    ok = fail = 0
+    for inp, expected in cases:
+        got = _normalize_amount_token(inp)
+        status = "✅" if got == expected else "❌"
+        if got != expected:
+            fail += 1
+        print(f"{status} {inp!r:30s} → {got!r}" + (f"  (期望 {expected!r})" if got != expected else ""))
+
+    print(f"\n共 {ok+fail} 个,通过 {len(cases)-fail},失败 {fail}")