Преглед изворни кода

新增OCR工具功能,支持HTML和Markdown内容中图片引用的处理,增强图片查找和转换为base64的能力

zhch158_admin пре 1 месец
родитељ
комит
72ddd4812f
1 измењених фајлова са 440 додато и 0 уклоњено
  1. 440 0
      ocr_validator_file_utils.py

+ 440 - 0
ocr_validator_file_utils.py

@@ -0,0 +1,440 @@
+import os
+import base64
+import pandas as pd
+import numpy as np
+from io import StringIO
+from html import unescape
+from typing import Dict, Optional, List
+from PIL import Image, ImageDraw
+from io import BytesIO
+import cv2
+import re
+
+def load_css_styles(css_path: str = "styles.css") -> str:
+    """加载CSS样式文件"""
+    try:
+        with open(css_path, 'r', encoding='utf-8') as f:
+            return f.read()
+    except Exception:
+        # 返回基本样式
+        return """
+        .main > div { background-color: white !important; color: #333333 !important; }
+        .stApp { background-color: white !important; }
+        .block-container { background-color: white !important; color: #333333 !important; }
+        """
+
+def find_image_in_multiple_locations(img_src: str, json_path: str) -> Optional[str]:
+    """
+    在多个可能的位置查找图片文件
+    """
+    json_dir = os.path.dirname(json_path)
+    
+    # 可能的搜索路径
+    search_paths = [
+        # 相对于JSON文件的路径
+        os.path.join(json_dir, img_src),
+        # 相对于JSON文件父目录的路径
+        os.path.join(os.path.dirname(json_dir), img_src),
+        # imgs目录(常见的图片目录)
+        os.path.join(json_dir, 'imgs', os.path.basename(img_src)),
+        os.path.join(os.path.dirname(json_dir), 'imgs', os.path.basename(img_src)),
+        # images目录
+        os.path.join(json_dir, 'images', os.path.basename(img_src)),
+        os.path.join(os.path.dirname(json_dir), 'images', os.path.basename(img_src)),
+        # 同名目录
+        os.path.join(json_dir, os.path.splitext(os.path.basename(json_path))[0], os.path.basename(img_src)),
+    ]
+    
+    # 如果是绝对路径,也加入搜索
+    if os.path.isabs(img_src):
+        search_paths.insert(0, img_src)
+    
+    # 查找存在的文件
+    for path in search_paths:
+        if os.path.exists(path):
+            return path
+    
+    return None
+
+
+def process_html_images(html_content: str, json_path: str) -> str:
+    """
+    处理HTML内容中的图片引用,将本地图片转换为base64 - 增强版
+    """
+    import re
+    
+    # 匹配HTML图片标签: <img src="path" ... />
+    img_pattern = r'<img\s+[^>]*src\s*=\s*["\']([^"\']+)["\'][^>]*/?>'
+    
+    def replace_html_image(match):
+        full_tag = match.group(0)
+        img_src = match.group(1)
+        
+        # 如果已经是base64或者网络链接,直接返回
+        if img_src.startswith('data:image') or img_src.startswith('http'):
+            return full_tag
+        
+        # 增强的图片查找
+        full_img_path = find_image_in_multiple_locations(img_src, json_path)
+        
+        # 尝试转换为base64
+        try:
+            if full_img_path and os.path.exists(full_img_path):
+                with open(full_img_path, 'rb') as img_file:
+                    img_data = img_file.read()
+                    
+                # 获取文件扩展名确定MIME类型
+                ext = os.path.splitext(full_img_path)[1].lower()
+                mime_type = {
+                    '.png': 'image/png',
+                    '.jpg': 'image/jpeg',
+                    '.jpeg': 'image/jpeg',
+                    '.gif': 'image/gif',
+                    '.bmp': 'image/bmp',
+                    '.webp': 'image/webp'
+                }.get(ext, 'image/jpeg')
+                
+                # 转换为base64
+                img_base64 = base64.b64encode(img_data).decode('utf-8')
+                data_url = f"data:{mime_type};base64,{img_base64}"
+                
+                # 替换src属性,保持其他属性不变
+                updated_tag = re.sub(
+                    r'src\s*=\s*["\'][^"\']+["\']',
+                    f'src="{data_url}"',
+                    full_tag
+                )
+                return updated_tag
+            else:
+                # 文件不存在,显示详细的错误信息
+                search_info = f"搜索路径: {img_src}"
+                if full_img_path:
+                    search_info += f" -> {full_img_path}"
+                
+                error_content = f"""
+                <div style="
+                    color: #d32f2f; 
+                    border: 2px dashed #d32f2f; 
+                    padding: 10px; 
+                    margin: 10px 0; 
+                    border-radius: 5px;
+                    background-color: #ffebee;
+                    text-align: center;
+                ">
+                    <strong>🖼️ 图片无法加载</strong><br>
+                    <small>原始路径: {img_src}</small><br>
+                    <small>JSON文件: {os.path.basename(json_path)}</small><br>
+                    <em>请检查图片文件是否存在</em>
+                </div>
+                """
+                return error_content
+        except Exception as e:
+            # 转换失败,返回错误信息
+            error_content = f"""
+            <div style="
+                color: #f57c00; 
+                border: 2px dashed #f57c00; 
+                padding: 10px; 
+                margin: 10px 0; 
+                border-radius: 5px;
+                background-color: #fff3e0;
+                text-align: center;
+            ">
+                <strong>⚠️ 图片处理失败</strong><br>
+                <small>文件: {img_src}</small><br>
+                <small>错误: {str(e)}</small>
+            </div>
+            """
+            return error_content
+    
+    # 替换所有HTML图片标签
+    processed_content = re.sub(img_pattern, replace_html_image, html_content, flags=re.IGNORECASE)
+    return processed_content
+
+def process_markdown_images(md_content: str, json_path: str) -> str:
+    """
+    处理Markdown中的图片引用,将本地图片转换为base64
+    """
+    import re
+    
+    # 匹配Markdown图片语法: ![alt](path)
+    img_pattern = r'!\[([^\]]*)\]\(([^)]+)\)'
+    
+    def replace_image(match):
+        alt_text = match.group(1)
+        img_path = match.group(2)
+        
+        # 如果已经是base64或者网络链接,直接返回
+        if img_path.startswith('data:image') or img_path.startswith('http'):
+            return match.group(0)
+        
+        # 处理相对路径
+        if not os.path.isabs(img_path):
+            # 相对于JSON文件的路径
+            json_dir = os.path.dirname(json_path)
+            full_img_path = os.path.join(json_dir, img_path)
+        else:
+            full_img_path = img_path
+        
+        # 尝试转换为base64
+        try:
+            if os.path.exists(full_img_path):
+                with open(full_img_path, 'rb') as img_file:
+                    img_data = img_file.read()
+                    
+                # 获取文件扩展名确定MIME类型
+                ext = os.path.splitext(full_img_path)[1].lower()
+                mime_type = {
+                    '.png': 'image/png',
+                    '.jpg': 'image/jpeg',
+                    '.jpeg': 'image/jpeg',
+                    '.gif': 'image/gif',
+                    '.bmp': 'image/bmp',
+                    '.webp': 'image/webp'
+                }.get(ext, 'image/jpeg')
+                
+                # 转换为base64
+                img_base64 = base64.b64encode(img_data).decode('utf-8')
+                data_url = f"data:{mime_type};base64,{img_base64}"
+                
+                return f'![{alt_text}]({data_url})'
+            else:
+                # 文件不存在,返回原始链接但添加警告
+                return f'![{alt_text} (文件不存在)]({img_path})'
+        except Exception as e:
+            # 转换失败,返回原始链接
+            return f'![{alt_text} (加载失败)]({img_path})'
+    
+    # 替换所有图片引用
+    processed_content = re.sub(img_pattern, replace_image, md_content)
+    return processed_content
+
+def process_all_images_in_content(content: str, json_path: str) -> str:
+    """
+    处理内容中的所有图片引用(包括Markdown和HTML格式)
+    """
+    # 先处理HTML图片
+    content = process_html_images(content, json_path)
+    # 再处理Markdown图片
+    content = process_markdown_images(content, json_path)
+    return content
+
+
+def draw_bbox_on_image(image: Image.Image, bbox: List[int], color: str = "red", width: int = 3) -> Image.Image:
+    """在图片上绘制bbox框"""
+    img_copy = image.copy()
+    draw = ImageDraw.Draw(img_copy)
+    
+    x1, y1, x2, y2 = bbox
+    
+    # 绘制矩形框
+    draw.rectangle([x1, y1, x2, y2], outline=color, width=width)
+    
+    # 添加半透明填充
+    overlay = Image.new('RGBA', img_copy.size, (0, 0, 0, 0))
+    overlay_draw = ImageDraw.Draw(overlay)
+    
+    color_map = {
+        "red": (255, 0, 0, 30),
+        "blue": (0, 0, 255, 30),
+        "green": (0, 255, 0, 30)
+    }
+    fill_color = color_map.get(color, (255, 255, 0, 30))
+    
+    overlay_draw.rectangle([x1, y1, x2, y2], fill=fill_color)
+    img_copy = Image.alpha_composite(img_copy.convert('RGBA'), overlay).convert('RGB')
+    
+    return img_copy
+
+
+def convert_html_table_to_markdown(content: str) -> str:
+    """将HTML表格转换为Markdown表格格式 - 支持横向滚动的增强版本"""
+    def replace_table(match):
+        table_html = match.group(0)
+        
+        # 提取所有行
+        rows = re.findall(r'<tr[^>]*>(.*?)</tr>', table_html, re.DOTALL | re.IGNORECASE)
+        if not rows:
+            return table_html
+        
+        markdown_rows = []
+        max_cols = 0
+        
+        # 处理所有行,找出最大列数
+        processed_rows = []
+        for row in rows:
+            # 提取单元格,支持 th 和 td
+            cells = re.findall(r'<t[hd][^>]*>(.*?)</t[hd]>', row, re.DOTALL | re.IGNORECASE)
+            if cells:
+                clean_cells = []
+                for cell in cells:
+                    cell_text = re.sub(r'<[^>]+>', '', cell).strip()
+                    cell_text = unescape(cell_text)
+                    # 限制单元格长度,避免表格过宽
+                    if len(cell_text) > 30:
+                        cell_text = cell_text[:27] + "..."
+                    clean_cells.append(cell_text or " ")  # 空单元格用空格替代
+                
+                processed_rows.append(clean_cells)
+                max_cols = max(max_cols, len(clean_cells))
+        
+        # 统一所有行的列数
+        for i, row_cells in enumerate(processed_rows):
+            while len(row_cells) < max_cols:
+                row_cells.append(" ")
+            
+            # 构建Markdown行
+            markdown_row = '| ' + ' | '.join(row_cells) + ' |'
+            markdown_rows.append(markdown_row)
+            
+            # 在第一行后添加分隔符
+            if i == 0:
+                separator = '| ' + ' | '.join(['---'] * max_cols) + ' |'
+                markdown_rows.append(separator)
+        
+        # 添加滚动提示
+        if max_cols > 8:
+            scroll_note = "\n> 📋 **提示**: 此表格列数较多,在某些视图中可能需要横向滚动查看完整内容。\n"
+            return scroll_note + '\n'.join(markdown_rows) if markdown_rows else table_html
+        
+        return '\n'.join(markdown_rows) if markdown_rows else table_html
+    
+    # 替换所有HTML表格
+    converted = re.sub(r'<table[^>]*>.*?</table>', replace_table, content, flags=re.DOTALL | re.IGNORECASE)
+    return converted
+
+
+def parse_html_tables(html_content: str) -> List[pd.DataFrame]:
+    """解析HTML内容中的表格为DataFrame列表"""
+    try:
+        tables = pd.read_html(StringIO(html_content))
+        return tables if tables else []
+    except Exception:
+        return []
+
+
+def create_dynamic_css(config: Dict, font_size_key: str, height: int) -> str:
+    """根据配置动态创建CSS样式"""
+    colors = config['styles']['colors']
+    font_size = config['styles']['font_sizes'][font_size_key]
+    
+    return f"""
+    <style>
+    .dynamic-content {{
+        height: {height}px;
+        font-size: {font_size}px !important;
+        line-height: 1.4;
+        background-color: {colors['background']} !important;
+        color: {colors['text']} !important;
+        border: 1px solid #ddd;
+        padding: 10px;
+        border-radius: 5px;
+    }}
+    
+    .highlight-selected {{
+        background-color: {colors['success']} !important;
+        color: white !important;
+    }}
+    
+    .highlight-error {{
+        background-color: {colors['error']} !important;
+        color: white !important;
+    }}
+    </style>
+    """
+
+
+def export_tables_to_excel(tables: List[pd.DataFrame], filename: str = "ocr_tables.xlsx") -> BytesIO:
+    """导出表格数据到Excel"""
+    output = BytesIO()
+    with pd.ExcelWriter(output, engine='openpyxl') as writer:
+        for i, table in enumerate(tables):
+            table.to_excel(writer, sheet_name=f'Table_{i+1}', index=False)
+    return output
+
+
+def get_table_statistics(tables: List[pd.DataFrame]) -> List[Dict]:
+    """获取表格统计信息"""
+    stats = []
+    for i, table in enumerate(tables):
+        numeric_cols = len(table.select_dtypes(include=[np.number]).columns)
+        stats.append({
+            'table_index': i + 1,
+            'rows': len(table),
+            'columns': len(table.columns),
+            'numeric_columns': numeric_cols
+        })
+    return stats
+
+
+def detect_image_orientation_by_opencv(image_path: str) -> Dict:
+    """
+    使用OpenCV的文本检测来判断图片方向
+    """
+    try:
+        # 读取图像
+        image = cv2.imread(image_path)
+        if image is None:
+            raise ValueError("无法读取图像文件")
+        
+        height, width = image.shape[:2]
+        gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
+        
+        # 使用EAST文本检测器或其他方法
+        # 这里使用简单的边缘检测和轮廓分析
+        edges = cv2.Canny(gray, 50, 150, apertureSize=3)
+        
+        # 检测直线
+        lines = cv2.HoughLines(edges, 1, np.pi/180, threshold=100)
+        
+        if lines is None:
+            return {
+                'detected_angle': 0.0,
+                'confidence': 0.0,
+                'method': 'opencv_analysis',
+                'message': '未检测到足够的直线特征'
+            }
+        
+        # 分析直线角度
+        angles = []
+        for rho, theta in lines[:, 0]:
+            angle = theta * 180 / np.pi
+            # 将角度标准化到0-180度
+            if angle > 90:
+                angle = angle - 180
+            angles.append(angle)
+        
+        # 统计主要角度
+        angle_hist = np.histogram(angles, bins=36, range=(-90, 90))[0]
+        dominant_angle_idx = np.argmax(angle_hist)
+        dominant_angle = -90 + dominant_angle_idx * 5  # 每个bin 5度
+        
+        # 将角度映射到标准旋转角度
+        if -22.5 <= dominant_angle <= 22.5:
+            detected_angle = 0.0
+        elif 22.5 < dominant_angle <= 67.5:
+            detected_angle = 270.0
+        elif 67.5 < dominant_angle <= 90 or -90 <= dominant_angle < -67.5:
+            detected_angle = 90.0
+        else:
+            detected_angle = 180.0
+        
+        confidence = angle_hist[dominant_angle_idx] / len(lines) if len(lines) > 0 else 0.0
+        
+        return {
+            'detected_angle': detected_angle,
+            'confidence': min(1.0, confidence),
+            'method': 'opencv_analysis',
+            'line_count': len(lines),
+            'dominant_angle': dominant_angle,
+            'message': f'基于{len(lines)}条直线检测到旋转角度: {detected_angle}°'
+        }
+        
+    except Exception as e:
+        return {
+            'detected_angle': 0.0,
+            'confidence': 0.0,
+            'method': 'opencv_analysis',
+            'error': str(e),
+            'message': f'OpenCV检测过程中发生错误: {str(e)}'
+        }