"""
OCR验证工具的工具函数模块
包含数据处理、图像处理、统计分析等功能
"""
import json
import pandas as pd
import numpy as np
from pathlib import Path
from PIL import Image, ImageDraw
from typing import Dict, List, Optional, Tuple, Union
from io import StringIO, BytesIO
import re
from html import unescape
import yaml
import base64
from urllib.parse import urlparse
import cv2
import os
def load_config(config_path: str = "config.yaml") -> Dict:
"""加载配置文件"""
try:
with open(config_path, 'r', encoding='utf-8') as f:
return yaml.safe_load(f)
except Exception as e:
# 返回默认配置
return get_default_config()
def get_default_config() -> Dict:
"""获取默认配置"""
return {
'styles': {
'font_sizes': {'small': 10, 'medium': 12, 'large': 14, 'extra_large': 16},
'colors': {
'primary': '#0288d1', 'secondary': '#ff9800', 'success': '#4caf50',
'error': '#f44336', 'warning': '#ff9800', 'background': '#fafafa', 'text': '#333333'
},
'layout': {'default_zoom': 1.0, 'default_height': 600, 'sidebar_width': 0.3, 'content_width': 0.7}
},
'ui': {
'page_title': 'OCR可视化校验工具', 'page_icon': '🔍', 'layout': 'wide',
'sidebar_state': 'expanded', 'default_font_size': 'medium', 'default_layout': '标准布局'
},
'paths': {
'ocr_out_dir': './sample_data', 'src_img_dir': './sample_data',
'supported_image_formats': ['.png', '.jpg', '.jpeg']
},
'ocr': {
'min_text_length': 2, 'default_confidence': 1.0, 'exclude_texts': ['Picture', ''],
'tools': {
'dots_ocr': {
'name': 'Dots OCR', 'json_structure': 'array',
'text_field': 'text', 'bbox_field': 'bbox', 'category_field': 'category'
},
'ppstructv3': {
'name': 'PPStructV3', 'json_structure': 'object', 'parsing_results_field': 'parsing_res_list',
'text_field': 'block_content', 'bbox_field': 'block_bbox', 'category_field': 'block_label'
}
},
'auto_detection': {
'enabled': True,
'rules': [
{'field_exists': 'parsing_res_list', 'tool_type': 'ppstructv3'},
{'json_is_array': True, 'tool_type': 'dots_ocr'}
]
}
}
}
def load_css_styles(css_path: str = "styles.css") -> str:
"""加载CSS样式文件"""
try:
with open(css_path, 'r', encoding='utf-8') as f:
return f.read()
except Exception:
# 返回基本样式
return """
.main > div { background-color: white !important; color: #333333 !important; }
.stApp { background-color: white !important; }
.block-container { background-color: white !important; color: #333333 !important; }
"""
def rotate_image_and_coordinates(
image: Image.Image,
angle: float,
coordinates_list: List[List[int]],
rotate_coordinates: bool = True
) -> Tuple[Image.Image, List[List[int]]]:
"""
根据角度旋转图像和坐标 - 修正版本
Args:
image: 原始图像
angle: 旋转角度(度数)
coordinates_list: 坐标列表,每个坐标为[x1, y1, x2, y2]格式
rotate_coordinates: 是否需要旋转坐标(针对不同OCR工具的处理方式)
Returns:
rotated_image: 旋转后的图像
rotated_coordinates: 处理后的坐标列表
"""
if angle == 0:
return image, coordinates_list
# 标准化旋转角度
if angle == 270:
rotation_angle = -90 # 顺时针90度
elif angle == 90:
rotation_angle = 90 # 逆时针90度
elif angle == 180:
rotation_angle = 180 # 180度
else:
rotation_angle = angle
# 旋转图像
rotated_image = image.rotate(rotation_angle, expand=True)
# 如果不需要旋转坐标,直接返回原坐标
if not rotate_coordinates:
return rotated_image, coordinates_list
# 获取原始和旋转后的图像尺寸
orig_width, orig_height = image.size
new_width, new_height = rotated_image.size
# 计算旋转后的坐标
rotated_coordinates = []
for coord in coordinates_list:
if len(coord) < 4:
rotated_coordinates.append(coord)
continue
x1, y1, x2, y2 = coord[:4]
# 验证原始坐标是否有效
if x1 < 0 or y1 < 0 or x2 <= x1 or y2 <= y1:
print(f"警告: 无效坐标 {coord}")
rotated_coordinates.append([0, 0, 50, 50]) # 使用默认坐标
continue
# 根据旋转角度变换坐标
if rotation_angle == -90: # 顺时针90度 (270度逆时针)
# 变换公式: (x, y) -> (orig_height - y, x)
new_x1 = orig_height - y2 # 这里是y2
new_y1 = x1
new_x2 = orig_height - y1 # 这里是y1
new_y2 = x2
elif rotation_angle == 90: # 逆时针90度
# 变换公式: (x, y) -> (y, orig_width - x)
new_x1 = y1
new_y1 = orig_width - x2 # 这里是x2
new_x2 = y2
new_y2 = orig_width - x1 # 这里是x1
elif rotation_angle == 180: # 180度
# 变换公式: (x, y) -> (orig_width - x, orig_height - y)
new_x1 = orig_width - x2
new_y1 = orig_height - y2
new_x2 = orig_width - x1
new_y2 = orig_height - y1
else: # 任意角度算法 - 修正版本
# 将角度转换为弧度
angle_rad = np.radians(rotation_angle)
cos_angle = np.cos(angle_rad)
sin_angle = np.sin(angle_rad)
# 原图像中心点
orig_center_x = orig_width / 2
orig_center_y = orig_height / 2
# 旋转后图像中心点
new_center_x = new_width / 2
new_center_y = new_height / 2
# 将bbox的四个角点转换为相对于原图像中心的坐标
corners = [
(x1 - orig_center_x, y1 - orig_center_y), # 左上角
(x2 - orig_center_x, y1 - orig_center_y), # 右上角
(x2 - orig_center_x, y2 - orig_center_y), # 右下角
(x1 - orig_center_x, y2 - orig_center_y) # 左下角
]
# 应用修正后的旋转矩阵变换每个角点
rotated_corners = []
for x, y in corners:
# 修正后的旋转矩阵: [cos(θ) sin(θ)] [x]
# [-sin(θ) cos(θ)] [y]
rotated_x = x * cos_angle + y * sin_angle
rotated_y = -x * sin_angle + y * cos_angle
# 转换回绝对坐标(相对于新图像)
abs_x = rotated_x + new_center_x
abs_y = rotated_y + new_center_y
rotated_corners.append((abs_x, abs_y))
# 从旋转后的四个角点计算新的边界框
x_coords = [corner[0] for corner in rotated_corners]
y_coords = [corner[1] for corner in rotated_corners]
new_x1 = int(min(x_coords))
new_y1 = int(min(y_coords))
new_x2 = int(max(x_coords))
new_y2 = int(max(y_coords))
# 确保坐标在有效范围内
new_x1 = max(0, min(new_width, new_x1))
new_y1 = max(0, min(new_height, new_y1))
new_x2 = max(0, min(new_width, new_x2))
new_y2 = max(0, min(new_height, new_y2))
# 确保x1 < x2, y1 < y2
if new_x1 > new_x2:
new_x1, new_x2 = new_x2, new_x1
if new_y1 > new_y2:
new_y1, new_y2 = new_y2, new_y1
rotated_coordinates.append([new_x1, new_y1, new_x2, new_y2])
return rotated_image, rotated_coordinates
def detect_ocr_tool_type(data: Union[List, Dict], config: Dict) -> str:
"""自动检测OCR工具类型"""
if not config['ocr']['auto_detection']['enabled']:
return 'dots_ocr' # 默认类型
rules = config['ocr']['auto_detection']['rules']
for rule in rules:
if 'field_exists' in rule:
field_name = rule['field_exists']
if isinstance(data, dict) and field_name in data:
return rule['tool_type']
if 'json_is_array' in rule:
if rule['json_is_array'] and isinstance(data, list):
return rule['tool_type']
# 默认返回dots_ocr
return 'dots_ocr'
def parse_dots_ocr_data(data: List, config: Dict) -> List[Dict]:
"""解析Dots OCR格式的数据"""
tool_config = config['ocr']['tools']['dots_ocr']
parsed_data = []
for item in data:
if not isinstance(item, dict):
continue
# 提取字段
text = item.get(tool_config['text_field'], '')
bbox = item.get(tool_config['bbox_field'], [])
category = item.get(tool_config['category_field'], 'Text')
confidence = item.get(tool_config.get('confidence_field', 'confidence'),
config['ocr']['default_confidence'])
if text and bbox and len(bbox) >= 4:
parsed_data.append({
'text': str(text).strip(),
'bbox': bbox[:4], # 确保只取前4个坐标
'category': category,
'confidence': confidence,
'source_tool': 'dots_ocr'
})
return parsed_data
def parse_ppstructv3_data(data: Dict, config: Dict) -> List[Dict]:
"""解析PPStructV3格式的数据"""
tool_config = config['ocr']['tools']['ppstructv3']
parsed_data = []
# 获取解析结果列表
parsing_results_field = tool_config['parsing_results_field']
if parsing_results_field not in data:
return parsed_data
parsing_results = data[parsing_results_field]
if not isinstance(parsing_results, list):
return parsed_data
for item in parsing_results:
if not isinstance(item, dict):
continue
# 提取字段
text = item.get(tool_config['text_field'], '')
bbox = item.get(tool_config['bbox_field'], [])
category = item.get(tool_config['category_field'], 'text')
confidence = item.get(tool_config.get('confidence_field', 'confidence'),
config['ocr']['default_confidence'])
if text and bbox and len(bbox) >= 4:
parsed_data.append({
'text': str(text).strip(),
'bbox': bbox[:4], # 确保只取前4个坐标
'category': category,
'confidence': confidence,
'source_tool': 'ppstructv3'
})
# 如果有OCR文本识别结果,也添加进来
if 'overall_ocr_res' in data:
ocr_res = data['overall_ocr_res']
if isinstance(ocr_res, dict) and 'rec_texts' in ocr_res and 'rec_boxes' in ocr_res:
texts = ocr_res['rec_texts']
boxes = ocr_res['rec_boxes']
scores = ocr_res.get('rec_scores', [])
for i, (text, box) in enumerate(zip(texts, boxes)):
if text and len(box) >= 4:
confidence = scores[i] if i < len(scores) else config['ocr']['default_confidence']
parsed_data.append({
'text': str(text).strip(),
'bbox': box[:4],
'category': 'OCR_Text',
'confidence': confidence,
'source_tool': 'ppstructv3_ocr'
})
return parsed_data
def normalize_ocr_data(raw_data: Union[List, Dict], config: Dict) -> List[Dict]:
"""统一不同OCR工具的数据格式"""
# 自动检测OCR工具类型
tool_type = detect_ocr_tool_type(raw_data, config)
if tool_type == 'dots_ocr':
return parse_dots_ocr_data(raw_data, config)
elif tool_type == 'ppstructv3':
return parse_ppstructv3_data(raw_data, config)
else:
raise ValueError(f"不支持的OCR工具类型: {tool_type}")
def get_rotation_angle_from_ppstructv3(data: Dict) -> float:
"""从PPStructV3数据中获取旋转角度"""
if 'doc_preprocessor_res' in data:
doc_res = data['doc_preprocessor_res']
if isinstance(doc_res, dict) and 'angle' in doc_res:
return float(doc_res['angle'])
return 0.0
def find_image_in_multiple_locations(img_src: str, json_path: str) -> Optional[str]:
"""
在多个可能的位置查找图片文件
"""
json_dir = os.path.dirname(json_path)
# 可能的搜索路径
search_paths = [
# 相对于JSON文件的路径
os.path.join(json_dir, img_src),
# 相对于JSON文件父目录的路径
os.path.join(os.path.dirname(json_dir), img_src),
# imgs目录(常见的图片目录)
os.path.join(json_dir, 'imgs', os.path.basename(img_src)),
os.path.join(os.path.dirname(json_dir), 'imgs', os.path.basename(img_src)),
# images目录
os.path.join(json_dir, 'images', os.path.basename(img_src)),
os.path.join(os.path.dirname(json_dir), 'images', os.path.basename(img_src)),
# 同名目录
os.path.join(json_dir, os.path.splitext(os.path.basename(json_path))[0], os.path.basename(img_src)),
]
# 如果是绝对路径,也加入搜索
if os.path.isabs(img_src):
search_paths.insert(0, img_src)
# 查找存在的文件
for path in search_paths:
if os.path.exists(path):
return path
return None
def process_html_images(html_content: str, json_path: str) -> str:
"""
处理HTML内容中的图片引用,将本地图片转换为base64 - 增强版
"""
import re
# 匹配HTML图片标签:
img_pattern = r'
]*src\s*=\s*["\']([^"\']+)["\'][^>]*/?>'
def replace_html_image(match):
full_tag = match.group(0)
img_src = match.group(1)
# 如果已经是base64或者网络链接,直接返回
if img_src.startswith('data:image') or img_src.startswith('http'):
return full_tag
# 增强的图片查找
full_img_path = find_image_in_multiple_locations(img_src, json_path)
# 尝试转换为base64
try:
if full_img_path and os.path.exists(full_img_path):
with open(full_img_path, 'rb') as img_file:
img_data = img_file.read()
# 获取文件扩展名确定MIME类型
ext = os.path.splitext(full_img_path)[1].lower()
mime_type = {
'.png': 'image/png',
'.jpg': 'image/jpeg',
'.jpeg': 'image/jpeg',
'.gif': 'image/gif',
'.bmp': 'image/bmp',
'.webp': 'image/webp'
}.get(ext, 'image/jpeg')
# 转换为base64
img_base64 = base64.b64encode(img_data).decode('utf-8')
data_url = f"data:{mime_type};base64,{img_base64}"
# 替换src属性,保持其他属性不变
updated_tag = re.sub(
r'src\s*=\s*["\'][^"\']+["\']',
f'src="{data_url}"',
full_tag
)
return updated_tag
else:
# 文件不存在,显示详细的错误信息
search_info = f"搜索路径: {img_src}"
if full_img_path:
search_info += f" -> {full_img_path}"
error_content = f"""