Quellcode durchsuchen

feat(ocr_utils): 增强财务数字标准化功能,添加金额 token 纠错逻辑,支持逗号和小数点的正确用法

zhch158_admin vor 2 Wochen
Ursprung
Commit
35c6e6cf36
1 geänderte Dateien mit 216 neuen und 71 gelöschten Zeilen
  1. 216 71
      ocr_utils/normalize_financial_numbers.py

+ 216 - 71
ocr_utils/normalize_financial_numbers.py

@@ -1,16 +1,83 @@
 import re
 import os
 from pathlib import Path
+from decimal import Decimal, InvalidOperation
+
+
+def _normalize_amount_token(token: str) -> str:
+    """
+    规范单个金额 token 中逗号/小数点的用法。
+    仅在形态明显为金额时进行纠错,其他情况原样返回。
+    """
+    if not token:
+        return token
+
+    # 只处理包含数字的简单 token,避免带字母/其他符号的误改
+    if not re.fullmatch(r"[+-]?\d[\d,\.]*\d", token):
+        return token
+
+    sign = ""
+    core = token
+    if core[0] in "+-":
+        sign, core = core[0], core[1:]
+
+    has_dot = "." in core
+    has_comma = "," in core
+
+    # 辅助: 尝试解析为 Decimal;失败则认为不安全,回退原值
+    def _safe_decimal(s: str) -> bool:
+        try:
+            Decimal(s.replace(",", ""))
+            return True
+        except (InvalidOperation, ValueError):
+            return False
+
+    # 规则A:同时包含 . 和 ,,最后一个分隔符是逗号,且其后为 1-2 位数字
+    if has_dot and has_comma:
+        last_comma = core.rfind(",")
+        last_dot = core.rfind(".")
+        if last_comma > last_dot and last_comma != -1:
+            frac = core[last_comma + 1 :]
+            if 1 <= len(frac) <= 2 and frac.isdigit():
+                # 先把所有点当作千分位逗号,再把最后一个逗号当作小数点
+                temp = core.replace(".", ",")
+                idx = temp.rfind(",")
+                if idx != -1:
+                    candidate = temp[:idx] + "." + temp[idx + 1 :]
+                    if _safe_decimal(candidate):
+                        return sign + candidate
+
+    # 规则B:只有 .,多个点,最后一段视为小数,其余为千分位
+    if has_dot and not has_comma:
+        parts = core.split(".")
+        if len(parts) >= 3:
+            last = parts[-1]
+            ints = parts[:-1]
+            if 1 <= len(last) <= 2 and all(len(p) == 3 for p in ints[1:]):
+                candidate = ",".join(ints) + "." + last
+                if _safe_decimal(candidate):
+                    return sign + candidate
+
+    # 规则C:只有 ,,多个逗号,最后一段长度为 1-2 且前面为 3 位分组
+    if has_comma and not has_dot:
+        parts = core.split(",")
+        if len(parts) >= 3:
+            last = parts[-1]
+            ints = parts[:-1]
+            if 1 <= len(last) <= 2 and all(len(p) == 3 for p in ints[1:]):
+                # 将最后一个逗号视为小数点
+                idx = core.rfind(",")
+                candidate = core[:idx] + "." + core[idx + 1 :]
+                if _safe_decimal(candidate):
+                    return sign + candidate
+
+    # 没有需要纠错的典型形态,直接返回原 token
+    return token
+
 
 def normalize_financial_numbers(text: str) -> str:
     """
-    标准化财务数字:将全角字符转换为半角字符
-    
-    Args:
-        text: 原始文本
-    
-    Returns:
-        标准化后的文本
+    标准化财务数字:将全角字符转换为半角字符,并纠正常见的逗号/小数点错用。
     """
     if not text:
         return text
@@ -31,30 +98,30 @@ def normalize_financial_numbers(text: str) -> str:
         '%': '%',  # 全角百分号转半角百分号
     }
     
-    # 第一步:执行基础字符替换
+    # 第一步:执行基础字符替换(全角 -> 半角)
     normalized_text = text
     for fullwidth, halfwidth in fullwidth_to_halfwidth.items():
         normalized_text = normalized_text.replace(fullwidth, halfwidth)
     
-    # 第二步:处理数字序列中的空格和分隔符
-    # 修改正则表达式以匹配完整的数字序列,包括空格
-    # 匹配模式:数字 + (空格? + 逗号 + 空格? + 数字)* + (空格? + 小数点 + 数字+)?
+    # 第二步:处理数字序列中的空格和分隔符(保留原有逻辑)
     number_sequence_pattern = r'(\d+(?:\s*[,,]\s*\d+)*(?:\s*[。..]\s*\d+)?)'
     
     def normalize_number_sequence(match):
         sequence = match.group(1)
-        
-        # 处理千分位分隔符周围的空格
-        # 将 "数字 + 空格 + 逗号 + 空格 + 数字" 标准化为 "数字,数字"
         sequence = re.sub(r'(\d)\s*[,,]\s*(\d)', r'\1,\2', sequence)
-        
-        # 处理小数点周围的空格
-        # 将 "数字 + 空格 + 小数点 + 空格 + 数字" 标准化为 "数字.数字"
         sequence = re.sub(r'(\d)\s*[。..]\s*(\d)', r'\1.\2', sequence)
-        
         return sequence
     
     normalized_text = re.sub(number_sequence_pattern, normalize_number_sequence, normalized_text)
+
+    # 第三步:对疑似金额 token 做逗号/小数点纠错
+    amount_pattern = r'(?P<tok>[+-]?\d[\d,\.]*\d)'
+
+    def _amount_sub(m: re.Match) -> str:
+        tok = m.group('tok')
+        return _normalize_amount_token(tok)
+
+    normalized_text = re.sub(amount_pattern, _amount_sub, normalized_text)
     return normalized_text
     
 def normalize_markdown_table(markdown_content: str) -> str:
@@ -78,7 +145,7 @@ def normalize_markdown_table(markdown_content: str) -> str:
     table_pattern = r'(<table[^>]*>.*?</table>)'
     
     def normalize_table_match(match):
-        """处理单个表格匹配,保留原始格式"""
+        """处理单个表格匹配,保留原始格式,并追加数字标准化说明注释。"""
         table_html = match.group(1)
         original_table_html = table_html  # 保存原始HTML用于比较
         
@@ -86,52 +153,65 @@ def normalize_markdown_table(markdown_content: str) -> str:
         soup = BeautifulSoup(table_html, 'html.parser')
         tables = soup.find_all('table')
         
-        # 记录所有需要替换的文本(原始文本 -> 标准化文本)
-        replacements = []
+        # 记录本表格中所有数值修改
+        changes: list[dict] = []
         
         for table in tables:
-            if isinstance(table, Tag):
-                cells = table.find_all(['td', 'th'])
-                for cell in cells:
-                    if isinstance(cell, Tag):
-                        # 获取单元格的纯文本内容
-                        original_text = cell.get_text()
-                        normalized_text = normalize_financial_numbers(original_text)
-                        
-                        # 如果内容发生了变化,记录替换
-                        if original_text != normalized_text:
-                            # 找到单元格中所有文本节点并替换
-                            from bs4.element import NavigableString
-                            for text_node in cell.find_all(string=True, recursive=True):
-                                if isinstance(text_node, NavigableString):
-                                    text_str = str(text_node)
-                                    if text_str.strip():
-                                        normalized = normalize_financial_numbers(text_str.strip())
-                                        if normalized != text_str.strip():
-                                            # 保留原始文本节点的前后空白
-                                            if text_str.strip() == text_str:
-                                                # 纯文本节点,直接替换
-                                                text_node.replace_with(normalized)
-                                            else:
-                                                # 有前后空白,需要保留
-                                                leading_ws = text_str[:len(text_str) - len(text_str.lstrip())]
-                                                trailing_ws = text_str[len(text_str.rstrip()):]
-                                                text_node.replace_with(leading_ws + normalized + trailing_ws)
+            if not isinstance(table, Tag):
+                continue
+            # 通过 tr / td(th) 计算行列位置
+            for row_idx, tr in enumerate(table.find_all('tr')):  # type: ignore[reportAttributeAccessIssue]
+                cells = tr.find_all(['td', 'th'])  # type: ignore[reportAttributeAccessIssue]
+                for col_idx, cell in enumerate(cells):
+                    if not isinstance(cell, Tag):
+                        continue
+                    # 获取单元格纯文本
+                    original_text = cell.get_text()
+                    normalized_text = normalize_financial_numbers(original_text)
+                    if original_text == normalized_text:
+                        continue
+                    # 记录一条修改
+                    changes.append(
+                        {
+                            "row": row_idx,
+                            "col": col_idx,
+                            "old": original_text,
+                            "new": normalized_text,
+                        }
+                    )
+                    # 具体替换:保持原有逻辑,按文本节点逐个替换以保留空白
+                    from bs4.element import NavigableString
+                    for text_node in cell.find_all(string=True, recursive=True):
+                        if isinstance(text_node, NavigableString):
+                            text_str = str(text_node)
+                            if not text_str.strip():
+                                continue
+                            normalized = normalize_financial_numbers(text_str.strip())
+                            if normalized != text_str.strip():
+                                if text_str.strip() == text_str:
+                                    text_node.replace_with(normalized)
+                                else:
+                                    leading_ws = text_str[: len(text_str) - len(text_str.lstrip())]
+                                    trailing_ws = text_str[len(text_str.rstrip()) :]
+                                    text_node.replace_with(leading_ws + normalized + trailing_ws)
+        
+        # 如果没有任何数值修改,直接返回原始 HTML
+        if not changes:
+            return original_table_html
         
         # 获取修改后的HTML
         modified_html = str(soup)
         
-        # 如果内容没有变化,返回原始HTML(保持原始格式)
-        # 检查是否只是格式变化(换行、空格等)
-        original_text_only = re.sub(r'\s+', '', original_table_html)
-        modified_text_only = re.sub(r'\s+', '', modified_html)
-        
-        if original_text_only == modified_text_only:
-            # 只有格式变化,返回原始HTML以保留换行符
-            return original_table_html
+        # 在表格后追加注释,说明哪些单元格被修改
+        lines = ["<!-- 数字标准化说明:"]
+        for ch in changes:
+            lines.append(
+                f"  - [row={ch['row']},col={ch['col']}] {ch['old']} -> {ch['new']}"
+            )
+        lines.append("-->")
+        comment = "\n".join(lines)
         
-        # 有实际内容变化,返回修改后的HTML
-        return modified_html
+        return modified_html + "\n\n" + comment
     
     # 使用正则替换,只替换表格内容,保留其他部分(包括换行符)不变
     normalized_content = re.sub(table_pattern, normalize_table_match, markdown_content, flags=re.DOTALL)
@@ -162,6 +242,7 @@ def normalize_json_table(json_content: str) -> str:
     ]
     """
     import json
+    from ast import literal_eval
     
     try:
         # 解析JSON内容
@@ -185,23 +266,43 @@ def normalize_json_table(json_content: str) -> str:
                 
                 soup = BeautifulSoup(table_html, 'html.parser')
                 tables = soup.find_all('table')
+
+                table_changes: list[dict] = []
                 
                 for table in tables:
-                    if isinstance(table, Tag):
-                        cells = table.find_all(['td', 'th'])
-                        for cell in cells:
-                            if isinstance(cell, Tag):
-                                original_text = cell.get_text()
-                                
-                                # 应用数字标准化
-                                normalized_text = normalize_financial_numbers(original_text)
-                                
-                                # 如果内容发生了变化,更新单元格内容
-                                if original_text != normalized_text:
-                                    cell.string = normalized_text
+                    if not isinstance(table, Tag):
+                        continue
+                    # 通过 tr / td(th) 计算行列位置
+                    for row_idx, tr in enumerate(table.find_all('tr')):  # type: ignore[reportAttributeAccessIssue]
+                        cells = tr.find_all(['td', 'th'])  # type: ignore[reportAttributeAccessIssue]
+                        for col_idx, cell in enumerate(cells):
+                            if not isinstance(cell, Tag):
+                                continue
+                            original_text = cell.get_text()
+                            normalized_text = normalize_financial_numbers(original_text)
+                            if original_text == normalized_text:
+                                continue
+                            # 记录本单元格的变更
+                            change: dict[str, object] = {
+                                "row": row_idx,
+                                "col": col_idx,
+                                "old": original_text,
+                                "new": normalized_text,
+                            }
+                            bbox_attr = cell.get("data-bbox")
+                            if isinstance(bbox_attr, str):
+                                try:
+                                    change["bbox"] = literal_eval(bbox_attr)
+                                except Exception:
+                                    change["bbox"] = bbox_attr
+                            table_changes.append(change)
+                            # 更新单元格内容(简单覆盖文本即可)
+                            cell.string = normalized_text
                 
-                # 更新item中的表格内容
+                # 更新 item 中的表格内容
                 item['text'] = str(soup)
+                if table_changes:
+                    item['number_normalization_changes'] = table_changes
             
             # 同时标准化普通文本中的数字(如果需要)
             # elif 'text' in item:
@@ -266,4 +367,48 @@ def normalize_json_file(file_path: str, output_path: str | None = None) -> str:
     
     print(f"📄 标准化结果已保存到: {output_file}")
     return normalized_content
+    
+
+if __name__ == "__main__":
+    """
+    简单验证:构造一份“故意打乱逗号/小数点”的 JSON / Markdown 示例,
+    并打印标准化前后的差异。
+    """
+    import json
+
+    print("=== JSON 示例:金额格式纠错 + 变更记录 ===")
+    demo_json_data = [
+        {
+            "category": "Table",
+            "text": (
+                "<table><tbody>"
+                "<tr><td data-bbox=\"[0,0,10,10]\">项目</td>"
+                "<td data-bbox=\"[10,0,20,10]\">2023 年12 月31 日</td></tr>"
+                # 故意打乱的数字:应为 12,123,456.00 和 1,234,567.89
+                "<tr><td data-bbox=\"[0,10,10,20]\">测试金额A</td>"
+                "<td data-bbox=\"[10,10,20,20]\">12.123,456,00</td></tr>"
+                "<tr><td data-bbox=\"[0,20,10,30]\">测试金额B</td>"
+                "<td data-bbox=\"[10,20,20,30]\">1,234,567,89</td></tr>"
+                "</tbody></table>"
+            ),
+        }
+    ]
+    demo_json_str = json.dumps(demo_json_data, ensure_ascii=False, indent=2)
+    print("原始 JSON:")
+    print(demo_json_str)
+    normalized_json_str = normalize_json_table(demo_json_str)
+    print("\n标准化后 JSON:")
+    print(normalized_json_str)
 
+    print("\n=== Markdown 示例:金额格式纠错 + 注释说明 ===")
+    demo_md = """<table><tbody>
+<tr><td>项目</td><td>2023 年12 月31 日</td></tr>
+<tr><td>测试金额A</td><td>12.123,456,00</td></tr>
+<tr><td>测试金额B</td><td>1,234,567,89</td></tr>
+</tbody></table>
+"""
+    print("原始 Markdown:")
+    print(demo_md)
+    normalized_md = normalize_markdown_table(demo_md)
+    print("\n标准化后 Markdown:")
+    print(normalized_md)