Bläddra i källkod

feat: 新增工具模块,包含CUDA环境检测、文件处理、文档处理及金融数字标准化功能

zhch158_admin 1 månad sedan
förälder
incheckning
eae609a851

+ 59 - 0
zhch/utils/__init__.py

@@ -0,0 +1,59 @@
+"""
+工具模块包
+
+包含CUDA环境检测、文件处理等实用工具函数
+"""
+
+from .check_cuda_env import (
+    check_nvidia_environment,
+)
+
+from .cuda_utils import (
+    monitor_gpu_memory,
+    detect_available_gpus
+)
+
+from .file_utils import (
+    get_image_files_from_dir,
+    get_image_files_from_list,
+    get_image_files_from_csv,
+    split_files,
+    create_temp_file_list,
+    collect_pid_files,
+    get_input_files,
+)
+
+from .doc_utils import (
+    load_images_from_pdf,
+    fitz_doc_to_image,
+)
+
+from .normalize_financial_numbers import (
+    normalize_financial_numbers,
+    normalize_markdown_table,
+    normalize_json_table,
+)
+
+__all__ = [
+    # CUDA环境检测
+    'check_nvidia_environment',
+    # CUDA工具
+    'monitor_gpu_memory',
+    'detect_available_gpus',
+    # 文件工具
+    'get_image_files_from_dir',
+    'get_image_files_from_list',
+    'get_image_files_from_csv',
+    'split_files',
+    'create_temp_file_list',
+    'collect_pid_files',
+    'get_input_files',
+    # 金融数字标准化
+    'normalize_financial_numbers',
+    'normalize_markdown_table',
+    'normalize_json_table',
+]
+
+__version__ = "1.0.0"
+__author__ = "zhch158"
+__description__ = "PaddleX工具模块包,提供CUDA环境检测和文件处理功能"

+ 54 - 0
zhch/utils/check_cuda_env.py

@@ -0,0 +1,54 @@
+import subprocess
+import sys
+
+def check_nvidia_environment():
+    print("=== NVIDIA环境检查 ===")
+    
+    # 检查nvidia-smi
+    try:
+        result = subprocess.run(['nvidia-smi'], capture_output=True, text=True, timeout=10)
+        if result.returncode == 0:
+            print("✓ nvidia-smi 正常运行")
+            # 提取驱动版本
+            lines = result.stdout.split('\n')
+            for line in lines:
+                if 'Driver Version:' in line:
+                    print(f"  {line.strip()}")
+                    break
+        else:
+            print("✗ nvidia-smi 失败:")
+            print(result.stderr)
+    except Exception as e:
+        print(f"✗ nvidia-smi 错误: {e}")
+    
+    # 检查NVML
+    try:
+        import pynvml
+        pynvml.nvmlInit()
+        driver_version = pynvml.nvmlSystemGetDriverVersion()
+        print(f"✓ NVML初始化成功,驱动版本: {driver_version}")
+        
+        device_count = pynvml.nvmlDeviceGetCount()
+        print(f"✓ 检测到 {device_count} 个GPU设备")
+        
+        for i in range(device_count):
+            handle = pynvml.nvmlDeviceGetHandleByIndex(i)
+            name = pynvml.nvmlDeviceGetName(handle)
+            print(f"  GPU {i}: {name}")
+            
+    except Exception as e:
+        print(f"✗ NVML错误: {e}")
+    
+    # 检查PaddlePaddle
+    try:
+        import paddle
+        print(f"✓ PaddlePaddle版本: {paddle.__version__}")
+        print(f"✓ CUDA编译支持: {paddle.device.is_compiled_with_cuda()}")
+        if paddle.device.is_compiled_with_cuda():
+            gpu_count = paddle.device.cuda.device_count()
+            print(f"✓ PaddlePaddle检测到 {gpu_count} 个GPU")
+    except Exception as e:
+        print(f"✗ PaddlePaddle错误: {e}")
+
+if __name__ == "__main__":
+    check_nvidia_environment()

+ 24 - 0
zhch/utils/cuda_utils.py

@@ -0,0 +1,24 @@
+import torch
+from typing import List
+def detect_available_gpus() -> List[int]:
+    """检测可用的GPU"""
+    try:
+        gpu_count = torch.cuda.device_count()
+        available_gpus = list(range(gpu_count))
+        print(f"检测到 {gpu_count} 个可用GPU: {available_gpus}")
+        return available_gpus
+    except Exception as e:
+        print(f"GPU检测失败: {e}")
+        return []
+
+def monitor_gpu_memory(gpu_ids: List[int] = [0, 1]):
+    """监控GPU内存使用情况"""
+    try:
+        for gpu_id in gpu_ids:
+            torch.cuda.set_device(gpu_id)
+            total = torch.cuda.get_device_properties(gpu_id).total_memory / 1024**3
+            allocated = torch.cuda.memory_allocated(gpu_id) / 1024**3
+            reserved = torch.cuda.memory_reserved(gpu_id) / 1024**3
+            print(f"GPU {gpu_id} - 显存: {total:.2f}GB, 已分配: {allocated:.2f}GB, 已预留: {reserved:.2f}GB")
+    except Exception as e:
+        print(f"GPU内存监控失败: {e}")

+ 59 - 0
zhch/utils/doc_utils.py

@@ -0,0 +1,59 @@
+import fitz
+import numpy as np
+import enum
+from pydantic import BaseModel, Field
+from PIL import Image
+
+class SupportedPdfParseMethod(enum.Enum):
+    OCR = 'ocr'
+    TXT = 'txt'
+
+
+class PageInfo(BaseModel):
+    """The width and height of page
+    """
+    w: float = Field(description='the width of page')
+    h: float = Field(description='the height of page')
+
+
+def fitz_doc_to_image(doc, target_dpi=200, origin_dpi=None) -> Image.Image:
+    """Convert fitz.Document to image, Then convert the image to numpy array.
+
+    Args:
+        doc (_type_): pymudoc page
+        dpi (int, optional): reset the dpi of dpi. Defaults to 200.
+
+    Returns:
+        dict:  {'img': numpy array, 'width': width, 'height': height }
+    """
+    from PIL import Image
+    mat = fitz.Matrix(target_dpi / 72, target_dpi / 72)
+    pm = doc.get_pixmap(matrix=mat, alpha=False)
+
+    if pm.width > 4500 or pm.height > 4500:
+        mat = fitz.Matrix(72 / 72, 72 / 72)  # use fitz default dpi
+        pm = doc.get_pixmap(matrix=mat, alpha=False)
+
+    image = Image.frombytes('RGB', (pm.width, pm.height), pm.samples)
+    return image
+
+
+def load_images_from_pdf(pdf_file, dpi=200, start_page_id=0, end_page_id=None) -> list:
+    images = []
+    with fitz.open(pdf_file) as doc:
+        pdf_page_num = doc.page_count
+        end_page_id = (
+            end_page_id
+            if end_page_id is not None and end_page_id >= 0
+            else pdf_page_num - 1
+        )
+        if end_page_id > pdf_page_num - 1:
+            print('end_page_id is out of range, use images length')
+            end_page_id = pdf_page_num - 1
+
+        for index in range(0, doc.page_count):
+            if start_page_id <= index <= end_page_id:
+                page = doc[index]
+                img = fitz_doc_to_image(page, target_dpi=dpi)
+                images.append(img)
+    return images

+ 292 - 0
zhch/utils/file_utils.py

@@ -0,0 +1,292 @@
+import tempfile
+from pathlib import Path
+from typing import List, Tuple
+import json
+from .doc_utils import load_images_from_pdf
+import traceback
+
+def split_files(file_list: List[str], num_splits: int) -> List[List[str]]:
+    """
+    将文件列表分割成指定数量的子列表
+    
+    Args:
+        file_list: 文件路径列表
+        num_splits: 分割数量
+        
+    Returns:
+        分割后的文件列表
+    """
+    if num_splits <= 0:
+        return [file_list]
+    
+    chunk_size = len(file_list) // num_splits
+    remainder = len(file_list) % num_splits
+    
+    chunks = []
+    start = 0
+    
+    for i in range(num_splits):
+        # 前remainder个chunk多分配一个文件
+        current_chunk_size = chunk_size + (1 if i < remainder else 0)
+        if current_chunk_size > 0:
+            chunks.append(file_list[start:start + current_chunk_size])
+            start += current_chunk_size
+    
+    return [chunk for chunk in chunks if chunk]  # 过滤空列表
+
+def create_temp_file_list(file_chunk: List[str]) -> str:
+    """
+    创建临时文件列表文件
+    
+    Args:
+        file_chunk: 文件路径列表
+        
+    Returns:
+        临时文件路径
+    """
+    with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
+        for file_path in file_chunk:
+            f.write(f"{file_path}\n")
+        return f.name
+
+def get_image_files_from_dir(input_dir: Path, pattern: str = "*", max_files: int = None) -> List[str]:
+    """
+    从目录获取图像文件列表
+    
+    Args:
+        input_dir: 输入目录
+        max_files: 最大文件数量限制
+        
+    Returns:
+        图像文件路径列表
+    """
+    image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif']
+    image_files = []
+    
+    for ext in image_extensions:
+        image_files.extend(list(input_dir.glob(f"{pattern}{ext}")))
+        image_files.extend(list(input_dir.glob(f"{pattern}{ext.upper()}")))
+
+    # 去重并排序
+    image_files = sorted(list(set(str(f) for f in image_files)))
+    
+    # 限制文件数量
+    if max_files:
+        image_files = image_files[:max_files]
+    
+    return image_files
+
+def get_image_files_from_list(file_list_path: str) -> List[str]:
+    """
+    从文件列表获取图像文件列表
+    
+    Args:
+        file_list_path: 文件列表路径
+        
+    Returns:
+        图像文件路径列表
+    """
+    print(f"📄 Reading file list from: {file_list_path}")
+    
+    with open(file_list_path, 'r', encoding='utf-8') as f:
+        image_files = [line.strip() for line in f if line.strip()]
+    
+    # 验证文件存在性
+    valid_files = []
+    missing_files = []
+    
+    for file_path in image_files:
+        if Path(file_path).exists():
+            valid_files.append(file_path)
+        else:
+            missing_files.append(file_path)
+    
+    if missing_files:
+        print(f"⚠️ Warning: {len(missing_files)} files not found:")
+        for missing_file in missing_files[:5]:  # 只显示前5个
+            print(f"  - {missing_file}")
+        if len(missing_files) > 5:
+            print(f"  ... and {len(missing_files) - 5} more")
+    
+    print(f"✅ Found {len(valid_files)} valid files out of {len(image_files)} in list")
+    return valid_files
+
+def get_image_files_from_csv(csv_file: str, status_filter: str = "fail") -> List[str]:
+    """
+    从CSV文件获取图像文件列表
+
+    Args:
+        csv_file: CSV文件路径
+        status_filter: 状态过滤器
+
+    Returns:
+        图像文件路径列表
+    """
+    print(f"📄 Reading image files from CSV: {csv_file}")
+
+	# 读取CSV文件, 表头:image_path,status
+    image_files = []
+    with open(csv_file, 'r', encoding='utf-8') as f:
+        for line in f:
+            # 需要去掉表头, 按“,”分割,读取文件名,状态
+            image_file, status = line.strip().split(",")
+            if status.lower() == status_filter.lower():
+                image_files.append(image_file)
+
+    return image_files
+
+
+def collect_pid_files(pid_output_file: str) -> List[Tuple[str, str]]:
+    """
+    从进程输出文件中收集文件
+
+    Args:
+        pid_output_file: 进程输出文件路径
+
+    Returns:
+        文件列表(文件路径,处理结果)
+    """
+
+    """
+    单进程结果统计文件格式
+    "results": [
+    {
+      "image_path": "docstructbench_dianzishu_zhongwenzaixian-o.O-61520612.pdf_140.jpg",
+      "processing_time": 2.0265579223632812e-06,
+      "success": true,
+      "device": "gpu:3",
+      "output_json": "/home/ubuntu/zhch/PaddleX/zhch/OmniDocBench_Results_Scheduler/process_3/docstructbench_dianzishu_zhongwenzaixian-o.O-61520612.pdf_140.json",
+      "output_md": "/home/ubuntu/zhch/PaddleX/zhch/OmniDocBench_Results_Scheduler/process_3/docstructbench_dianzishu_zhongwenzaixian-o.O-61520612.pdf_140.md"
+    },
+    ...
+    """
+    if not Path(pid_output_file).exists():
+        print(f"⚠️ Warning: PID output file not found: {pid_output_file}")
+        return []
+
+    with open(pid_output_file, 'r', encoding='utf-8') as f:
+        data = json.load(f)
+
+    if not isinstance(data, dict) or "results" not in data:
+        print(f"⚠️ Warning: Invalid PID output file format: {pid_output_file}")
+        return []
+    # 返回文件路径和处理状态, 如果“success”: True, 则状态为“success”, 否则为“fail”
+    file_list = []
+    for file_result in data.get("results", []):
+        image_path = file_result.get("image_path", "")
+        status = "success" if file_result.get("success", False) else "fail"
+        file_list.append((image_path, status))
+    return file_list
+
+def convert_pdf_to_images(pdf_file: str, output_dir: str | None = None, dpi: int = 200) -> List[str]:
+    """
+    将PDF转换为图像文件
+    
+    Args:
+        pdf_file: PDF文件路径
+        output_dir: 输出目录
+        dpi: 图像分辨率
+        
+    Returns:
+        生成的图像文件路径列表
+    """
+    pdf_path = Path(pdf_file)
+    if not pdf_path.exists() or pdf_path.suffix.lower() != '.pdf':
+        print(f"❌ Invalid PDF file: {pdf_path}")
+        return []
+
+    # 如果没有指定输出目录,使用PDF同名目录
+    if output_dir is None:
+        output_path = pdf_path.parent / f"{pdf_path.stem}"
+    else:
+        output_path = Path(output_dir) / f"{pdf_path.stem}"
+    output_path = output_path.resolve()
+    output_path.mkdir(parents=True, exist_ok=True)
+
+    try:
+        # 使用doc_utils中的函数加载PDF图像
+        images = load_images_from_pdf(str(pdf_path), dpi=dpi)
+        
+        image_paths = []
+        for i, image in enumerate(images):
+            # 生成图像文件名
+            image_filename = f"{pdf_path.stem}_page_{i+1:03d}.png"
+            image_path = output_path / image_filename
+
+            # 保存图像
+            image.save(str(image_path))
+            image_paths.append(str(image_path))
+            
+        print(f"✅ Converted {len(images)} pages from {pdf_path.name} to images")
+        return image_paths
+        
+    except Exception as e:
+        print(f"❌ Error converting PDF {pdf_path}: {e}")
+        traceback.print_exc()
+        return []
+
+def get_input_files(args) -> List[str]:
+    """
+    获取输入文件列表,统一处理PDF和图像文件
+    
+    Args:
+        args: 命令行参数
+        
+    Returns:
+        处理后的图像文件路径列表
+    """
+    input_files = []
+    
+    # 获取原始输入文件
+    if args.input_csv:
+        raw_files = get_image_files_from_csv(args.input_csv, "fail")
+    elif args.input_file_list:
+        raw_files = get_image_files_from_list(args.input_file_list)
+    elif args.input_file:
+        raw_files = [Path(args.input_file).resolve()]
+    else:
+        input_dir = Path(args.input_dir).resolve()
+        if not input_dir.exists():
+            print(f"❌ Input directory does not exist: {input_dir}")
+            return []
+        
+        # 获取所有支持的文件(图像和PDF)
+        image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif']
+        pdf_extensions = ['.pdf']
+        
+        raw_files = []
+        for ext in image_extensions + pdf_extensions:
+            raw_files.extend(list(input_dir.glob(f"*{ext}")))
+            raw_files.extend(list(input_dir.glob(f"*{ext.upper()}")))
+        
+        raw_files = [str(f) for f in raw_files]
+    
+    # 分别处理PDF和图像文件
+    pdf_count = 0
+    image_count = 0
+    
+    for file_path in raw_files:
+        file_path = Path(file_path)
+        
+        if file_path.suffix.lower() == '.pdf':
+            # 转换PDF为图像
+            print(f"📄 Processing PDF: {file_path.name}")
+            pdf_images = convert_pdf_to_images(
+                str(file_path), 
+                args.output_dir,
+                dpi=args.pdf_dpi
+            )
+            input_files.extend(pdf_images)
+            pdf_count += 1
+        else:
+            # 直接添加图像文件
+            if file_path.exists():
+                input_files.append(str(file_path))
+                image_count += 1
+    
+    print(f"📊 Input summary:")
+    print(f"  PDF files processed: {pdf_count}")
+    print(f"  Image files found: {image_count}")
+    print(f"  Total image files to process: {len(input_files)}")
+    
+    return sorted(list(set(str(f) for f in input_files)))

+ 250 - 0
zhch/utils/normalize_financial_numbers.py

@@ -0,0 +1,250 @@
+import re
+import os
+from pathlib import Path
+
+def normalize_financial_numbers(text: str) -> str:
+    """
+    标准化财务数字:将全角字符转换为半角字符
+    
+    Args:
+        text: 原始文本
+    
+    Returns:
+        标准化后的文本
+    """
+    if not text:
+        return text
+    
+    # 定义全角到半角的映射
+    fullwidth_to_halfwidth = {
+        '0': '0', '1': '1', '2': '2', '3': '3', '4': '4',
+        '5': '5', '6': '6', '7': '7', '8': '8', '9': '9',
+        ',': ',',  # 全角逗号转半角逗号
+        '。': '.',  # 全角句号转半角句号  
+        '.': '.',  # 全角句点转半角句点
+        ':': ':',  # 全角冒号转半角冒号
+        ';': ';',  # 全角分号转半角分号
+        '(': '(',  # 全角左括号转半角左括号
+        ')': ')',  # 全角右括号转半角右括号
+        '-': '-',  # 全角减号转半角减号
+        '+': '+',  # 全角加号转半角加号
+        '%': '%',  # 全角百分号转半角百分号
+    }
+    
+    # 第一步:执行基础字符替换
+    normalized_text = text
+    for fullwidth, halfwidth in fullwidth_to_halfwidth.items():
+        normalized_text = normalized_text.replace(fullwidth, halfwidth)
+    
+    # 第二步:处理数字序列中的空格和分隔符
+    # 修改正则表达式以匹配完整的数字序列,包括空格
+    # 匹配模式:数字 + (空格? + 逗号 + 空格? + 数字)* + (空格? + 小数点 + 数字+)?
+    number_sequence_pattern = r'(\d+(?:\s*[,,]\s*\d+)*(?:\s*[。..]\s*\d+)?)'
+    
+    def normalize_number_sequence(match):
+        sequence = match.group(1)
+        
+        # 处理千分位分隔符周围的空格
+        # 将 "数字 + 空格 + 逗号 + 空格 + 数字" 标准化为 "数字,数字"
+        sequence = re.sub(r'(\d)\s*[,,]\s*(\d)', r'\1,\2', sequence)
+        
+        # 处理小数点周围的空格
+        # 将 "数字 + 空格 + 小数点 + 空格 + 数字" 标准化为 "数字.数字"
+        sequence = re.sub(r'(\d)\s*[。..]\s*(\d)', r'\1.\2', sequence)
+        
+        return sequence
+    
+    normalized_text = re.sub(number_sequence_pattern, normalize_number_sequence, normalized_text)
+    return normalized_text
+    
+def normalize_markdown_table(markdown_content: str) -> str:
+    """
+    专门处理Markdown表格中的数字标准化
+    
+    Args:
+        markdown_content: Markdown内容
+    
+    Returns:
+        标准化后的Markdown内容
+    """
+    # 使用BeautifulSoup处理HTML表格
+    from bs4 import BeautifulSoup, Tag
+    
+    soup = BeautifulSoup(markdown_content, 'html.parser')
+    tables = soup.find_all('table')
+    
+    for table in tables:
+        if isinstance(table, Tag):
+            cells = table.find_all(['td', 'th'])
+            for cell in cells:
+                if isinstance(cell, Tag):
+                    original_text = cell.get_text()
+                    normalized_text = normalize_financial_numbers(original_text)
+                    
+                    # 如果内容发生了变化,更新单元格内容
+                    if original_text != normalized_text:
+                        cell.string = normalized_text
+    
+    # 返回更新后的HTML
+    return str(soup)
+
+def normalize_json_table(json_content: str) -> str:
+    """
+    专门处理JSON格式OCR结果中表格的数字标准化
+    
+    Args:
+        json_content: JSON格式的OCR结果内容
+    
+    Returns:
+        标准化后的JSON内容
+    """
+    """
+    json_content 示例:
+    [
+        {
+            "category": "Table",
+            "text": "<table>...</table>"
+        },
+        {
+            "category": "Text",
+            "text": "Some other text"
+        }
+    ]
+    """
+    import json
+    
+    try:
+        # 解析JSON内容
+        data = json.loads(json_content) if isinstance(json_content, str) else json_content
+        
+        # 确保data是列表格式
+        if not isinstance(data, list):
+            return json_content
+        
+        # 遍历所有OCR结果项
+        for item in data:
+            if not isinstance(item, dict):
+                continue
+                
+            # 检查是否是表格类型
+            if item.get('category') == 'Table' and 'text' in item:
+                table_html = item['text']
+                
+                # 使用BeautifulSoup处理HTML表格
+                from bs4 import BeautifulSoup, Tag
+                
+                soup = BeautifulSoup(table_html, 'html.parser')
+                tables = soup.find_all('table')
+                
+                for table in tables:
+                    if isinstance(table, Tag):
+                        cells = table.find_all(['td', 'th'])
+                        for cell in cells:
+                            if isinstance(cell, Tag):
+                                original_text = cell.get_text()
+                                
+                                # 应用数字标准化
+                                normalized_text = normalize_financial_numbers(original_text)
+                                
+                                # 如果内容发生了变化,更新单元格内容
+                                if original_text != normalized_text:
+                                    cell.string = normalized_text
+                
+                # 更新item中的表格内容
+                item['text'] = str(soup)
+            
+            # 同时标准化普通文本中的数字(如果需要)
+            # elif 'text' in item:
+            #     original_text = item['text']
+            #     normalized_text = normalize_financial_numbers(original_text)
+            #     if original_text != normalized_text:
+            #         item['text'] = normalized_text
+        
+        # 返回标准化后的JSON字符串
+        return json.dumps(data, ensure_ascii=False, indent=2)
+        
+    except json.JSONDecodeError as e:
+        print(f"⚠️ JSON解析失败: {e}")
+        return json_content
+    except Exception as e:
+        print(f"⚠️ JSON表格标准化失败: {e}")
+        return json_content
+
+def normalize_json_file(file_path: str, output_path: str | None = None) -> str:
+    """
+    标准化JSON文件中的表格数字
+    
+    Args:
+        file_path: 输入JSON文件路径
+        output_path: 输出文件路径,如果为None则覆盖原文件
+    
+    Returns:
+        标准化后的JSON内容
+    """
+    input_file = Path(file_path)
+    output_file = Path(output_path) if output_path else input_file
+    
+    if not input_file.exists():
+        raise FileNotFoundError(f"找不到文件: {file_path}")
+    
+    # 读取原始JSON文件
+    with open(input_file, 'r', encoding='utf-8') as f:
+        original_content = f.read()
+    
+    print(f"🔧 正在标准化JSON文件: {input_file.name}")
+    
+    # 标准化内容
+    normalized_content = normalize_json_table(original_content)
+    
+    # 保存标准化后的文件
+    with open(output_file, 'w', encoding='utf-8') as f:
+        f.write(normalized_content)
+    
+    # 统计变化
+    changes = sum(1 for o, n in zip(original_content, normalized_content) if o != n)
+    if changes > 0:
+        print(f"✅ 标准化了 {changes} 个字符")
+        
+        # 如果输出路径不同,也保存原始版本
+        if output_path and output_path != file_path:
+            original_backup = Path(output_path).parent / f"{Path(output_path).stem}_original.json"
+            with open(original_backup, 'w', encoding='utf-8') as f:
+                f.write(original_content)
+            print(f"📄 原始版本已保存到: {original_backup}")
+    else:
+        print("ℹ️ 无需标准化(已是标准格式)")
+    
+    print(f"📄 标准化结果已保存到: {output_file}")
+    return normalized_content
+
+if __name__ == "__main__":
+    # 简单测试
+    test_strings = [
+        "28, 239, 305.48",
+        "2023年净利润为28,239,305.48元",
+        "总资产为1,234,567.89元",
+        "负债总额为500,000.00元",
+        "收入增长了10.5%,达到1,200,000元",
+        "费用为300,000元",
+        "利润率为15.2%",
+        "现金流量为-50,000元",
+        "股东权益为2,500,000.00元",
+        "每股收益为3.25元",
+        "市盈率为20.5倍",
+        "营业收入为750,000元",
+        "净资产收益率为12.3%",
+        "总负债为1,200,000元",
+        "流动比率为1.5倍",
+        "速动比率为1.2倍",
+        "资产负债率为40%",
+        "存货周转率为6次/年",
+        "应收账款周转率为8次/年",
+        "固定资产周转率为2次/年",
+        "总资产周转率为1.2次/年",
+        "经营活动产生的现金流量净额为200,000元"
+    ]
+    
+    for s in test_strings:
+        print("原始: ", s)
+        print("标准化: ", normalize_financial_numbers(s))
+        print("-" * 50)

+ 40 - 0
zhch/utils/verify_flash_attention.py

@@ -0,0 +1,40 @@
+import torch
+import subprocess
+import pkg_resources
+
+def check_flash_attention():
+    print("🔍 Flash Attention 状态检查")
+    print("=" * 50)
+    
+    # 检查已安装的包
+    try:
+        flash_attn_version = pkg_resources.get_distribution("flash-attn").version
+        print(f"✅ flash-attn: {flash_attn_version}")
+    except:
+        print("❌ flash-attn: 未安装")
+    
+    try:
+        flashinfer_version = pkg_resources.get_distribution("flashinfer").version
+        print(f"✅ flashinfer: {flashinfer_version}")
+    except:
+        print("❌ flashinfer: 未安装")
+    
+    # 检查 CUDA 可用性
+    print(f"\n🔧 CUDA 状态:")
+    print(f"CUDA 可用: {torch.cuda.is_available()}")
+    if torch.cuda.is_available():
+        print(f"CUDA 版本: {torch.version.cuda}")
+        print(f"GPU 数量: {torch.cuda.device_count()}")
+        for i in range(torch.cuda.device_count()):
+            print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
+    
+    # 检查 Flash Attention 功能
+    try:
+        import flash_attn
+        print(f"\n✅ Flash Attention 可导入")
+        print(f"Flash Attention 版本: {flash_attn.__version__}")
+    except ImportError as e:
+        print(f"\n❌ Flash Attention 导入失败: {e}")
+
+if __name__ == "__main__":
+    check_flash_attention()