Răsfoiți Sursa

feat: 添加自定义Numpy编码器以支持JSON格式化,优化输出中的numpy数据处理

zhch158_admin 2 zile în urmă
părinte
comite
985c776438
2 a modificat fișierele cu 29 adăugiri și 2 ștergeri
  1. 15 1
      ocr_utils/json_formatters.py
  2. 14 1
      ocr_utils/output_formatter_v2.py

+ 15 - 1
ocr_utils/json_formatters.py

@@ -9,6 +9,7 @@ JSON 格式化工具模块
 """
 """
 import json
 import json
 import sys
 import sys
+import numpy as np
 from pathlib import Path
 from pathlib import Path
 from typing import Dict, Any, List, Optional
 from typing import Dict, Any, List, Optional
 from loguru import logger
 from loguru import logger
@@ -16,6 +17,19 @@ from loguru import logger
 # 导入数字标准化工具
 # 导入数字标准化工具
 from .normalize_financial_numbers import normalize_json_table
 from .normalize_financial_numbers import normalize_json_table
 
 
+
+class NumpyEncoder(json.JSONEncoder):
+    """自定义JSON编码器,处理numpy类型"""
+    def default(self, obj):
+        if isinstance(obj, np.integer):
+            return int(obj)
+        elif isinstance(obj, np.floating):
+            return float(obj)
+        elif isinstance(obj, np.ndarray):
+            return obj.tolist()
+        return super().default(obj)
+
+
 class JSONFormatters:
 class JSONFormatters:
     """JSON 格式化工具类"""
     """JSON 格式化工具类"""
     
     
@@ -236,7 +250,7 @@ class JSONFormatters:
                     page_elements.append(converted)
                     page_elements.append(converted)
             
             
             # 转换为 JSON 字符串
             # 转换为 JSON 字符串
-            json_content = json.dumps(page_elements, ensure_ascii=False, indent=2)
+            json_content = json.dumps(page_elements, ensure_ascii=False, indent=2, cls=NumpyEncoder)
             
             
             # 金额数字标准化
             # 金额数字标准化
             if normalize_numbers:
             if normalize_numbers:

+ 14 - 1
ocr_utils/output_formatter_v2.py

@@ -19,6 +19,7 @@
 """
 """
 import json
 import json
 import sys
 import sys
+import numpy as np
 from pathlib import Path
 from pathlib import Path
 from typing import Dict, Any, List, Optional
 from typing import Dict, Any, List, Optional
 from loguru import logger
 from loguru import logger
@@ -33,6 +34,18 @@ from .visualization_utils import VisualizationUtils
 from .normalize_financial_numbers import normalize_markdown_table, normalize_json_table
 from .normalize_financial_numbers import normalize_markdown_table, normalize_json_table
 
 
 
 
+class NumpyEncoder(json.JSONEncoder):
+    """自定义JSON编码器,处理numpy类型"""
+    def default(self, obj):
+        if isinstance(obj, np.integer):
+            return int(obj)
+        elif isinstance(obj, np.floating):
+            return float(obj)
+        elif isinstance(obj, np.ndarray):
+            return obj.tolist()
+        return super().default(obj)
+
+
 class OutputFormatterV2:
 class OutputFormatterV2:
     """
     """
     统一输出格式化器
     统一输出格式化器
@@ -161,7 +174,7 @@ class OutputFormatterV2:
         # 3. 保存 middle.json
         # 3. 保存 middle.json
         if output_config.get('save_json', True):
         if output_config.get('save_json', True):
             json_path = doc_output_dir / f"{doc_name}_middle.json"
             json_path = doc_output_dir / f"{doc_name}_middle.json"
-            json_content = json.dumps(middle_json, ensure_ascii=False, indent=2)
+            json_content = json.dumps(middle_json, ensure_ascii=False, indent=2, cls=NumpyEncoder)
             
             
             # 金额数字标准化
             # 金额数字标准化
             normalize_numbers = output_config.get('normalize_numbers', True)
             normalize_numbers = output_config.get('normalize_numbers', True)