Browse Source

feat(ocr_utils): 增强金额标准化功能,支持欧洲格式小数和JSON表格内容的标准化处理

zhch158_admin 2 weeks ago
parent
commit
6e96478c23
1 changed files with 132 additions and 109 deletions
  1. 132 109
      ocr_utils/normalize_financial_numbers.py

+ 132 - 109
ocr_utils/normalize_financial_numbers.py

@@ -70,6 +70,13 @@ def _normalize_amount_token(token: str) -> str:
                 candidate = core[:idx] + "." + core[idx + 1 :]
                 if _safe_decimal(candidate):
                     return sign + candidate
+        # 规则D:只有 ,,且仅有一个逗号、逗号后 1-2 位数字 → 欧洲格式小数,如 301,55 → 301.55
+        elif len(parts) == 2:
+            left, right = parts[0], parts[1]
+            if 1 <= len(right) <= 2 and right.isdigit() and left.isdigit():
+                candidate = left + "." + right
+                if _safe_decimal(candidate):
+                    return sign + candidate
 
     # 没有需要纠错的典型形态,直接返回原 token
     return token
@@ -165,7 +172,7 @@ def normalize_markdown_table(markdown_content: str) -> str:
                 for col_idx, cell in enumerate(cells):
                     if not isinstance(cell, Tag):
                         continue
-                    # 获取单元格纯文本
+                    # 与 normalize_json_table 一致:整格取文本、只标准化一次、再写回
                     original_text = cell.get_text()
                     normalized_text = normalize_financial_numbers(original_text)
                     if original_text == normalized_text:
@@ -179,21 +186,8 @@ def normalize_markdown_table(markdown_content: str) -> str:
                             "new": normalized_text,
                         }
                     )
-                    # 具体替换:保持原有逻辑,按文本节点逐个替换以保留空白
-                    from bs4.element import NavigableString
-                    for text_node in cell.find_all(string=True, recursive=True):
-                        if isinstance(text_node, NavigableString):
-                            text_str = str(text_node)
-                            if not text_str.strip():
-                                continue
-                            normalized = normalize_financial_numbers(text_str.strip())
-                            if normalized != text_str.strip():
-                                if text_str.strip() == text_str:
-                                    text_node.replace_with(normalized)
-                                else:
-                                    leading_ws = text_str[: len(text_str) - len(text_str.lstrip())]
-                                    trailing_ws = text_str[len(text_str.rstrip()) :]
-                                    text_node.replace_with(leading_ws + normalized + trailing_ws)
+                    # 整格替换为标准化后的文本(与 normalize_json_table 的 cell.string = normalized_text 一致)
+                    cell.string = normalized_text
         
         # 如果没有任何数值修改,直接返回原始 HTML
         if not changes:
@@ -218,102 +212,111 @@ def normalize_markdown_table(markdown_content: str) -> str:
     
     return normalized_content
 
-def normalize_json_table(json_content: str) -> str:
+def normalize_json_table(
+    json_content: str,
+    *,
+    table_type_key: str = "category",
+    table_type_value: str = "Table",
+    html_key: str = "text",
+    cells_key: str | None = None,
+) -> str:
     """
-    专门处理JSON格式OCR结果中表格的数字标准化
-    
+    专门处理JSON格式OCR结果中表格的数字标准化。
+    通过参数指定提取用的 key,以兼容不同 OCR 工具的 JSON 结构。
+
     Args:
-        json_content: JSON格式的OCR结果内容
-    
+        json_content: JSON格式的OCR结果内容(字符串或已解析的 list)
+        table_type_key: 用于判断“是否为表格”的字段名,如 "type" 或 "category"
+        table_type_value: 上述字段等于该值时视为表格,如 "table" 或 "Table"
+        html_key: 存放表格 HTML 的字段名,如 "table_body" 或 "text"
+        cells_key: 存放单元格列表的字段名,如 "table_cells";为 None 则不处理 cells,
+                   仅标准化 html_key 中的表格
+
     Returns:
-        标准化后的JSON内容
-    """
-    """
-    json_content 示例:
-    [
-        {
-            "category": "Table",
-            "text": "<table>...</table>"
-        },
-        {
-            "category": "Text",
-            "text": "Some other text"
-        }
-    ]
+        标准化后的JSON内容(字符串)
+
+    常见格式示例:
+        - 旧格式: category="Table", html 在 "text"
+          normalize_json_table(s)  # 默认即此
+        - mineru_vllm_results_cell_bbox: type="table", html 在 "table_body", cells 在 "table_cells"
+          normalize_json_table(s, table_type_key="type", table_type_value="table",
+                               html_key="table_body", cells_key="table_cells")
     """
     import json
     from ast import literal_eval
-    
+
     try:
-        # 解析JSON内容
         data = json.loads(json_content) if isinstance(json_content, str) else json_content
-        
-        # 确保data是列表格式
         if not isinstance(data, list):
             return json_content
-        
-        # 遍历所有OCR结果项
+
         for item in data:
             if not isinstance(item, dict):
                 continue
-                
-            # 检查是否是表格类型
-            if item.get('category') == 'Table' and 'text' in item:
-                table_html = item['text']
-                
-                # 使用BeautifulSoup处理HTML表格
-                from bs4 import BeautifulSoup, Tag
-                
-                soup = BeautifulSoup(table_html, 'html.parser')
-                tables = soup.find_all('table')
-
-                table_changes: list[dict] = []
-                
-                for table in tables:
-                    if not isinstance(table, Tag):
+            # 按参数判断是否为表格项,且包含 HTML
+            if item.get(table_type_key) != table_type_value or html_key not in item:
+                continue
+
+            table_html = item[html_key]
+            if not table_html or not isinstance(table_html, str):
+                continue
+
+            from bs4 import BeautifulSoup, Tag
+
+            soup = BeautifulSoup(table_html, "html.parser")
+            tables = soup.find_all("table")
+            table_changes: list[dict] = []
+
+            for table in tables:
+                if not isinstance(table, Tag):
+                    continue
+                for row_idx, tr in enumerate(table.find_all("tr")):  # type: ignore[reportAttributeAccessIssue]
+                    cells_tag = tr.find_all(["td", "th"])  # type: ignore[reportAttributeAccessIssue]
+                    for col_idx, cell in enumerate(cells_tag):
+                        if not isinstance(cell, Tag):
+                            continue
+                        original_text = cell.get_text()
+                        normalized_text = normalize_financial_numbers(original_text)
+                        if original_text == normalized_text:
+                            continue
+                        change: dict[str, object] = {
+                            "row": row_idx,
+                            "col": col_idx,
+                            "old": original_text,
+                            "new": normalized_text,
+                        }
+                        bbox_attr = cell.get("data-bbox")
+                        if isinstance(bbox_attr, str):
+                            try:
+                                change["bbox"] = literal_eval(bbox_attr)
+                            except Exception:
+                                change["bbox"] = bbox_attr
+                        table_changes.append(change)
+                        cell.string = normalized_text
+
+            # 写回 HTML
+            item[html_key] = str(soup)
+            if table_changes:
+                item["number_normalization_changes"] = table_changes
+
+            # 若指定了 cells_key,同时标准化 cells 中每格的 text(及 matched_text)
+            # for key in ("text", "matched_text"):
+            table_cell_text_keys = ["text"]
+            if cells_key and cells_key in item and isinstance(item[cells_key], list):
+                for cell in item[cells_key]:
+                    if not isinstance(cell, dict):
                         continue
-                    # 通过 tr / td(th) 计算行列位置
-                    for row_idx, tr in enumerate(table.find_all('tr')):  # type: ignore[reportAttributeAccessIssue]
-                        cells = tr.find_all(['td', 'th'])  # type: ignore[reportAttributeAccessIssue]
-                        for col_idx, cell in enumerate(cells):
-                            if not isinstance(cell, Tag):
-                                continue
-                            original_text = cell.get_text()
-                            normalized_text = normalize_financial_numbers(original_text)
-                            if original_text == normalized_text:
-                                continue
-                            # 记录本单元格的变更
-                            change: dict[str, object] = {
-                                "row": row_idx,
-                                "col": col_idx,
-                                "old": original_text,
-                                "new": normalized_text,
-                            }
-                            bbox_attr = cell.get("data-bbox")
-                            if isinstance(bbox_attr, str):
-                                try:
-                                    change["bbox"] = literal_eval(bbox_attr)
-                                except Exception:
-                                    change["bbox"] = bbox_attr
-                            table_changes.append(change)
-                            # 更新单元格内容(简单覆盖文本即可)
-                            cell.string = normalized_text
-                
-                # 更新 item 中的表格内容
-                item['text'] = str(soup)
-                if table_changes:
-                    item['number_normalization_changes'] = table_changes
-            
-            # 同时标准化普通文本中的数字(如果需要)
-            # elif 'text' in item:
-            #     original_text = item['text']
-            #     normalized_text = normalize_financial_numbers(original_text)
-            #     if original_text != normalized_text:
-            #         item['text'] = normalized_text
-        
-        # 返回标准化后的JSON字符串
+
+                    for key in table_cell_text_keys:
+                        if key not in cell or not isinstance(cell[key], str):
+                            continue
+                        orig = cell[key]
+                        norm = normalize_financial_numbers(orig)
+                        if norm != orig:
+                            cell[key] = norm
+
         return json.dumps(data, ensure_ascii=False, indent=2)
-        
+
     except json.JSONDecodeError as e:
         print(f"⚠️ JSON解析失败: {e}")
         return json_content
@@ -321,31 +324,48 @@ def normalize_json_table(json_content: str) -> str:
         print(f"⚠️ JSON表格标准化失败: {e}")
         return json_content
 
-def normalize_json_file(file_path: str, output_path: str | None = None) -> str:
+def normalize_json_file(
+    file_path: str,
+    output_path: str | None = None,
+    *,
+    table_type_key: str = "category",
+    table_type_value: str = "Table",
+    html_key: str = "text",
+    cells_key: str | None = None,
+) -> str:
     """
-    标准化JSON文件中的表格数字
-    
+    标准化JSON文件中的表格数字。
+    提取表格时使用的 key 可通过参数指定,以兼容不同 OCR 工具。
+
     Args:
         file_path: 输入JSON文件路径
         output_path: 输出文件路径,如果为None则覆盖原文件
-    
+        table_type_key: 判断表格的字段名(见 normalize_json_table)
+        table_type_value: 判断表格的字段值
+        html_key: 表格 HTML 所在字段名
+        cells_key: 单元格列表所在字段名,None 表示不处理 cells
+
     Returns:
         标准化后的JSON内容
     """
     input_file = Path(file_path)
     output_file = Path(output_path) if output_path else input_file
-    
+
     if not input_file.exists():
         raise FileNotFoundError(f"找不到文件: {file_path}")
-    
-    # 读取原始JSON文件
-    with open(input_file, 'r', encoding='utf-8') as f:
+
+    with open(input_file, "r", encoding="utf-8") as f:
         original_content = f.read()
-    
+
     print(f"🔧 正在标准化JSON文件: {input_file.name}")
-    
-    # 标准化内容
-    normalized_content = normalize_json_table(original_content)
+
+    normalized_content = normalize_json_table(
+        original_content,
+        table_type_key=table_type_key,
+        table_type_value=table_type_value,
+        html_key=html_key,
+        cells_key=cells_key,
+    )
     
     # 保存标准化后的文件
     with open(output_file, 'w', encoding='utf-8') as f:
@@ -389,6 +409,8 @@ if __name__ == "__main__":
                 "<td data-bbox=\"[10,10,20,20]\">12.123,456,00</td></tr>"
                 "<tr><td data-bbox=\"[0,20,10,30]\">测试金额B</td>"
                 "<td data-bbox=\"[10,20,20,30]\">1,234,567,89</td></tr>"
+                "<tr><td data-bbox=\"[0,20,10,40]\">测试金额C</td>"
+                "<td data-bbox=\"[10,20,20,40]\">301,55</td></tr>"
                 "</tbody></table>"
             ),
         }
@@ -405,6 +427,7 @@ if __name__ == "__main__":
 <tr><td>项目</td><td>2023 年12 月31 日</td></tr>
 <tr><td>测试金额A</td><td>12.123,456,00</td></tr>
 <tr><td>测试金额B</td><td>1,234,567,89</td></tr>
+<tr><td>测试金额C</td><td>301,55</td></tr>
 </tbody></table>
 """
     print("原始 Markdown:")