#!/usr/bin/env python3
"""
OCR验证工具的布局管理模块
包含标准布局、滚动布局、紧凑布局的实现
"""
import streamlit as st
from pathlib import Path
from PIL import Image
from typing import Dict, List, Optional
import plotly.graph_objects as go
from ocr_validator_utils import (
convert_html_table_to_markdown,
parse_html_tables,
draw_bbox_on_image
)
class OCRLayoutManager:
"""OCR布局管理器"""
def __init__(self, validator):
self.validator = validator
self.config = validator.config
def create_interactive_plot(self, image: Image.Image, selected_bbox: Optional[List[int]] = None) -> go.Figure:
"""创建交互式图片显示"""
fig = go.Figure()
# 添加图片
fig.add_layout_image(
dict(
source=image,
xref="x", yref="y",
x=0, y=image.height,
sizex=image.width, sizey=image.height,
sizing="stretch", opacity=1.0, layer="below"
)
)
colors = self.config['styles']['colors']
# 添加所有bbox(浅色显示)
for text, info_list in self.validator.text_bbox_mapping.items():
for info in info_list:
bbox = info['bbox']
if len(bbox) >= 4:
x1, y1, x2, y2 = bbox[:4]
if text in self.validator.marked_errors:
color = f"rgba(244, 67, 54, 0.3)" # 错误标记为红色
line_color = colors['error']
else:
color = f"rgba(2, 136, 209, 0.2)" # 默认浅蓝色
line_color = colors['primary']
fig.add_shape(
type="rect",
x0=x1, y0=image.height-y2,
x1=x2, y1=image.height-y1,
line=dict(color=line_color, width=1),
fillcolor=color,
)
# 高亮显示选中的bbox
if selected_bbox and len(selected_bbox) >= 4:
x1, y1, x2, y2 = selected_bbox[:4]
fig.add_shape(
type="rect",
x0=x1, y0=image.height-y2,
x1=x2, y1=image.height-y1,
line=dict(color=colors['error'], width=3),
fillcolor="rgba(255, 0, 0, 0.2)",
)
# 设置布局 - 增加图片大小并确保从顶部开始显示
fig.update_xaxes(visible=False, range=[0, image.width])
fig.update_yaxes(visible=False, range=[0, image.height], scaleanchor="x")
# 计算合适的显示尺寸
aspect_ratio = image.width / image.height
display_height = 800 # 增加显示高度
display_width = int(display_height * aspect_ratio)
fig.update_layout(
width=display_width,
height=display_height,
margin=dict(l=0, r=0, t=0, b=0),
xaxis_showgrid=False, yaxis_showgrid=False,
plot_bgcolor='white'
)
return fig
def render_content_section(self, layout_type: str = "standard"):
"""渲染内容区域 - 统一方法"""
st.header("📄 OCR识别内容")
# 文本选择器
if self.validator.text_bbox_mapping:
text_options = ["请选择文本..."] + list(self.validator.text_bbox_mapping.keys())
selected_index = st.selectbox(
"选择要校验的文本",
range(len(text_options)),
format_func=lambda x: text_options[x][:50] + "..." if len(text_options[x]) > 50 else text_options[x],
key=f"{layout_type}_text_selector"
)
if selected_index > 0:
st.session_state.selected_text = text_options[selected_index]
else:
st.warning("没有找到可点击的文本")
def render_md_content(self, layout_type: str):
"""渲染Markdown内容 - 统一方法"""
if not self.validator.md_content:
return None, None
# 搜索功能
search_term = st.text_input(
"🔍 搜索文本内容",
placeholder="输入关键词搜索...",
key=f"{layout_type}_search"
)
display_content = self.validator.md_content
if search_term:
lines = display_content.split('\n')
filtered_lines = [line for line in lines if search_term.lower() in line.lower()]
display_content = '\n'.join(filtered_lines)
if filtered_lines:
st.success(f"找到 {len(filtered_lines)} 行包含 '{search_term}'")
else:
st.warning(f"未找到包含 '{search_term}' 的内容")
# 渲染方式选择
render_mode = st.radio(
"选择渲染方式",
["HTML渲染", "Markdown渲染", "DataFrame表格", "原始文本"],
horizontal=True,
key=f"{layout_type}_render_mode"
)
return display_content, render_mode
def render_content_by_mode(self, content: str, render_mode: str, font_size: int, layout_type: str):
"""根据渲染模式显示内容 - 统一方法"""
if content is None or render_mode is None:
return
if render_mode == "HTML渲染":
content_style = f"""
"""
st.markdown(content_style, unsafe_allow_html=True)
st.markdown(f'
{content}
', unsafe_allow_html=True)
elif render_mode == "Markdown渲染":
converted_content = convert_html_table_to_markdown(content)
content_style = f"""
"""
st.markdown(content_style, unsafe_allow_html=True)
st.markdown(f'{converted_content}
', unsafe_allow_html=True)
elif render_mode == "DataFrame表格":
if ' 30 else text_options[x],
key="compact_quick_text_selector" # 使用不同的key
)
if selected_index > 0:
st.session_state.selected_text = text_options[selected_index]
# 自定义CSS样式
st.markdown(f"""
""", unsafe_allow_html=True)
# 处理并显示OCR内容
if self.validator.md_content:
# 高亮可点击文本
highlighted_content = self.validator.md_content
for text in self.validator.text_bbox_mapping.keys():
if len(text) > 2: # 避免高亮过短的文本
css_class = "highlight-text selected-highlight" if text == st.session_state.selected_text else "highlight-text"
highlighted_content = highlighted_content.replace(
text,
f'{text}'
)
st.markdown(
f'{highlighted_content}
',
unsafe_allow_html=True
)
with right_col:
# 修复的对齐图片显示
self.create_aligned_image_display(zoom_level, "compact")
def create_aligned_image_display(self, zoom_level: float = 1.0, layout_type: str = "aligned"):
"""创建与左侧对齐的图片显示"""
# 精确对齐CSS
st.markdown(f"""
""", unsafe_allow_html=True)
st.markdown(f'', unsafe_allow_html=True)
st.header("🖼️ 原图标注")
# 图片缩放控制
col1, col2 = st.columns(2)
with col1:
current_zoom = st.slider("图片缩放", 0.3, 2.0, zoom_level, 0.1, key=f"{layout_type}_zoom_level")
with col2:
show_all_boxes = st.checkbox("显示所有框", value=False, key=f"{layout_type}_show_all_boxes")
if self.validator.image_path and Path(self.validator.image_path).exists():
try:
image = Image.open(self.validator.image_path)
# 根据缩放级别调整图片大小
new_width = int(image.width * current_zoom)
new_height = int(image.height * current_zoom)
resized_image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
# 在固定容器中显示图片
selected_bbox = None
if st.session_state.selected_text and st.session_state.selected_text in self.validator.text_bbox_mapping:
info = self.validator.text_bbox_mapping[st.session_state.selected_text][0]
# 根据缩放级别调整bbox坐标
bbox = info['bbox']
selected_bbox = [int(coord * current_zoom) for coord in bbox]
# 创建交互式图片 - 确保从顶部开始显示
fig = self.create_resized_interactive_plot(resized_image, selected_bbox, current_zoom, show_all_boxes)
st.plotly_chart(fig, use_container_width=True, key=f"{layout_type}_plot")
# 显示选中文本的详细信息
if st.session_state.selected_text and st.session_state.selected_text in self.validator.text_bbox_mapping:
st.subheader("📍 选中文本详情")
info = self.validator.text_bbox_mapping[st.session_state.selected_text][0]
bbox = info['bbox']
info_col1, info_col2 = st.columns(2)
with info_col1:
st.write(f"**文本内容:** {st.session_state.selected_text[:30]}...")
st.write(f"**类别:** {info['category']}")
with info_col2:
st.write(f"**位置:** [{', '.join(map(str, bbox))}]")
if len(bbox) >= 4:
st.write(f"**大小:** {bbox[2] - bbox[0]} x {bbox[3] - bbox[1]} px")
# 错误标记功能
col1, col2 = st.columns(2)
with col1:
if st.button("❌ 标记为错误", key=f"{layout_type}_mark_error"):
st.session_state.marked_errors.add(st.session_state.selected_text)
st.rerun()
with col2:
if st.button("✅ 取消错误标记", key=f"{layout_type}_unmark_error"):
st.session_state.marked_errors.discard(st.session_state.selected_text)
st.rerun()
except Exception as e:
st.error(f"❌ 图片处理失败: {e}")
else:
st.error("未找到对应的图片文件")
if self.validator.image_path:
st.write(f"期望路径: {self.validator.image_path}")
st.markdown('
', unsafe_allow_html=True)
def create_resized_interactive_plot(self, image: Image.Image, selected_bbox: Optional[List[int]], zoom_level: float, show_all_boxes: bool) -> go.Figure:
"""创建可调整大小的交互式图片 - 优化显示和定位"""
fig = go.Figure()
fig.add_layout_image(
dict(
source=image,
xref="x", yref="y",
x=0, y=0, # 改为从底部开始,这样图片会从顶部显示
sizex=image.width, sizey=image.height,
sizing="stretch", opacity=1.0, layer="below"
)
)
# 显示所有bbox(如果开启)
if show_all_boxes:
for text, info_list in self.validator.text_bbox_mapping.items():
for info in info_list:
bbox = info['bbox']
if len(bbox) >= 4:
x1, y1, x2, y2 = [coord * zoom_level for coord in bbox[:4]]
color = "rgba(0, 100, 200, 0.2)"
if text in self.validator.marked_errors:
color = "rgba(255, 0, 0, 0.3)"
fig.add_shape(
type="rect",
x0=x1, y0=y1, # 调整坐标系,不再翻转
x1=x2, y1=y2,
line=dict(color=color.replace('0.2', '0.8').replace('0.3', '1.0'), width=1),
fillcolor=color,
)
# 高亮显示选中的bbox
if selected_bbox and len(selected_bbox) >= 4:
x1, y1, x2, y2 = selected_bbox[:4]
fig.add_shape(
type="rect",
x0=x1, y0=y1, # 调整坐标系
x1=x2, y1=y2,
line=dict(color="red", width=2),
fillcolor="rgba(255, 0, 0, 0.3)",
)
# 设置坐标轴范围 - 让图片从顶部开始显示
fig.update_xaxes(visible=False, range=[0, image.width])
fig.update_yaxes(visible=False, range=[image.height, 0], scaleanchor="x") # 翻转Y轴让图片从顶部开始
# 计算更大的显示尺寸
aspect_ratio = image.width / image.height
display_height = min(1000, max(600, image.height)) # 动态调整高度
display_width = int(display_height * aspect_ratio)
fig.update_layout(
width=display_width,
height=display_height,
margin=dict(l=0, r=0, t=0, b=0),
showlegend=False,
plot_bgcolor='white',
# 设置初始视图让图片从顶部开始显示
xaxis=dict(
range=[0, image.width],
constrain="domain"
),
yaxis=dict(
range=[image.height, 0], # 翻转范围
constrain="domain",
scaleanchor="x",
scaleratio=1
)
)
return fig