|
|
@@ -0,0 +1,323 @@
|
|
|
+"""
|
|
|
+OCR验证工具的工具函数模块
|
|
|
+包含数据处理、图像处理、统计分析等功能
|
|
|
+"""
|
|
|
+
|
|
|
+import json
|
|
|
+import pandas as pd
|
|
|
+import numpy as np
|
|
|
+from pathlib import Path
|
|
|
+from PIL import Image, ImageDraw
|
|
|
+from typing import Dict, List, Optional, Tuple
|
|
|
+from io import StringIO, BytesIO
|
|
|
+import re
|
|
|
+from html import unescape
|
|
|
+import yaml
|
|
|
+
|
|
|
+
|
|
|
+def load_config(config_path: str = "config.yaml") -> Dict:
|
|
|
+ """加载配置文件"""
|
|
|
+ try:
|
|
|
+ with open(config_path, 'r', encoding='utf-8') as f:
|
|
|
+ return yaml.safe_load(f)
|
|
|
+ except Exception as e:
|
|
|
+ # 返回默认配置
|
|
|
+ return get_default_config()
|
|
|
+
|
|
|
+
|
|
|
+def get_default_config() -> Dict:
|
|
|
+ """获取默认配置"""
|
|
|
+ return {
|
|
|
+ 'styles': {
|
|
|
+ 'font_sizes': {'small': 10, 'medium': 12, 'large': 14, 'extra_large': 16},
|
|
|
+ 'colors': {
|
|
|
+ 'primary': '#0288d1', 'secondary': '#ff9800', 'success': '#4caf50',
|
|
|
+ 'error': '#f44336', 'warning': '#ff9800', 'background': '#fafafa', 'text': '#333333'
|
|
|
+ },
|
|
|
+ 'layout': {'default_zoom': 1.0, 'default_height': 600, 'sidebar_width': 0.3, 'content_width': 0.7}
|
|
|
+ },
|
|
|
+ 'ui': {
|
|
|
+ 'page_title': 'OCR可视化校验工具', 'page_icon': '🔍', 'layout': 'wide',
|
|
|
+ 'sidebar_state': 'expanded', 'default_font_size': 'medium', 'default_layout': '标准布局'
|
|
|
+ },
|
|
|
+ 'paths': {
|
|
|
+ 'output_dir': 'output', 'sample_data_dir': './sample_data',
|
|
|
+ 'supported_image_formats': ['.png', '.jpg', '.jpeg']
|
|
|
+ },
|
|
|
+ 'ocr': {'min_text_length': 2, 'default_confidence': 1.0, 'exclude_texts': ['Picture', '']}
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+def load_css_styles(css_path: str = "styles.css") -> str:
|
|
|
+ """加载CSS样式文件"""
|
|
|
+ try:
|
|
|
+ with open(css_path, 'r', encoding='utf-8') as f:
|
|
|
+ return f.read()
|
|
|
+ except Exception:
|
|
|
+ # 返回基本样式
|
|
|
+ return """
|
|
|
+ .main > div { background-color: white !important; color: #333333 !important; }
|
|
|
+ .stApp { background-color: white !important; }
|
|
|
+ .block-container { background-color: white !important; color: #333333 !important; }
|
|
|
+ """
|
|
|
+
|
|
|
+
|
|
|
+def load_ocr_data_file(json_path: str, config: Dict) -> Tuple[List, str, str]:
|
|
|
+ """加载OCR相关数据文件"""
|
|
|
+ json_file = Path(json_path)
|
|
|
+ ocr_data = []
|
|
|
+ md_content = ""
|
|
|
+ image_path = ""
|
|
|
+
|
|
|
+ # 加载JSON数据
|
|
|
+ try:
|
|
|
+ with open(json_file, 'r', encoding='utf-8') as f:
|
|
|
+ data = json.load(f)
|
|
|
+ if isinstance(data, list):
|
|
|
+ ocr_data = data
|
|
|
+ elif isinstance(data, dict) and 'results' in data:
|
|
|
+ ocr_data = data['results']
|
|
|
+ else:
|
|
|
+ raise ValueError(f"不支持的JSON格式: {json_path}")
|
|
|
+ except Exception as e:
|
|
|
+ raise Exception(f"加载JSON文件失败: {e}")
|
|
|
+
|
|
|
+ # 加载MD文件
|
|
|
+ md_file = json_file.with_suffix('.md')
|
|
|
+ if md_file.exists():
|
|
|
+ with open(md_file, 'r', encoding='utf-8') as f:
|
|
|
+ md_content = f.read()
|
|
|
+
|
|
|
+ # 推断图片路径
|
|
|
+ image_name = json_file.stem
|
|
|
+ sample_data_dir = Path(config['paths']['sample_data_dir'])
|
|
|
+
|
|
|
+ image_candidates = []
|
|
|
+ for ext in config['paths']['supported_image_formats']:
|
|
|
+ image_candidates.extend([
|
|
|
+ sample_data_dir / f"{image_name}{ext}",
|
|
|
+ json_file.parent / f"{image_name}{ext}",
|
|
|
+ ])
|
|
|
+
|
|
|
+ for candidate in image_candidates:
|
|
|
+ if candidate.exists():
|
|
|
+ image_path = str(candidate)
|
|
|
+ break
|
|
|
+
|
|
|
+ return ocr_data, md_content, image_path
|
|
|
+
|
|
|
+
|
|
|
+def process_ocr_data(ocr_data: List, config: Dict) -> Dict[str, List]:
|
|
|
+ """处理OCR数据,建立文本到bbox的映射"""
|
|
|
+ text_bbox_mapping = {}
|
|
|
+ exclude_texts = config['ocr']['exclude_texts']
|
|
|
+ min_text_length = config['ocr']['min_text_length']
|
|
|
+
|
|
|
+ if not isinstance(ocr_data, list):
|
|
|
+ return text_bbox_mapping
|
|
|
+
|
|
|
+ for i, item in enumerate(ocr_data):
|
|
|
+ if not isinstance(item, dict):
|
|
|
+ continue
|
|
|
+
|
|
|
+ if 'text' in item and 'bbox' in item:
|
|
|
+ text = str(item['text']).strip()
|
|
|
+ if text and text not in exclude_texts and len(text) >= min_text_length:
|
|
|
+ bbox = item['bbox']
|
|
|
+ if isinstance(bbox, list) and len(bbox) == 4:
|
|
|
+ if text not in text_bbox_mapping:
|
|
|
+ text_bbox_mapping[text] = []
|
|
|
+ text_bbox_mapping[text].append({
|
|
|
+ 'bbox': bbox,
|
|
|
+ 'category': item.get('category', 'Text'),
|
|
|
+ 'index': i,
|
|
|
+ 'confidence': item.get('confidence', config['ocr']['default_confidence'])
|
|
|
+ })
|
|
|
+
|
|
|
+ return text_bbox_mapping
|
|
|
+
|
|
|
+
|
|
|
+def draw_bbox_on_image(image: Image.Image, bbox: List[int], color: str = "red", width: int = 3) -> Image.Image:
|
|
|
+ """在图片上绘制bbox框"""
|
|
|
+ img_copy = image.copy()
|
|
|
+ draw = ImageDraw.Draw(img_copy)
|
|
|
+
|
|
|
+ x1, y1, x2, y2 = bbox
|
|
|
+
|
|
|
+ # 绘制矩形框
|
|
|
+ draw.rectangle([x1, y1, x2, y2], outline=color, width=width)
|
|
|
+
|
|
|
+ # 添加半透明填充
|
|
|
+ overlay = Image.new('RGBA', img_copy.size, (0, 0, 0, 0))
|
|
|
+ overlay_draw = ImageDraw.Draw(overlay)
|
|
|
+
|
|
|
+ color_map = {
|
|
|
+ "red": (255, 0, 0, 30),
|
|
|
+ "blue": (0, 0, 255, 30),
|
|
|
+ "green": (0, 255, 0, 30)
|
|
|
+ }
|
|
|
+ fill_color = color_map.get(color, (255, 255, 0, 30))
|
|
|
+
|
|
|
+ overlay_draw.rectangle([x1, y1, x2, y2], fill=fill_color)
|
|
|
+ img_copy = Image.alpha_composite(img_copy.convert('RGBA'), overlay).convert('RGB')
|
|
|
+
|
|
|
+ return img_copy
|
|
|
+
|
|
|
+
|
|
|
+def get_ocr_statistics(ocr_data: List, text_bbox_mapping: Dict, marked_errors: set) -> Dict:
|
|
|
+ """获取OCR数据统计信息"""
|
|
|
+ if not isinstance(ocr_data, list) or not ocr_data:
|
|
|
+ return {
|
|
|
+ 'total_texts': 0, 'clickable_texts': 0, 'marked_errors': 0,
|
|
|
+ 'categories': {}, 'accuracy_rate': 0
|
|
|
+ }
|
|
|
+
|
|
|
+ total_texts = len(ocr_data)
|
|
|
+ clickable_texts = len(text_bbox_mapping)
|
|
|
+ marked_errors_count = len(marked_errors)
|
|
|
+
|
|
|
+ # 按类别统计
|
|
|
+ categories = {}
|
|
|
+ for item in ocr_data:
|
|
|
+ if isinstance(item, dict):
|
|
|
+ category = item.get('category', 'Unknown')
|
|
|
+ elif isinstance(item, str):
|
|
|
+ category = 'Text'
|
|
|
+ else:
|
|
|
+ category = 'Unknown'
|
|
|
+
|
|
|
+ categories[category] = categories.get(category, 0) + 1
|
|
|
+
|
|
|
+ accuracy_rate = (clickable_texts - marked_errors_count) / clickable_texts * 100 if clickable_texts > 0 else 0
|
|
|
+
|
|
|
+ return {
|
|
|
+ 'total_texts': total_texts,
|
|
|
+ 'clickable_texts': clickable_texts,
|
|
|
+ 'marked_errors': marked_errors_count,
|
|
|
+ 'categories': categories,
|
|
|
+ 'accuracy_rate': accuracy_rate
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+def convert_html_table_to_markdown(content: str) -> str:
|
|
|
+ """将HTML表格转换为Markdown表格格式"""
|
|
|
+ def replace_table(match):
|
|
|
+ table_html = match.group(0)
|
|
|
+
|
|
|
+ # 提取所有行
|
|
|
+ rows = re.findall(r'<tr>(.*?)</tr>', table_html, re.DOTALL | re.IGNORECASE)
|
|
|
+ if not rows:
|
|
|
+ return table_html
|
|
|
+
|
|
|
+ markdown_rows = []
|
|
|
+ for i, row in enumerate(rows):
|
|
|
+ # 提取单元格
|
|
|
+ cells = re.findall(r'<td[^>]*>(.*?)</td>', row, re.DOTALL | re.IGNORECASE)
|
|
|
+ if cells:
|
|
|
+ # 清理单元格内容
|
|
|
+ clean_cells = []
|
|
|
+ for cell in cells:
|
|
|
+ cell_text = re.sub(r'<[^>]+>', '', cell).strip()
|
|
|
+ cell_text = unescape(cell_text)
|
|
|
+ clean_cells.append(cell_text)
|
|
|
+
|
|
|
+ # 构建Markdown行
|
|
|
+ markdown_row = '| ' + ' | '.join(clean_cells) + ' |'
|
|
|
+ markdown_rows.append(markdown_row)
|
|
|
+
|
|
|
+ # 在第一行后添加分隔符
|
|
|
+ if i == 0:
|
|
|
+ separator = '| ' + ' | '.join(['---'] * len(clean_cells)) + ' |'
|
|
|
+ markdown_rows.append(separator)
|
|
|
+
|
|
|
+ return '\n'.join(markdown_rows) if markdown_rows else table_html
|
|
|
+
|
|
|
+ # 替换所有HTML表格
|
|
|
+ converted = re.sub(r'<table[^>]*>.*?</table>', replace_table, content, flags=re.DOTALL | re.IGNORECASE)
|
|
|
+ return converted
|
|
|
+
|
|
|
+
|
|
|
+def parse_html_tables(html_content: str) -> List[pd.DataFrame]:
|
|
|
+ """解析HTML内容中的表格为DataFrame列表"""
|
|
|
+ try:
|
|
|
+ tables = pd.read_html(StringIO(html_content))
|
|
|
+ return tables if tables else []
|
|
|
+ except Exception:
|
|
|
+ return []
|
|
|
+
|
|
|
+
|
|
|
+def find_available_ocr_files(output_dir: str) -> List[str]:
|
|
|
+ """查找可用的OCR文件"""
|
|
|
+ available_files = []
|
|
|
+ output_path = Path(output_dir)
|
|
|
+
|
|
|
+ if output_path.exists():
|
|
|
+ for json_file in output_path.rglob("*.json"):
|
|
|
+ available_files.append(str(json_file))
|
|
|
+
|
|
|
+ return available_files
|
|
|
+
|
|
|
+
|
|
|
+def create_dynamic_css(config: Dict, font_size_key: str, height: int) -> str:
|
|
|
+ """根据配置动态创建CSS样式"""
|
|
|
+ colors = config['styles']['colors']
|
|
|
+ font_size = config['styles']['font_sizes'][font_size_key]
|
|
|
+
|
|
|
+ return f"""
|
|
|
+ <style>
|
|
|
+ .dynamic-content {{
|
|
|
+ height: {height}px;
|
|
|
+ font-size: {font_size}px !important;
|
|
|
+ line-height: 1.4;
|
|
|
+ background-color: {colors['background']} !important;
|
|
|
+ color: {colors['text']} !important;
|
|
|
+ border: 1px solid #ddd;
|
|
|
+ padding: 10px;
|
|
|
+ border-radius: 5px;
|
|
|
+ }}
|
|
|
+
|
|
|
+ .highlight-selected {{
|
|
|
+ background-color: {colors['success']} !important;
|
|
|
+ color: white !important;
|
|
|
+ }}
|
|
|
+
|
|
|
+ .highlight-error {{
|
|
|
+ background-color: {colors['error']} !important;
|
|
|
+ color: white !important;
|
|
|
+ }}
|
|
|
+ </style>
|
|
|
+ """
|
|
|
+
|
|
|
+
|
|
|
+def export_tables_to_excel(tables: List[pd.DataFrame], filename: str = "ocr_tables.xlsx") -> BytesIO:
|
|
|
+ """导出表格数据到Excel"""
|
|
|
+ output = BytesIO()
|
|
|
+ with pd.ExcelWriter(output, engine='openpyxl') as writer:
|
|
|
+ for i, table in enumerate(tables):
|
|
|
+ table.to_excel(writer, sheet_name=f'Table_{i+1}', index=False)
|
|
|
+ return output
|
|
|
+
|
|
|
+
|
|
|
+def get_table_statistics(tables: List[pd.DataFrame]) -> List[Dict]:
|
|
|
+ """获取表格统计信息"""
|
|
|
+ stats = []
|
|
|
+ for i, table in enumerate(tables):
|
|
|
+ numeric_cols = len(table.select_dtypes(include=[np.number]).columns)
|
|
|
+ stats.append({
|
|
|
+ 'table_index': i + 1,
|
|
|
+ 'rows': len(table),
|
|
|
+ 'columns': len(table.columns),
|
|
|
+ 'numeric_columns': numeric_cols
|
|
|
+ })
|
|
|
+ return stats
|
|
|
+
|
|
|
+
|
|
|
+def group_texts_by_category(text_bbox_mapping: Dict[str, List]) -> Dict[str, List[str]]:
|
|
|
+ """按类别对文本进行分组"""
|
|
|
+ categories = {}
|
|
|
+ for text, info_list in text_bbox_mapping.items():
|
|
|
+ category = info_list[0]['category']
|
|
|
+ if category not in categories:
|
|
|
+ categories[category] = []
|
|
|
+ categories[category].append(text)
|
|
|
+ return categories
|