Przeglądaj źródła

新增OCR验证工具的工具函数模块,包含配置加载、数据处理、图像处理和统计分析功能

zhch158_admin 2 miesięcy temu
rodzic
commit
f904f50706
1 zmienionych plików z 323 dodań i 0 usunięć
  1. 323 0
      ocr_validator_utils.py

+ 323 - 0
ocr_validator_utils.py

@@ -0,0 +1,323 @@
+"""
+OCR验证工具的工具函数模块
+包含数据处理、图像处理、统计分析等功能
+"""
+
+import json
+import pandas as pd
+import numpy as np
+from pathlib import Path
+from PIL import Image, ImageDraw
+from typing import Dict, List, Optional, Tuple
+from io import StringIO, BytesIO
+import re
+from html import unescape
+import yaml
+
+
+def load_config(config_path: str = "config.yaml") -> Dict:
+    """加载配置文件"""
+    try:
+        with open(config_path, 'r', encoding='utf-8') as f:
+            return yaml.safe_load(f)
+    except Exception as e:
+        # 返回默认配置
+        return get_default_config()
+
+
+def get_default_config() -> Dict:
+    """获取默认配置"""
+    return {
+        'styles': {
+            'font_sizes': {'small': 10, 'medium': 12, 'large': 14, 'extra_large': 16},
+            'colors': {
+                'primary': '#0288d1', 'secondary': '#ff9800', 'success': '#4caf50',
+                'error': '#f44336', 'warning': '#ff9800', 'background': '#fafafa', 'text': '#333333'
+            },
+            'layout': {'default_zoom': 1.0, 'default_height': 600, 'sidebar_width': 0.3, 'content_width': 0.7}
+        },
+        'ui': {
+            'page_title': 'OCR可视化校验工具', 'page_icon': '🔍', 'layout': 'wide',
+            'sidebar_state': 'expanded', 'default_font_size': 'medium', 'default_layout': '标准布局'
+        },
+        'paths': {
+            'output_dir': 'output', 'sample_data_dir': './sample_data',
+            'supported_image_formats': ['.png', '.jpg', '.jpeg']
+        },
+        'ocr': {'min_text_length': 2, 'default_confidence': 1.0, 'exclude_texts': ['Picture', '']}
+    }
+
+
+def load_css_styles(css_path: str = "styles.css") -> str:
+    """加载CSS样式文件"""
+    try:
+        with open(css_path, 'r', encoding='utf-8') as f:
+            return f.read()
+    except Exception:
+        # 返回基本样式
+        return """
+        .main > div { background-color: white !important; color: #333333 !important; }
+        .stApp { background-color: white !important; }
+        .block-container { background-color: white !important; color: #333333 !important; }
+        """
+
+
+def load_ocr_data_file(json_path: str, config: Dict) -> Tuple[List, str, str]:
+    """加载OCR相关数据文件"""
+    json_file = Path(json_path)
+    ocr_data = []
+    md_content = ""
+    image_path = ""
+    
+    # 加载JSON数据
+    try:
+        with open(json_file, 'r', encoding='utf-8') as f:
+            data = json.load(f)
+            if isinstance(data, list):
+                ocr_data = data
+            elif isinstance(data, dict) and 'results' in data:
+                ocr_data = data['results']
+            else:
+                raise ValueError(f"不支持的JSON格式: {json_path}")
+    except Exception as e:
+        raise Exception(f"加载JSON文件失败: {e}")
+    
+    # 加载MD文件
+    md_file = json_file.with_suffix('.md')
+    if md_file.exists():
+        with open(md_file, 'r', encoding='utf-8') as f:
+            md_content = f.read()
+    
+    # 推断图片路径
+    image_name = json_file.stem
+    sample_data_dir = Path(config['paths']['sample_data_dir'])
+    
+    image_candidates = []
+    for ext in config['paths']['supported_image_formats']:
+        image_candidates.extend([
+            sample_data_dir / f"{image_name}{ext}",
+            json_file.parent / f"{image_name}{ext}",
+        ])
+    
+    for candidate in image_candidates:
+        if candidate.exists():
+            image_path = str(candidate)
+            break
+    
+    return ocr_data, md_content, image_path
+
+
+def process_ocr_data(ocr_data: List, config: Dict) -> Dict[str, List]:
+    """处理OCR数据,建立文本到bbox的映射"""
+    text_bbox_mapping = {}
+    exclude_texts = config['ocr']['exclude_texts']
+    min_text_length = config['ocr']['min_text_length']
+    
+    if not isinstance(ocr_data, list):
+        return text_bbox_mapping
+    
+    for i, item in enumerate(ocr_data):
+        if not isinstance(item, dict):
+            continue
+            
+        if 'text' in item and 'bbox' in item:
+            text = str(item['text']).strip()
+            if text and text not in exclude_texts and len(text) >= min_text_length:
+                bbox = item['bbox']
+                if isinstance(bbox, list) and len(bbox) == 4:
+                    if text not in text_bbox_mapping:
+                        text_bbox_mapping[text] = []
+                    text_bbox_mapping[text].append({
+                        'bbox': bbox,
+                        'category': item.get('category', 'Text'),
+                        'index': i,
+                        'confidence': item.get('confidence', config['ocr']['default_confidence'])
+                    })
+    
+    return text_bbox_mapping
+
+
+def draw_bbox_on_image(image: Image.Image, bbox: List[int], color: str = "red", width: int = 3) -> Image.Image:
+    """在图片上绘制bbox框"""
+    img_copy = image.copy()
+    draw = ImageDraw.Draw(img_copy)
+    
+    x1, y1, x2, y2 = bbox
+    
+    # 绘制矩形框
+    draw.rectangle([x1, y1, x2, y2], outline=color, width=width)
+    
+    # 添加半透明填充
+    overlay = Image.new('RGBA', img_copy.size, (0, 0, 0, 0))
+    overlay_draw = ImageDraw.Draw(overlay)
+    
+    color_map = {
+        "red": (255, 0, 0, 30),
+        "blue": (0, 0, 255, 30),
+        "green": (0, 255, 0, 30)
+    }
+    fill_color = color_map.get(color, (255, 255, 0, 30))
+    
+    overlay_draw.rectangle([x1, y1, x2, y2], fill=fill_color)
+    img_copy = Image.alpha_composite(img_copy.convert('RGBA'), overlay).convert('RGB')
+    
+    return img_copy
+
+
+def get_ocr_statistics(ocr_data: List, text_bbox_mapping: Dict, marked_errors: set) -> Dict:
+    """获取OCR数据统计信息"""
+    if not isinstance(ocr_data, list) or not ocr_data:
+        return {
+            'total_texts': 0, 'clickable_texts': 0, 'marked_errors': 0,
+            'categories': {}, 'accuracy_rate': 0
+        }
+    
+    total_texts = len(ocr_data)
+    clickable_texts = len(text_bbox_mapping)
+    marked_errors_count = len(marked_errors)
+    
+    # 按类别统计
+    categories = {}
+    for item in ocr_data:
+        if isinstance(item, dict):
+            category = item.get('category', 'Unknown')
+        elif isinstance(item, str):
+            category = 'Text'
+        else:
+            category = 'Unknown'
+        
+        categories[category] = categories.get(category, 0) + 1
+    
+    accuracy_rate = (clickable_texts - marked_errors_count) / clickable_texts * 100 if clickable_texts > 0 else 0
+    
+    return {
+        'total_texts': total_texts,
+        'clickable_texts': clickable_texts,
+        'marked_errors': marked_errors_count,
+        'categories': categories,
+        'accuracy_rate': accuracy_rate
+    }
+
+
+def convert_html_table_to_markdown(content: str) -> str:
+    """将HTML表格转换为Markdown表格格式"""
+    def replace_table(match):
+        table_html = match.group(0)
+        
+        # 提取所有行
+        rows = re.findall(r'<tr>(.*?)</tr>', table_html, re.DOTALL | re.IGNORECASE)
+        if not rows:
+            return table_html
+        
+        markdown_rows = []
+        for i, row in enumerate(rows):
+            # 提取单元格
+            cells = re.findall(r'<td[^>]*>(.*?)</td>', row, re.DOTALL | re.IGNORECASE)
+            if cells:
+                # 清理单元格内容
+                clean_cells = []
+                for cell in cells:
+                    cell_text = re.sub(r'<[^>]+>', '', cell).strip()
+                    cell_text = unescape(cell_text)
+                    clean_cells.append(cell_text)
+                
+                # 构建Markdown行
+                markdown_row = '| ' + ' | '.join(clean_cells) + ' |'
+                markdown_rows.append(markdown_row)
+                
+                # 在第一行后添加分隔符
+                if i == 0:
+                    separator = '| ' + ' | '.join(['---'] * len(clean_cells)) + ' |'
+                    markdown_rows.append(separator)
+        
+        return '\n'.join(markdown_rows) if markdown_rows else table_html
+    
+    # 替换所有HTML表格
+    converted = re.sub(r'<table[^>]*>.*?</table>', replace_table, content, flags=re.DOTALL | re.IGNORECASE)
+    return converted
+
+
+def parse_html_tables(html_content: str) -> List[pd.DataFrame]:
+    """解析HTML内容中的表格为DataFrame列表"""
+    try:
+        tables = pd.read_html(StringIO(html_content))
+        return tables if tables else []
+    except Exception:
+        return []
+
+
+def find_available_ocr_files(output_dir: str) -> List[str]:
+    """查找可用的OCR文件"""
+    available_files = []
+    output_path = Path(output_dir)
+    
+    if output_path.exists():
+        for json_file in output_path.rglob("*.json"):
+            available_files.append(str(json_file))
+    
+    return available_files
+
+
+def create_dynamic_css(config: Dict, font_size_key: str, height: int) -> str:
+    """根据配置动态创建CSS样式"""
+    colors = config['styles']['colors']
+    font_size = config['styles']['font_sizes'][font_size_key]
+    
+    return f"""
+    <style>
+    .dynamic-content {{
+        height: {height}px;
+        font-size: {font_size}px !important;
+        line-height: 1.4;
+        background-color: {colors['background']} !important;
+        color: {colors['text']} !important;
+        border: 1px solid #ddd;
+        padding: 10px;
+        border-radius: 5px;
+    }}
+    
+    .highlight-selected {{
+        background-color: {colors['success']} !important;
+        color: white !important;
+    }}
+    
+    .highlight-error {{
+        background-color: {colors['error']} !important;
+        color: white !important;
+    }}
+    </style>
+    """
+
+
+def export_tables_to_excel(tables: List[pd.DataFrame], filename: str = "ocr_tables.xlsx") -> BytesIO:
+    """导出表格数据到Excel"""
+    output = BytesIO()
+    with pd.ExcelWriter(output, engine='openpyxl') as writer:
+        for i, table in enumerate(tables):
+            table.to_excel(writer, sheet_name=f'Table_{i+1}', index=False)
+    return output
+
+
+def get_table_statistics(tables: List[pd.DataFrame]) -> List[Dict]:
+    """获取表格统计信息"""
+    stats = []
+    for i, table in enumerate(tables):
+        numeric_cols = len(table.select_dtypes(include=[np.number]).columns)
+        stats.append({
+            'table_index': i + 1,
+            'rows': len(table),
+            'columns': len(table.columns),
+            'numeric_columns': numeric_cols
+        })
+    return stats
+
+
+def group_texts_by_category(text_bbox_mapping: Dict[str, List]) -> Dict[str, List[str]]:
+    """按类别对文本进行分组"""
+    categories = {}
+    for text, info_list in text_bbox_mapping.items():
+        category = info_list[0]['category']
+        if category not in categories:
+            categories[category] = []
+        categories[category].append(text)
+    return categories