瀏覽代碼

feat: 添加自定义Numpy编码器以支持JSON格式化,优化输出中的numpy数据处理

zhch158_admin 3 天之前
父節點
當前提交
985c776438
共有 2 個文件被更改,包括 29 次插入2 次删除
  1. 15 1
      ocr_utils/json_formatters.py
  2. 14 1
      ocr_utils/output_formatter_v2.py

+ 15 - 1
ocr_utils/json_formatters.py

@@ -9,6 +9,7 @@ JSON 格式化工具模块
 """
 import json
 import sys
+import numpy as np
 from pathlib import Path
 from typing import Dict, Any, List, Optional
 from loguru import logger
@@ -16,6 +17,19 @@ from loguru import logger
 # 导入数字标准化工具
 from .normalize_financial_numbers import normalize_json_table
 
+
+class NumpyEncoder(json.JSONEncoder):
+    """自定义JSON编码器,处理numpy类型"""
+    def default(self, obj):
+        if isinstance(obj, np.integer):
+            return int(obj)
+        elif isinstance(obj, np.floating):
+            return float(obj)
+        elif isinstance(obj, np.ndarray):
+            return obj.tolist()
+        return super().default(obj)
+
+
 class JSONFormatters:
     """JSON 格式化工具类"""
     
@@ -236,7 +250,7 @@ class JSONFormatters:
                     page_elements.append(converted)
             
             # 转换为 JSON 字符串
-            json_content = json.dumps(page_elements, ensure_ascii=False, indent=2)
+            json_content = json.dumps(page_elements, ensure_ascii=False, indent=2, cls=NumpyEncoder)
             
             # 金额数字标准化
             if normalize_numbers:

+ 14 - 1
ocr_utils/output_formatter_v2.py

@@ -19,6 +19,7 @@
 """
 import json
 import sys
+import numpy as np
 from pathlib import Path
 from typing import Dict, Any, List, Optional
 from loguru import logger
@@ -33,6 +34,18 @@ from .visualization_utils import VisualizationUtils
 from .normalize_financial_numbers import normalize_markdown_table, normalize_json_table
 
 
+class NumpyEncoder(json.JSONEncoder):
+    """自定义JSON编码器,处理numpy类型"""
+    def default(self, obj):
+        if isinstance(obj, np.integer):
+            return int(obj)
+        elif isinstance(obj, np.floating):
+            return float(obj)
+        elif isinstance(obj, np.ndarray):
+            return obj.tolist()
+        return super().default(obj)
+
+
 class OutputFormatterV2:
     """
     统一输出格式化器
@@ -161,7 +174,7 @@ class OutputFormatterV2:
         # 3. 保存 middle.json
         if output_config.get('save_json', True):
             json_path = doc_output_dir / f"{doc_name}_middle.json"
-            json_content = json.dumps(middle_json, ensure_ascii=False, indent=2)
+            json_content = json.dumps(middle_json, ensure_ascii=False, indent=2, cls=NumpyEncoder)
             
             # 金额数字标准化
             normalize_numbers = output_config.get('normalize_numbers', True)