소스 검색

feat: 添加数字标准化功能,支持Markdown和JSON内容的全角字符转换为半角字符

zhch158_admin 1 개월 전
부모
커밋
b7bf47fd78
1개의 변경된 파일56개의 추가작업 그리고 16개의 파일을 삭제
  1. 56 16
      zhch/OmniDocBench_DotsOCR_multthreads.py

+ 56 - 16
zhch/OmniDocBench_DotsOCR_multthreads.py

@@ -33,7 +33,9 @@ from utils import (
     get_image_files_from_dir,
     get_image_files_from_list,
     get_image_files_from_csv,
-    collect_pid_files
+    collect_pid_files,
+    normalize_markdown_table,
+    normalize_json_table
 )
 
 def convert_pdf_to_images(pdf_file: str, output_dir: str | None = None, dpi: int = 200) -> List[str]:
@@ -159,7 +161,8 @@ class DotsOCRProcessor:
                  prompt_mode: str = "prompt_layout_all_en",
                  dpi: int = 200,
                  min_pixels: int = MIN_PIXELS,
-                 max_pixels: int = MAX_PIXELS):
+                 max_pixels: int = MAX_PIXELS,
+                 normalize_numbers: bool = False):
         """
         初始化处理器
         
@@ -179,7 +182,8 @@ class DotsOCRProcessor:
         self.dpi = dpi
         self.min_pixels = min_pixels
         self.max_pixels = max_pixels
-        
+        self.normalize_numbers = normalize_numbers
+
         # 初始化解析器
         self.parser = DotsOCRParser(
             ip=ip,
@@ -232,9 +236,27 @@ class DotsOCRProcessor:
             else:
                 md_content = "# 解析失败\n\n未能提取到有效的文档内容。"
             
+            # 如果启用数字标准化,处理 Markdown 内容
+            original_text = md_content
+            if self.normalize_numbers:
+                # generated_text = normalize_financial_numbers(generated_text)
+                # 只对Markdown表格进行数字标准化
+                generated_text = normalize_markdown_table(md_content)
+                
+                # 统计标准化的变化
+                changes_count = len([1 for o, n in zip(original_text, generated_text) if o != n])
+                if changes_count > 0:
+                    saved_files['md_normalized'] = f"✅ 已标准化 {changes_count} 个字符(全角→半角)"
+                else:
+                    saved_files['md_normalized'] = "ℹ️ 无需标准化(已是标准格式)"
             with open(output_md_path, 'w', encoding='utf-8') as f:
-                f.write(md_content)
+                f.write(generated_text)
             saved_files['md'] = output_md_path
+            # 如果启用了标准化,也保存原始版本用于对比
+            if self.normalize_numbers and original_text != generated_text:
+                original_markdown_path = Path(output_dir) / f"{Path(image_name).stem}_original.md"
+                with open(original_markdown_path, 'w', encoding='utf-8') as f:
+                    f.write(original_text)
             
             # 2. 保存 JSON 文件
             output_json_path = os.path.join(output_dir, f"{image_name}.json")
@@ -242,16 +264,29 @@ class DotsOCRProcessor:
             
             if 'layout_info_path' in result and os.path.exists(result['layout_info_path']):
                 with open(result['layout_info_path'], 'r', encoding='utf-8') as f:
-                    json_data = json.load(f)
+                    json_content = f.read()
             else:
-                json_data = {
-                    "error": "未能提取到有效的布局信息",
-                    "cells": []
-                }
+                json_content = f'{{"error": "未能提取到有效的布局信息"}}'
             
+            # 对json中的表格内容进行数字标准化,
+            original_json_text = json_content
+            if self.normalize_numbers:
+                json_content = normalize_json_table(json_content)
+                
+                # 统计标准化的变化
+                changes_count = len([1 for o, n in zip(original_json_text, json_content) if o != n])
+                if changes_count > 0:
+                    saved_files['json_normalized'] = f"✅ 已标准化 {changes_count} 个字符(全角→半角)"
+                else:
+                    saved_files['json_normalized'] = "ℹ️ 无需标准化(已是标准格式)"
             with open(output_json_path, 'w', encoding='utf-8') as f:
-                json.dump(json_data, f, ensure_ascii=False, indent=2)
+                f.write(json_content)
             saved_files['json'] = output_json_path
+            # 如果启用了标准化,也保存原始版本用于对比
+            if self.normalize_numbers and original_json_text != json_content:
+                original_json_path = Path(output_dir) / f"{Path(image_name).stem}_original.json"
+                with open(original_json_path, 'w', encoding='utf-8') as f:
+                    f.write(original_json_text)
             
             # 3. 保存带标注的布局图片
             output_layout_image_path = os.path.join(output_dir, f"{image_name}_layout.jpg")
@@ -553,6 +588,7 @@ def main():
     parser.add_argument("--min_pixels", type=int, default=MIN_PIXELS, help="Minimum pixels")
     parser.add_argument("--max_pixels", type=int, default=MAX_PIXELS, help="Maximum pixels")
     parser.add_argument("--dpi", type=int, default=200, help="PDF processing DPI")
+    parser.add_argument('--no-normalize', action='store_true', help='禁用数字标准化')
     
     # 处理参数
     parser.add_argument("--batch_size", type=int, default=1, help="Batch size")
@@ -597,7 +633,8 @@ def main():
             prompt_mode=args.prompt_mode,
             dpi=args.dpi,
             min_pixels=args.min_pixels,
-            max_pixels=args.max_pixels
+            max_pixels=args.max_pixels,
+            normalize_numbers=not args.no_normalize
         )
         
         # 开始处理
@@ -709,17 +746,20 @@ if __name__ == "__main__":
         
         # 默认配置
         default_config = {
-            "input_dir": "../../OmniDocBench/OpenDataLab___OmniDocBench/images",
-            "output_dir": "./OmniDocBench_DotsOCR_Results",
-            # "ip": "10.192.72.11",
-            "ip": "127.0.0.1",
+            "input_file": "./sample_data/2023年度报告母公司_page_003.png",
+            "output_dir": "./sample_data",
+            "collect_results": "./sample_data/processed_files.csv",
+            # "input_dir": "../../OmniDocBench/OpenDataLab___OmniDocBench/images",
+            # "output_dir": "./OmniDocBench_DotsOCR_Results",
+            # "collect_results": "./OmniDocBench_DotsOCR_Results/processed_files.csv",
+            "ip": "10.192.72.11",
+            # "ip": "127.0.0.1",
             "port": "8101",
             "model_name": "DotsOCR",
             "prompt_mode": "prompt_layout_all_en",
             "batch_size": "1",
             "max_workers": "3",
             "dpi": "200",
-            "collect_results": "./OmniDocBench_DotsOCR_Results/processed_files.csv",
         }
         
         # 如果需要处理失败的文件,可以使用这个配置