Procházet zdrojové kódy

feat: 添加多个工具模块,包括设备检测、图像处理、HTML/Markdown 处理和数字解析功能

zhch158_admin před 1 týdnem
rodič
revize
a7520b9498

+ 43 - 0
ocr_utils/__init__.py

@@ -34,6 +34,28 @@ from .file_utils import (
     parse_page_range
 )
 from .log_utils import setup_logging
+from .device_utils import get_device, get_device_name
+from .image_utils import (
+    img_decode,
+    check_img,
+    alpha_to_color,
+    preprocess_image,
+    bbox_to_points,
+    points_to_bbox,
+    rotate_image_and_coordinates
+)
+from .html_utils import (
+    find_image_in_multiple_locations,
+    process_html_images,
+    process_markdown_images,
+    process_all_images_in_content,
+    convert_html_table_to_markdown,
+    parse_html_tables
+)
+from .number_utils import (
+    parse_number,
+    normalize_text_number
+)
 
 __all__ = [
     # PDF 工具
@@ -69,6 +91,27 @@ __all__ = [
     'setup_logging',
     # bbox 工具
     'BBoxExtractor',
+    # 设备工具
+    'get_device',
+    'get_device_name',
+    # 图像处理工具
+    'img_decode',
+    'check_img',
+    'alpha_to_color',
+    'preprocess_image',
+    'bbox_to_points',
+    'points_to_bbox',
+    'rotate_image_and_coordinates',
+    # HTML/Markdown 处理工具
+    'find_image_in_multiple_locations',
+    'process_html_images',
+    'process_markdown_images',
+    'process_all_images_in_content',
+    'convert_html_table_to_markdown',
+    'parse_html_tables',
+    # 数字解析工具
+    'parse_number',
+    'normalize_text_number',
 ]
 
 

+ 91 - 0
ocr_utils/device_utils.py

@@ -0,0 +1,91 @@
+"""
+设备检测工具
+从 MinerU 移植,用于自动检测可用设备 (CUDA/MPS/NPU/CPU)
+"""
+import os
+
+try:
+    import torch
+except ImportError:
+    torch = None
+
+try:
+    import torch_npu
+except ImportError:
+    torch_npu = None
+
+
+def get_device():
+    """
+    自动检测并返回可用的设备
+    
+    优先级: CUDA > MPS > NPU > CPU
+    
+    Returns:
+        str: 设备名称 ('cuda', 'mps', 'npu', 'cpu')
+    
+    Environment Variables:
+        MINERU_DEVICE_MODE: 强制指定设备模式
+    """
+    # 支持通过环境变量强制指定设备
+    device_mode = os.getenv('MINERU_DEVICE_MODE', None)
+    if device_mode is not None:
+        return device_mode
+    
+    # 如果没有 torch,返回 cpu
+    if torch is None:
+        return "cpu"
+    
+    # 自动检测
+    if torch.cuda.is_available():
+        return "cuda"
+    elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
+        return "mps"
+    else:
+        # 尝试检测华为 NPU
+        try:
+            if torch_npu is not None and torch_npu.npu.is_available():
+                return "npu"
+        except Exception:
+            pass
+    
+    return "cpu"
+
+
+def get_device_name():
+    """
+    获取设备的友好名称
+    
+    Returns:
+        str: 设备的友好名称
+    """
+    device = get_device()
+    
+    device_names = {
+        "cuda": "NVIDIA CUDA",
+        "mps": "Apple Metal (MPS)",
+        "npu": "Huawei NPU",
+        "cpu": "CPU"
+    }
+    
+    return device_names.get(device, device.upper())
+
+
+if __name__ == "__main__":
+    """测试设备检测"""
+    print(f"🔍 Detecting available device...")
+    device = get_device()
+    device_name = get_device_name()
+    print(f"✅ Device: {device} ({device_name})")
+    
+    # 测试 torch 设备
+    if torch is not None:
+        try:
+            test_tensor = torch.tensor([1.0, 2.0, 3.0])
+            if device != "cpu":
+                test_tensor = test_tensor.to(device)
+                print(f"✅ Torch tensor moved to {device}")
+            print(f"   Tensor device: {test_tensor.device}")
+        except Exception as e:
+            print(f"⚠️  Failed to move tensor to {device}: {e}")
+

+ 317 - 0
ocr_utils/html_utils.py

@@ -0,0 +1,317 @@
+"""
+HTML/Markdown 处理工具模块
+
+提供 HTML 和 Markdown 内容的处理功能:
+- 图片引用处理(转换为 base64)
+- HTML 表格转换
+- 表格解析
+"""
+import os
+import base64
+import pandas as pd
+from io import StringIO
+from html import unescape
+from typing import List, Optional
+import re
+
+
+def find_image_in_multiple_locations(img_src: str, json_path: str) -> Optional[str]:
+    """
+    在多个可能的位置查找图片文件
+    
+    Args:
+        img_src: 图片源路径
+        json_path: JSON 文件路径(用于相对路径解析)
+        
+    Returns:
+        找到的图片完整路径,如果未找到则返回 None
+    """
+    json_dir = os.path.dirname(json_path)
+    
+    # 可能的搜索路径
+    search_paths = [
+        # 相对于JSON文件的路径
+        os.path.join(json_dir, img_src),
+        # 相对于JSON文件父目录的路径
+        os.path.join(os.path.dirname(json_dir), img_src),
+        # imgs目录(常见的图片目录)
+        os.path.join(json_dir, 'imgs', os.path.basename(img_src)),
+        os.path.join(os.path.dirname(json_dir), 'imgs', os.path.basename(img_src)),
+        # images目录
+        os.path.join(json_dir, 'images', os.path.basename(img_src)),
+        os.path.join(os.path.dirname(json_dir), 'images', os.path.basename(img_src)),
+        # 同名目录
+        os.path.join(json_dir, os.path.splitext(os.path.basename(json_path))[0], os.path.basename(img_src)),
+    ]
+    
+    # 如果是绝对路径,也加入搜索
+    if os.path.isabs(img_src):
+        search_paths.insert(0, img_src)
+    
+    # 查找存在的文件
+    for path in search_paths:
+        if os.path.exists(path):
+            return path
+    
+    return None
+
+
+def process_html_images(html_content: str, json_path: str) -> str:
+    """
+    处理HTML内容中的图片引用,将本地图片转换为base64
+    
+    Args:
+        html_content: HTML 内容
+        json_path: JSON 文件路径(用于相对路径解析)
+        
+    Returns:
+        处理后的 HTML 内容(图片已转换为 base64)
+    """
+    # 匹配HTML图片标签: <img src="path" ... />
+    img_pattern = r'<img\s+[^>]*src\s*=\s*["\']([^"\']+)["\'][^>]*/?>'
+    
+    def replace_html_image(match):
+        full_tag = match.group(0)
+        img_src = match.group(1)
+        
+        # 如果已经是base64或者网络链接,直接返回
+        if img_src.startswith('data:image') or img_src.startswith('http'):
+            return full_tag
+        
+        # 增强的图片查找
+        full_img_path = find_image_in_multiple_locations(img_src, json_path)
+        
+        # 尝试转换为base64
+        try:
+            if full_img_path and os.path.exists(full_img_path):
+                with open(full_img_path, 'rb') as img_file:
+                    img_data = img_file.read()
+                    
+                # 获取文件扩展名确定MIME类型
+                ext = os.path.splitext(full_img_path)[1].lower()
+                mime_type = {
+                    '.png': 'image/png',
+                    '.jpg': 'image/jpeg',
+                    '.jpeg': 'image/jpeg',
+                    '.gif': 'image/gif',
+                    '.bmp': 'image/bmp',
+                    '.webp': 'image/webp'
+                }.get(ext, 'image/jpeg')
+                
+                # 转换为base64
+                img_base64 = base64.b64encode(img_data).decode('utf-8')
+                data_url = f"data:{mime_type};base64,{img_base64}"
+                
+                # 替换src属性,保持其他属性不变
+                updated_tag = re.sub(
+                    r'src\s*=\s*["\'][^"\']+["\']',
+                    f'src="{data_url}"',
+                    full_tag
+                )
+                return updated_tag
+            else:
+                # 文件不存在,显示详细的错误信息
+                error_content = f"""
+                <div style="
+                    color: #d32f2f; 
+                    border: 2px dashed #d32f2f; 
+                    padding: 10px; 
+                    margin: 10px 0; 
+                    border-radius: 5px;
+                    background-color: #ffebee;
+                    text-align: center;
+                ">
+                    <strong>🖼️ 图片无法加载</strong><br>
+                    <small>原始路径: {img_src}</small><br>
+                    <small>JSON文件: {os.path.basename(json_path)}</small><br>
+                    <em>请检查图片文件是否存在</em>
+                </div>
+                """
+                return error_content
+        except Exception as e:
+            # 转换失败,返回错误信息
+            error_content = f"""
+            <div style="
+                color: #f57c00; 
+                border: 2px dashed #f57c00; 
+                padding: 10px; 
+                margin: 10px 0; 
+                border-radius: 5px;
+                background-color: #fff3e0;
+                text-align: center;
+            ">
+                <strong>⚠️ 图片处理失败</strong><br>
+                <small>文件: {img_src}</small><br>
+                <small>错误: {str(e)}</small>
+            </div>
+            """
+            return error_content
+    
+    # 替换所有HTML图片标签
+    processed_content = re.sub(img_pattern, replace_html_image, html_content, flags=re.IGNORECASE)
+    return processed_content
+
+
+def process_markdown_images(md_content: str, json_path: str) -> str:
+    """
+    处理Markdown中的图片引用,将本地图片转换为base64
+    
+    Args:
+        md_content: Markdown 内容
+        json_path: JSON 文件路径(用于相对路径解析)
+        
+    Returns:
+        处理后的 Markdown 内容(图片已转换为 base64)
+    """
+    # 匹配Markdown图片语法: ![alt](path)
+    img_pattern = r'!\[([^\]]*)\]\(([^)]+)\)'
+    
+    def replace_image(match):
+        alt_text = match.group(1)
+        img_path = match.group(2)
+        
+        # 如果已经是base64或者网络链接,直接返回
+        if img_path.startswith('data:image') or img_path.startswith('http'):
+            return match.group(0)
+        
+        # 处理相对路径
+        if not os.path.isabs(img_path):
+            # 相对于JSON文件的路径
+            json_dir = os.path.dirname(json_path)
+            full_img_path = os.path.join(json_dir, img_path)
+        else:
+            full_img_path = img_path
+        
+        # 尝试转换为base64
+        try:
+            if os.path.exists(full_img_path):
+                with open(full_img_path, 'rb') as img_file:
+                    img_data = img_file.read()
+                    
+                # 获取文件扩展名确定MIME类型
+                ext = os.path.splitext(full_img_path)[1].lower()
+                mime_type = {
+                    '.png': 'image/png',
+                    '.jpg': 'image/jpeg',
+                    '.jpeg': 'image/jpeg',
+                    '.gif': 'image/gif',
+                    '.bmp': 'image/bmp',
+                    '.webp': 'image/webp'
+                }.get(ext, 'image/jpeg')
+                
+                # 转换为base64
+                img_base64 = base64.b64encode(img_data).decode('utf-8')
+                data_url = f"data:{mime_type};base64,{img_base64}"
+                
+                return f'![{alt_text}]({data_url})'
+            else:
+                # 文件不存在,返回原始链接但添加警告
+                return f'![{alt_text} (文件不存在)]({img_path})'
+        except Exception as e:
+            # 转换失败,返回原始链接
+            return f'![{alt_text} (加载失败)]({img_path})'
+    
+    # 替换所有图片引用
+    processed_content = re.sub(img_pattern, replace_image, md_content)
+    return processed_content
+
+
+def process_all_images_in_content(content: str, json_path: str) -> str:
+    """
+    处理内容中的所有图片引用(包括Markdown和HTML格式)
+    
+    Args:
+        content: 内容(可能包含 HTML 和 Markdown)
+        json_path: JSON 文件路径
+        
+    Returns:
+        处理后的内容(所有图片已转换为 base64)
+    """
+    # 先处理HTML图片
+    content = process_html_images(content, json_path)
+    # 再处理Markdown图片
+    content = process_markdown_images(content, json_path)
+    return content
+
+
+def convert_html_table_to_markdown(content: str) -> str:
+    """
+    将HTML表格转换为Markdown表格格式 - 支持横向滚动的增强版本
+    
+    Args:
+        content: 包含 HTML 表格的内容
+        
+    Returns:
+        转换后的内容(HTML 表格已转换为 Markdown)
+    """
+    def replace_table(match):
+        table_html = match.group(0)
+        
+        # 提取所有行
+        rows = re.findall(r'<tr[^>]*>(.*?)</tr>', table_html, re.DOTALL | re.IGNORECASE)
+        if not rows:
+            return table_html
+        
+        markdown_rows = []
+        max_cols = 0
+        
+        # 处理所有行,找出最大列数
+        processed_rows = []
+        for row in rows:
+            # 提取单元格,支持 th 和 td
+            cells = re.findall(r'<t[hd][^>]*>(.*?)</t[hd]>', row, re.DOTALL | re.IGNORECASE)
+            if cells:
+                clean_cells = []
+                for cell in cells:
+                    cell_text = re.sub(r'<[^>]+>', '', cell).strip()
+                    cell_text = unescape(cell_text)
+                    # 限制单元格长度,避免表格过宽
+                    if len(cell_text) > 30:
+                        cell_text = cell_text[:27] + "..."
+                    clean_cells.append(cell_text or " ")  # 空单元格用空格替代
+                
+                processed_rows.append(clean_cells)
+                max_cols = max(max_cols, len(clean_cells))
+        
+        # 统一所有行的列数
+        for i, row_cells in enumerate(processed_rows):
+            while len(row_cells) < max_cols:
+                row_cells.append(" ")
+            
+            # 构建Markdown行
+            markdown_row = '| ' + ' | '.join(row_cells) + ' |'
+            markdown_rows.append(markdown_row)
+            
+            # 在第一行后添加分隔符
+            if i == 0:
+                separator = '| ' + ' | '.join(['---'] * max_cols) + ' |'
+                markdown_rows.append(separator)
+        
+        # 添加滚动提示
+        if max_cols > 8:
+            scroll_note = "\n> 📋 **提示**: 此表格列数较多,在某些视图中可能需要横向滚动查看完整内容。\n"
+            return scroll_note + '\n'.join(markdown_rows) if markdown_rows else table_html
+        
+        return '\n'.join(markdown_rows) if markdown_rows else table_html
+    
+    # 替换所有HTML表格
+    converted = re.sub(r'<table[^>]*>.*?</table>', replace_table, content, flags=re.DOTALL | re.IGNORECASE)
+    return converted
+
+
+def parse_html_tables(html_content: str) -> List[pd.DataFrame]:
+    """
+    解析HTML内容中的表格为DataFrame列表
+    
+    Args:
+        html_content: HTML 内容
+        
+    Returns:
+        DataFrame 列表,每个 DataFrame 对应一个表格
+    """
+    try:
+        tables = pd.read_html(StringIO(html_content))
+        return tables if tables else []
+    except Exception:
+        return []
+

+ 257 - 0
ocr_utils/image_utils.py

@@ -0,0 +1,257 @@
+"""
+图像处理工具模块
+
+提供通用的图像处理功能:
+- 图像解码和格式转换
+- Alpha 通道处理
+- 图像预处理
+- BBox 和点坐标转换
+- 图像旋转和坐标转换
+"""
+import cv2
+import numpy as np
+from typing import List, Tuple, Union
+from PIL import Image
+
+
+def img_decode(content: bytes) -> np.ndarray:
+    """
+    解码字节流为图像
+    
+    Args:
+        content: 图像字节流
+        
+    Returns:
+        np.ndarray: 解码后的图像
+    """
+    np_arr = np.frombuffer(content, dtype=np.uint8)
+    return cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED)
+
+
+def check_img(img: Union[bytes, np.ndarray]) -> np.ndarray:
+    """
+    检查并转换图像格式
+    
+    Args:
+        img: 图像(可以是 bytes 或 np.ndarray)
+        
+    Returns:
+        np.ndarray: BGR 格式图像
+    """
+    if isinstance(img, bytes):
+        img = img_decode(img)
+    if isinstance(img, np.ndarray) and len(img.shape) == 2:
+        img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+    return img
+
+
+def alpha_to_color(img: np.ndarray, alpha_color: Tuple[int, int, int] = (255, 255, 255)) -> np.ndarray:
+    """
+    将带 alpha 通道的图像转换为 RGB
+    
+    Args:
+        img: 输入图像
+        alpha_color: 背景颜色 (B, G, R)
+        
+    Returns:
+        np.ndarray: RGB 图像
+    """
+    if len(img.shape) == 3 and img.shape[2] == 4:
+        B, G, R, A = cv2.split(img)
+        alpha = A / 255
+
+        R = (alpha_color[0] * (1 - alpha) + R * alpha).astype(np.uint8)
+        G = (alpha_color[1] * (1 - alpha) + G * alpha).astype(np.uint8)
+        B = (alpha_color[2] * (1 - alpha) + B * alpha).astype(np.uint8)
+
+        img = cv2.merge((B, G, R))
+    return img
+
+
+def preprocess_image(_image: np.ndarray) -> np.ndarray:
+    """
+    预处理图像(去除 alpha 通道)
+    
+    Args:
+        _image: 输入图像
+        
+    Returns:
+        np.ndarray: 预处理后的图像
+    """
+    alpha_color = (255, 255, 255)
+    _image = alpha_to_color(_image, alpha_color)
+    return _image
+
+
+def bbox_to_points(bbox: List[float]) -> np.ndarray:
+    """
+    将 bbox 格式转换为四个顶点的数组
+    
+    Args:
+        bbox: [x0, y0, x1, y1]
+        
+    Returns:
+        np.ndarray: [[x0, y0], [x1, y0], [x1, y1], [x0, y1]]
+    """
+    x0, y0, x1, y1 = bbox
+    return np.array([[x0, y0], [x1, y0], [x1, y1], [x0, y1]]).astype('float32')
+
+
+def points_to_bbox(points: np.ndarray) -> List[float]:
+    """
+    将四个顶点的数组转换为 bbox 格式
+    
+    Args:
+        points: [[x0, y0], [x1, y1], [x2, y2], [x3, y3]]
+        
+    Returns:
+        list: [x0, y0, x1, y1]
+    """
+    x0, y0 = points[0]
+    x1, _ = points[1]
+    _, y1 = points[2]
+    return [x0, y0, x1, y1]
+
+
+def rotate_image_and_coordinates(
+    image: Image.Image, 
+    angle: float, 
+    coordinates_list: List[List[int]], 
+    rotate_coordinates: bool = True
+) -> Tuple[Image.Image, List[List[int]]]:
+    """
+    根据角度旋转图像和坐标 - 修正版本
+    
+    Args:
+        image: 原始图像(PIL Image)
+        angle: 旋转角度(度数:0, 90, 180, 270 或任意角度)
+        coordinates_list: 坐标列表,每个坐标为[x1, y1, x2, y2]格式
+        rotate_coordinates: 是否需要旋转坐标(针对不同OCR工具的处理方式)
+    
+    Returns:
+        rotated_image: 旋转后的图像
+        rotated_coordinates: 处理后的坐标列表
+    """
+    if angle == 0:
+        return image, coordinates_list
+    
+    # 标准化旋转角度
+    if angle == 270:
+        rotation_angle = -90  # 顺时针90度
+    elif angle == 90:
+        rotation_angle = 90   # 逆时针90度
+    elif angle == 180:
+        rotation_angle = 180  # 180度
+    else:
+        rotation_angle = angle
+    
+    # 旋转图像
+    rotated_image = image.rotate(rotation_angle, expand=True)
+    
+    # 如果不需要旋转坐标,直接返回原坐标
+    if not rotate_coordinates:
+        return rotated_image, coordinates_list
+    
+    # 获取原始和旋转后的图像尺寸
+    orig_width, orig_height = image.size
+    new_width, new_height = rotated_image.size
+    
+    # 计算旋转后的坐标
+    rotated_coordinates = []
+    
+    for coord in coordinates_list:
+        if len(coord) < 4:
+            rotated_coordinates.append(coord)
+            continue
+            
+        x1, y1, x2, y2 = coord[:4]
+        
+        # 验证原始坐标是否有效
+        if x1 < 0 or y1 < 0 or x2 <= x1 or y2 <= y1:
+            print(f"警告: 无效坐标 {coord}")
+            rotated_coordinates.append([0, 0, 50, 50])  # 使用默认坐标
+            continue
+        
+        # 根据旋转角度变换坐标
+        if rotation_angle == -90:  # 顺时针90度 (270度逆时针)
+            # 变换公式: (x, y) -> (orig_height - y, x)
+            new_x1 = orig_height - y2  # 这里是y2
+            new_y1 = x1
+            new_x2 = orig_height - y1  # 这里是y1
+            new_y2 = x2
+            
+        elif rotation_angle == 90:  # 逆时针90度
+            # 变换公式: (x, y) -> (y, orig_width - x)
+            new_x1 = y1
+            new_y1 = orig_width - x2  # 这里是x2
+            new_x2 = y2
+            new_y2 = orig_width - x1  # 这里是x1
+
+        elif rotation_angle == 180:  # 180度
+            # 变换公式: (x, y) -> (orig_width - x, orig_height - y)
+            new_x1 = orig_width - x2
+            new_y1 = orig_height - y2
+            new_x2 = orig_width - x1
+            new_y2 = orig_height - y1
+            
+        else:  # 任意角度算法 - 修正版本
+            # 将角度转换为弧度
+            angle_rad = np.radians(rotation_angle)
+            cos_angle = np.cos(angle_rad)
+            sin_angle = np.sin(angle_rad)
+            
+            # 原图像中心点
+            orig_center_x = orig_width / 2
+            orig_center_y = orig_height / 2
+            
+            # 旋转后图像中心点
+            new_center_x = new_width / 2
+            new_center_y = new_height / 2
+            
+            # 将bbox的四个角点转换为相对于原图像中心的坐标
+            corners = [
+                (x1 - orig_center_x, y1 - orig_center_y),  # 左上角
+                (x2 - orig_center_x, y1 - orig_center_y),  # 右上角
+                (x2 - orig_center_x, y2 - orig_center_y),  # 右下角
+                (x1 - orig_center_x, y2 - orig_center_y)   # 左下角
+            ]
+            
+            # 应用修正后的旋转矩阵变换每个角点
+            rotated_corners = []
+            for x, y in corners:
+                # 修正后的旋转矩阵: [cos(θ)  sin(θ)] [x]
+                #                  [-sin(θ) cos(θ)] [y]
+                rotated_x = x * cos_angle + y * sin_angle
+                rotated_y = -x * sin_angle + y * cos_angle
+                
+                # 转换回绝对坐标(相对于新图像)
+                abs_x = rotated_x + new_center_x
+                abs_y = rotated_y + new_center_y
+                
+                rotated_corners.append((abs_x, abs_y))
+            
+            # 从旋转后的四个角点计算新的边界框
+            x_coords = [corner[0] for corner in rotated_corners]
+            y_coords = [corner[1] for corner in rotated_corners]
+            
+            new_x1 = int(min(x_coords))
+            new_y1 = int(min(y_coords))
+            new_x2 = int(max(x_coords))
+            new_y2 = int(max(y_coords))
+        
+        # 确保坐标在有效范围内
+        new_x1 = max(0, min(new_width, new_x1))
+        new_y1 = max(0, min(new_height, new_y1))
+        new_x2 = max(0, min(new_width, new_x2))
+        new_y2 = max(0, min(new_height, new_y2))
+        
+        # 确保x1 < x2, y1 < y2
+        if new_x1 > new_x2:
+            new_x1, new_x2 = new_x2, new_x1
+        if new_y1 > new_y2:
+            new_y1, new_y2 = new_y2, new_y1
+        
+        rotated_coordinates.append([new_x1, new_y1, new_x2, new_y2])
+    
+    return rotated_image, rotated_coordinates
+

+ 1 - 0
ocr_utils/markdown_generator.py

@@ -304,6 +304,7 @@ pages: {len(results.get('pages', []))}
             f"-->",
             "",
         ]
+        md_lines.append("")
         
         for element in page.get('elements', []):
             elem_type = element.get('type', '')

+ 55 - 0
ocr_utils/number_utils.py

@@ -0,0 +1,55 @@
+"""
+数字解析工具模块
+
+提供数字解析和标准化功能:
+- 解析数字(处理千分位和货币符号)
+- 标准化文本型数字
+"""
+import re
+
+
+def parse_number(text: str) -> float:
+    """
+    解析数字,处理千分位和货币符号
+    
+    Args:
+        text: 包含数字的文本
+        
+    Returns:
+        解析后的浮点数
+    """
+    if not text:
+        return 0.0
+    
+    clean_text = re.sub(r'[¥$€£,,\s]', '', text)
+    
+    is_negative = False
+    if clean_text.startswith('-') or clean_text.startswith('−'):
+        is_negative = True
+        clean_text = clean_text[1:]
+    
+    if clean_text.startswith('(') and clean_text.endswith(')'):
+        is_negative = True
+        clean_text = clean_text[1:-1]
+    
+    try:
+        number = float(clean_text)
+        return -number if is_negative else number
+    except ValueError:
+        return 0.0
+
+
+def normalize_text_number(text: str) -> str:
+    """
+    标准化文本型数字:移除空格和连字符
+    
+    Args:
+        text: 文本型数字(如账号、订单号)
+        
+    Returns:
+        标准化后的文本
+    """
+    if not text:
+        return ""
+    return re.sub(r'[\s\-\u3000]', '', text)
+

+ 38 - 0
ocr_utils/visualization_utils.py

@@ -433,4 +433,42 @@ class VisualizationUtils:
                 continue
         
         return ImageFont.load_default()
+    
+    @staticmethod
+    def draw_bbox_on_image(image: Image.Image, bbox: List[int], color: str = "red", width: int = 3) -> Image.Image:
+        """
+        在图片上绘制bbox框
+        
+        Args:
+            image: PIL Image 对象
+            bbox: 边界框坐标 [x1, y1, x2, y2]
+            color: 边框颜色(字符串,如 "red", "blue", "green")
+            width: 边框宽度
+            
+        Returns:
+            绘制了 bbox 的图像副本
+        """
+        img_copy = image.copy()
+        draw = ImageDraw.Draw(img_copy)
+        
+        x1, y1, x2, y2 = bbox
+        
+        # 绘制矩形框
+        draw.rectangle([x1, y1, x2, y2], outline=color, width=width)
+        
+        # 添加半透明填充
+        overlay = Image.new('RGBA', img_copy.size, (0, 0, 0, 0))
+        overlay_draw = ImageDraw.Draw(overlay)
+        
+        color_map = {
+            "red": (255, 0, 0, 30),
+            "blue": (0, 0, 255, 30),
+            "green": (0, 255, 0, 30)
+        }
+        fill_color = color_map.get(color, (255, 255, 0, 30))
+        
+        overlay_draw.rectangle([x1, y1, x2, y2], fill=fill_color)
+        img_copy = Image.alpha_composite(img_copy.convert('RGBA'), overlay).convert('RGB')
+        
+        return img_copy