#!/usr/bin/env python3
"""
OCR验证工具的布局管理模块
包含标准布局、滚动布局、紧凑布局的实现
"""
import streamlit as st
from pathlib import Path
from PIL import Image
from typing import Dict, List, Optional
import plotly.graph_objects as go
from typing import Tuple
from ocr_validator_utils import (
convert_html_table_to_markdown,
parse_html_tables,
draw_bbox_on_image,
rotate_image_and_coordinates,
get_ocr_tool_rotation_config,
detect_image_orientation_by_opencv # 新增导入
)
class OCRLayoutManager:
"""OCR布局管理器"""
def __init__(self, validator):
self.validator = validator
self.config = validator.config
self._rotated_image_cache = {}
self._cache_max_size = 10
self._orientation_cache = {} # 缓存方向检测结果
# self._auto_detected_angle = 0.0 # 自动检测的旋转角度缓存
def clear_image_cache(self):
"""清理所有图像缓存"""
self._rotated_image_cache.clear()
def clear_cache_for_image(self, image_path: str):
"""清理指定图像的所有缓存"""
keys_to_remove = [key for key in self._rotated_image_cache.keys() if key.startswith(image_path)]
for key in keys_to_remove:
del self._rotated_image_cache[key]
def get_cache_info(self) -> dict:
"""获取缓存信息"""
return {
'cache_size': len(self._rotated_image_cache),
'cached_images': list(self._rotated_image_cache.keys()),
'max_size': self._cache_max_size
}
def _manage_cache_size(self):
"""管理缓存大小,超出限制时清理最旧的缓存"""
if len(self._rotated_image_cache) > self._cache_max_size:
# 删除最旧的缓存项(FIFO策略)
oldest_key = next(iter(self._rotated_image_cache))
del self._rotated_image_cache[oldest_key]
def detect_and_suggest_rotation(self, image_path: str) -> Dict:
"""检测并建议图片旋转角度"""
if image_path in self._orientation_cache:
return self._orientation_cache[image_path]
# 使用自动检测功能
detection_result = detect_image_orientation_by_opencv(image_path)
# 缓存结果
self._orientation_cache[image_path] = detection_result
return detection_result
def get_rotation_angle(self) -> float:
"""获取旋转角度 - 增强版本支持自动检测"""
# 首先尝试从OCR数据中获取(PPStructV3等)
if self.validator.ocr_data:
for item in self.validator.ocr_data:
if isinstance(item, dict) and 'rotation_angle' in item:
return item['rotation_angle']
# 如果没有预设角度,尝试自动检测
if hasattr(self, '_auto_detected_angle'):
return self._auto_detected_angle
return 0.0
def load_and_rotate_image(self, image_path: str) -> Optional[Image.Image]:
"""加载并根据需要旋转图像"""
if not image_path or not Path(image_path).exists():
return None
# 检查缓存
rotation_angle = self.get_rotation_angle()
cache_key = f"{image_path}_{rotation_angle}"
if cache_key in self._rotated_image_cache:
return self._rotated_image_cache[cache_key]
try:
image = Image.open(image_path)
# 如果需要旋转
if rotation_angle != 0:
# 获取OCR工具的旋转配置
rotation_config = get_ocr_tool_rotation_config(self.validator.ocr_data, self.config)
# st.info(f"🔄 检测到文档旋转角度: {rotation_angle}°,正在处理图像和坐标...")
# st.info(f"📋 OCR工具配置: 坐标{'已预旋转' if rotation_config['coordinates_are_pre_rotated'] else '需要旋转'}")
# 判断是否需要旋转坐标
if rotation_config['coordinates_are_pre_rotated']:
# PPStructV3: 坐标已经是旋转后的,只旋转图像
if rotation_angle == 270:
rotated_image = image.rotate(-90, expand=True) # 顺时针90度
elif rotation_angle == 90:
rotated_image = image.rotate(90, expand=True) # 逆时针90度
elif rotation_angle == 180:
rotated_image = image.rotate(180, expand=True) # 180度
else:
rotated_image = image.rotate(-rotation_angle, expand=True)
# 坐标不需要变换,因为JSON中已经是正确的坐标
self._rotated_image_cache[cache_key] = rotated_image
self._manage_cache_size()
return rotated_image
else:
# Dots OCR: 需要同时旋转图像和坐标
# 收集所有bbox坐标
all_bboxes = []
text_to_bbox_map = {} # 记录文本到bbox索引的映射
bbox_index = 0
for text, info_list in self.validator.text_bbox_mapping.items():
text_to_bbox_map[text] = []
for info in info_list:
all_bboxes.append(info['bbox'])
text_to_bbox_map[text].append(bbox_index)
bbox_index += 1
# 旋转图像和坐标
rotated_image, rotated_bboxes = rotate_image_and_coordinates(
image, rotation_angle, all_bboxes,
rotate_coordinates=not rotation_config['coordinates_are_pre_rotated']
)
# 更新bbox映射 - 使用映射关系确保正确对应
for text, bbox_indices in text_to_bbox_map.items():
for i, bbox_idx in enumerate(bbox_indices):
if bbox_idx < len(rotated_bboxes) and i < len(self.validator.text_bbox_mapping[text]):
self.validator.text_bbox_mapping[text][i]['bbox'] = rotated_bboxes[bbox_idx]
# 缓存结果
self._rotated_image_cache[cache_key] = rotated_image
self._manage_cache_size()
return rotated_image
else:
# 无需旋转,直接缓存原图
self._rotated_image_cache[cache_key] = image
self._manage_cache_size() # 检查并管理缓存大小
return image
except Exception as e:
st.error(f"❌ 图像加载失败: {e}")
return None
def render_content_section(self, layout_type: str = "compact"):
"""渲染内容区域 - 统一方法"""
st.header("📄 OCR识别内容")
# 显示旋转信息
# rotation_angle = self.get_rotation_angle()
# if rotation_angle != 0:
# st.info(f"📐 文档旋转角度: {rotation_angle}°")
# 文本选择器
if self.validator.text_bbox_mapping:
text_options = ["请选择文本..."] + list(self.validator.text_bbox_mapping.keys())
selected_index = st.selectbox(
"选择要校验的文本",
range(len(text_options)),
format_func=lambda x: text_options[x][:50] + "..." if len(text_options[x]) > 50 else text_options[x],
key=f"{layout_type}_text_selector"
)
if selected_index > 0:
st.session_state.selected_text = text_options[selected_index]
else:
st.warning("没有找到可点击的文本")
def render_md_content(self, layout_type: str):
"""渲染Markdown内容 - 统一方法"""
if not self.validator.md_content:
return None, None
# 搜索功能
search_term = st.text_input(
"🔍 搜索文本内容",
placeholder="输入关键词搜索...",
key=f"{layout_type}_search"
)
display_content = self.validator.md_content
if search_term:
lines = display_content.split('\n')
filtered_lines = [line for line in lines if search_term.lower() in line.lower()]
display_content = '\n'.join(filtered_lines)
if filtered_lines:
st.success(f"找到 {len(filtered_lines)} 行包含 '{search_term}'")
else:
st.warning(f"未找到包含 '{search_term}' 的内容")
return display_content
def render_content_by_mode(self, content: str, render_mode: str, font_size: int, container_height: int, layout_type: str):
"""根据渲染模式显示内容 - 增强版本"""
if content is None or render_mode is None:
return
if render_mode == "HTML渲染":
# 增强的HTML渲染样式,支持横向滚动
content_style = f"""
"""
st.markdown(content_style, unsafe_allow_html=True)
st.markdown(f'
{content}
', unsafe_allow_html=True)
elif render_mode == "Markdown渲染":
converted_content = convert_html_table_to_markdown(content)
st.markdown(converted_content, unsafe_allow_html=True)
elif render_mode == "DataFrame表格":
if ' 30 else text_options[x],
key="compact_quick_text_selector" # 使用不同的key
)
if selected_index > 0:
st.session_state.selected_text = text_options[selected_index]
# 处理并显示OCR内容
if self.validator.md_content:
# 高亮可点击文本
highlighted_content = self.validator.md_content
for text in self.validator.text_bbox_mapping.keys():
if len(text) > 2: # 避免高亮过短的文本
css_class = "highlight-text selected-highlight" if text == st.session_state.selected_text else "highlight-text"
highlighted_content = highlighted_content.replace(
text,
f'{text}'
)
self.render_content_by_mode(highlighted_content, "HTML渲染", font_size, container_height, layout_type)
with right_col:
# 修复的对齐图片显示
self.create_aligned_image_display(zoom_level, "compact")
def create_aligned_image_display(self, zoom_level: float = 1.0, layout_type: str = "aligned"):
"""创建与左侧对齐的图片显示 - 修复显示问题"""
st.header("🖼️ 原图标注")
# 图片控制选项
col1, col2, col3, col4 = st.columns(4)
with col1:
# 判断{layout_type}_zoom_level是否有值,如果有值直接使用,否则使用传入的zoom_level
current_zoom = self.validator.zoom_level
current_zoom = st.slider("图片缩放", 0.3, 2.0, current_zoom, 0.1, key=f"{layout_type}_zoom_level")
if current_zoom != self.validator.zoom_level:
self.validator.zoom_level = current_zoom
with col2:
# 判断{layout_type}_show_all_boxes是否有值,如果有值直接使用,否则默认False
# if f"{layout_type}_show_all_boxes" not in st.session_state:
# st.session_state[f"{layout_type}_show_all_boxes"] = False
show_all_boxes = st.checkbox(
"显示所有框",
# value=st.session_state[f"{layout_type}_show_all_boxes"],
value = self.validator.show_all_boxes,
key=f"{layout_type}_show_all_boxes"
)
if show_all_boxes != self.validator.show_all_boxes:
self.validator.show_all_boxes = show_all_boxes
with col3:
# 判断{layout_type}_fit_to_container是否有值,如果有值直接使用,否则默认True
fit_to_container = st.checkbox(
"适应容器",
value=self.validator.fit_to_container,
key=f"{layout_type}_fit_to_container"
)
if fit_to_container != self.validator.fit_to_container:
self.validator.fit_to_container = fit_to_container
with col4:
# 显示当前角度状态
current_angle = self.get_rotation_angle()
st.metric("当前角度", f"{current_angle}°", label_visibility="collapsed")
# 方向检测控制面板
with st.expander("🔄 图片方向检测", expanded=False):
col1, col2, col3 = st.columns([1, 1, 1], width='stretch')
with col1:
manual_angle = st.selectbox(
"设置角度",
[0, 90, 180, 270],
index = 0,
label_visibility="collapsed",
# key=f"{layout_type}_manual_angle"
)
# if st.button("应用手动角度", key=f"{layout_type}_apply_manual"):
if not hasattr(self, '_auto_detected_angle') or self._auto_detected_angle != manual_angle:
self._auto_detected_angle = float(manual_angle)
# st.success(f"已设置旋转角度为 {manual_angle}")
# 需要清除图片缓存,以及text_bbox_mapping中的bbox
self.clear_image_cache()
self.validator.process_data()
st.rerun()
with col2:
if st.button("🔍 自动检测方向", key=f"{layout_type}_detect_orientation"):
if self.validator.image_path:
with st.spinner("正在检测图片方向..."):
detection_result = self.detect_and_suggest_rotation(self.validator.image_path)
st.session_state[f'{layout_type}_detection_result'] = detection_result
st.rerun()
with col3:
if st.button("🔄 重置角度", key=f"{layout_type}_reset_angle"):
if hasattr(self, '_auto_detected_angle'):
delattr(self, '_auto_detected_angle')
st.success("已重置旋转角度")
# 需要清除图片缓存,以及text_bbox_mapping中的bbox
self.clear_image_cache()
self.validator.process_data()
st.rerun()
# 显示检测结果
if f'{layout_type}_detection_result' in st.session_state:
result = st.session_state[f'{layout_type}_detection_result']
st.markdown("### 🎯 检测结果")
# 结果概览
result_col1, result_col2, result_col3 = st.columns(3)
with result_col1:
st.metric("建议角度", f"{result['detected_angle']}°")
with result_col2:
st.metric("置信度", f"{result['confidence']:.2%}")
with result_col3:
confidence_color = "🟢" if result['confidence'] > 0.7 else "🟡" if result['confidence'] > 0.4 else "🔴"
st.metric("可信度", f"{confidence_color}")
# 详细信息
st.write(f"**检测信息:** {result['message']}")
if 'method_details' in result:
st.write("**方法详情:**")
for detail in result['method_details']:
st.write(f"• {detail}")
# 应用建议角度
if result['confidence'] > 0.3 and result['detected_angle'] != 0:
if st.button(f"✅ 应用建议角度 {result['detected_angle']}°", key=f"{layout_type}_apply_suggested"):
self._auto_detected_angle = result['detected_angle']
st.success(f"已应用建议角度 {result['detected_angle']}°")
# 需要清除图片缓存,以及text_bbox_mapping中的bbox
self.clear_image_cache()
self.validator.process_data()
st.rerun()
# 显示个别方法的结果
if 'individual_results' in result and len(result['individual_results']) > 1:
with st.expander("📊 各方法检测详情", expanded=False):
for i, individual in enumerate(result['individual_results']):
st.write(f"**方法 {i+1}: {individual['method']}**")
st.write(f"角度: {individual['detected_angle']}°, 置信度: {individual['confidence']:.2f}")
st.write(f"信息: {individual['message']}")
if 'error' in individual:
st.error(f"错误: {individual['error']}")
st.write("---")
# 使用增强的图像加载方法
image = self.load_and_rotate_image(self.validator.image_path)
if image:
try:
# 根据缩放级别调整图片大小
new_width = int(image.width * current_zoom)
new_height = int(image.height * current_zoom)
resized_image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
# 计算选中的bbox
selected_bbox = None
if st.session_state.selected_text and st.session_state.selected_text in self.validator.text_bbox_mapping:
info = self.validator.text_bbox_mapping[st.session_state.selected_text][0]
bbox = info['bbox']
selected_bbox = [int(coord * current_zoom) for coord in bbox]
# 收集所有框
all_boxes = []
if show_all_boxes:
for text, info_list in self.validator.text_bbox_mapping.items():
for info in info_list:
bbox = info['bbox']
if len(bbox) >= 4:
scaled_bbox = [coord * current_zoom for coord in bbox]
all_boxes.append(scaled_bbox)
# 创建交互式图片
fig = self.create_resized_interactive_plot(resized_image, selected_bbox, current_zoom, all_boxes)
plot_config = {
'displayModeBar': True,
'modeBarButtonsToRemove': ['zoom2d', 'select2d', 'lasso2d', 'autoScale2d'],
'scrollZoom': True,
'doubleClick': 'reset'
}
st.plotly_chart(
fig,
use_container_width=fit_to_container,
config=plot_config,
key=f"{layout_type}_plot"
)
# 显示选中文本的详细信息
if st.session_state.selected_text and st.session_state.selected_text in self.validator.text_bbox_mapping:
st.subheader("📍 选中文本详情")
info = self.validator.text_bbox_mapping[st.session_state.selected_text][0]
bbox = info['bbox']
info_col1, info_col2 = st.columns(2)
with info_col1:
st.write(f"**文本内容:** {st.session_state.selected_text[:30]}...")
st.write(f"**类别:** {info['category']}")
# 显示旋转信息
rotation_angle = self.get_rotation_angle()
if rotation_angle != 0:
st.write(f"**旋转角度:** {rotation_angle}°")
with info_col2:
st.write(f"**位置:** [{', '.join(map(str, bbox))}]")
if len(bbox) >= 4:
st.write(f"**大小:** {bbox[2] - bbox[0]} x {bbox[3] - bbox[1]} px")
# 错误标记功能
col1, col2 = st.columns(2)
with col1:
if st.button("❌ 标记为错误", key=f"{layout_type}_mark_error"):
st.session_state.marked_errors.add(st.session_state.selected_text)
st.rerun()
with col2:
if st.button("✅ 取消错误标记", key=f"{layout_type}_unmark_error"):
st.session_state.marked_errors.discard(st.session_state.selected_text)
st.rerun()
# 增强的调试信息
with st.expander("🔍 图像和坐标调试信息", expanded=False):
rotation_angle = self.get_rotation_angle()
rotation_config = get_ocr_tool_rotation_config(self.validator.ocr_data, self.config)
col_debug1, col_debug2, col_debug3 = st.columns(3)
with col_debug1:
st.write("**图像信息:**")
st.write(f"原始尺寸: {image.width} x {image.height}")
st.write(f"缩放后尺寸: {resized_image.width} x {resized_image.height}")
st.write(f"当前角度: {rotation_angle}°")
with col_debug2:
st.write("**坐标信息:**")
if selected_bbox:
st.write(f"选中框: {selected_bbox}")
st.write(f"总框数: {len(all_boxes)}")
st.write(f"文本框数: {len(self.validator.text_bbox_mapping)}")
with col_debug3:
st.write("**配置信息:**")
st.write(f"工具类型: {rotation_config.get('coordinates_are_pre_rotated', 'unknown')}")
st.write(f"缓存状态: {len(self._rotated_image_cache)} 项")
if hasattr(self, '_auto_detected_angle'):
st.write(f"自动检测角度: {self._auto_detected_angle}°")
except Exception as e:
st.error(f"❌ 图片处理失败: {e}")
st.exception(e)
else:
st.error("未找到对应的图片文件")
if self.validator.image_path:
st.write(f"期望路径: {self.validator.image_path}")
st.markdown('', unsafe_allow_html=True)
def create_resized_interactive_plot(self, image: Image.Image, selected_bbox: Optional[List[int]], zoom_level: float, all_boxes: list[tuple]) -> go.Figure:
"""
创建可调整大小的交互式图片 - 修复图像显示和bbox对齐问题
图片,box坐标全部是已缩放,旋转后的坐标
"""
fig = go.Figure()
# 添加图片 - Plotly坐标系,原点在左下角
fig.add_layout_image(
dict(
source=image,
xref="x", yref="y",
x=0, y=image.height, # 图片左下角在Plotly坐标系中的位置
sizex=image.width,
sizey=image.height,
sizing="stretch",
opacity=1.0,
layer="below"
)
)
# 显示所有bbox - 需要坐标转换
if len(all_boxes) > 0:
for bbox in all_boxes:
if len(bbox) >= 4:
x1, y1, x2, y2 = bbox[:4]
# 转换为Plotly坐标系(翻转Y轴)
plot_x1 = x1
plot_x2 = x2
plot_y1 = image.height - y2 # JSON的y2 -> Plotly的底部
plot_y2 = image.height - y1 # JSON的y1 -> Plotly的顶部
color = "rgba(0, 100, 200, 0.2)"
fig.add_shape(
type="rect",
x0=plot_x1, y0=plot_y1,
x1=plot_x2, y1=plot_y2,
line=dict(color="blue", width=1),
fillcolor=color,
)
# 高亮显示选中的bbox
if selected_bbox and len(selected_bbox) >= 4:
x1, y1, x2, y2 = selected_bbox[:4]
# 转换为Plotly坐标系
plot_x1 = x1
plot_x2 = x2
plot_y1 = image.height - y2 # 翻转Y坐坐标
plot_y2 = image.height - y1 # 翻转Y坐标
fig.add_shape(
type="rect",
x0=plot_x1, y0=plot_y1,
x1=plot_x2, y1=plot_y2,
line=dict(color="red", width=3),
fillcolor="rgba(255, 0, 0, 0.3)",
)
# 修复:优化显示尺寸计算
max_display_width = 800
max_display_height = 600
# 计算合适的显示尺寸,保持宽高比
aspect_ratio = image.width / image.height
if aspect_ratio > 1: # 宽图
display_width = min(max_display_width, image.width)
display_height = int(display_width / aspect_ratio)
else: # 高图
display_height = min(max_display_height, image.height)
display_width = int(display_height * aspect_ratio)
# 修复:设置合理的布局参数
fig.update_layout(
width=display_width,
height=display_height,
margin=dict(l=0, r=0, t=0, b=0), # 移除所有边距
showlegend=False,
plot_bgcolor='white',
dragmode="pan",
# 修复:X轴设置
xaxis=dict(
visible=False,
range=[0, image.width],
constrain="domain",
fixedrange=False,
autorange=False,
showgrid=False,
zeroline=False
),
# 修复:Y轴设置,确保范围正确
yaxis=dict(
visible=False,
range=[0, image.height], # 确保Y轴范围从0到图片高度
constrain="domain",
scaleanchor="x",
scaleratio=1,
fixedrange=False,
autorange=False,
showgrid=False,
zeroline=False
)
)
return fig