瀏覽代碼

新增OCR布局管理模块,支持标准、滚动和紧凑布局的实现,优化交互式图片显示和内容渲染功能

zhch158_admin 2 月之前
父節點
當前提交
e95dc87bde
共有 1 個文件被更改,包括 460 次插入0 次删除
  1. 460 0
      ocr_validator_layout.py

+ 460 - 0
ocr_validator_layout.py

@@ -0,0 +1,460 @@
+#!/usr/bin/env python3
+"""
+OCR验证工具的布局管理模块
+包含标准布局、滚动布局、紧凑布局的实现
+"""
+
+import streamlit as st
+from pathlib import Path
+from PIL import Image
+from typing import Dict, List, Optional
+import plotly.graph_objects as go
+
+from ocr_validator_utils import (
+    convert_html_table_to_markdown, 
+    parse_html_tables,
+    draw_bbox_on_image
+)
+
+
+class OCRLayoutManager:
+    """OCR布局管理器"""
+    
+    def __init__(self, validator):
+        self.validator = validator
+        self.config = validator.config
+    
+    def create_interactive_plot(self, image: Image.Image, selected_bbox: Optional[List[int]] = None) -> go.Figure:
+        """创建交互式图片显示"""
+        fig = go.Figure()
+        
+        # 添加图片
+        fig.add_layout_image(
+            dict(
+                source=image,
+                xref="x", yref="y",
+                x=0, y=image.height,
+                sizex=image.width, sizey=image.height,
+                sizing="stretch", opacity=1.0, layer="below"
+            )
+        )
+        
+        colors = self.config['styles']['colors']
+        
+        # 添加所有bbox(浅色显示)
+        for text, info_list in self.validator.text_bbox_mapping.items():
+            for info in info_list:
+                bbox = info['bbox']
+                if len(bbox) >= 4:
+                    x1, y1, x2, y2 = bbox[:4]
+                    
+                    if text in self.validator.marked_errors:
+                        color = f"rgba(244, 67, 54, 0.3)"  # 错误标记为红色
+                        line_color = colors['error']
+                    else:
+                        color = f"rgba(2, 136, 209, 0.2)"  # 默认浅蓝色
+                        line_color = colors['primary']
+                    
+                    fig.add_shape(
+                        type="rect",
+                        x0=x1, y0=image.height-y2,
+                        x1=x2, y1=image.height-y1,
+                        line=dict(color=line_color, width=1),
+                        fillcolor=color,
+                    )
+        
+        # 高亮显示选中的bbox
+        if selected_bbox and len(selected_bbox) >= 4:
+            x1, y1, x2, y2 = selected_bbox[:4]
+            fig.add_shape(
+                type="rect",
+                x0=x1, y0=image.height-y2,
+                x1=x2, y1=image.height-y1,
+                line=dict(color=colors['error'], width=3),
+                fillcolor="rgba(255, 0, 0, 0.2)",
+            )
+        
+        # 设置布局 - 增加图片大小并确保从顶部开始显示
+        fig.update_xaxes(visible=False, range=[0, image.width])
+        fig.update_yaxes(visible=False, range=[0, image.height], scaleanchor="x")
+        
+        # 计算合适的显示尺寸
+        aspect_ratio = image.width / image.height
+        display_height = 800  # 增加显示高度
+        display_width = int(display_height * aspect_ratio)
+        
+        fig.update_layout(
+            width=display_width,
+            height=display_height,
+            margin=dict(l=0, r=0, t=0, b=0),
+            xaxis_showgrid=False, yaxis_showgrid=False,
+            plot_bgcolor='white'
+        )
+        
+        return fig
+    
+    def render_content_section(self, layout_type: str = "standard"):
+        """渲染内容区域 - 统一方法"""
+        st.header("📄 OCR识别内容")
+        
+        # 文本选择器
+        if self.validator.text_bbox_mapping:
+            text_options = ["请选择文本..."] + list(self.validator.text_bbox_mapping.keys())
+            selected_index = st.selectbox(
+                "选择要校验的文本",
+                range(len(text_options)),
+                format_func=lambda x: text_options[x][:50] + "..." if len(text_options[x]) > 50 else text_options[x],
+                key=f"{layout_type}_text_selector"
+            )
+            
+            if selected_index > 0:
+                st.session_state.selected_text = text_options[selected_index]
+        else:
+            st.warning("没有找到可点击的文本")
+    
+    def render_md_content(self, layout_type: str):
+        """渲染Markdown内容 - 统一方法"""
+        if not self.validator.md_content:
+            return None, None
+            
+        # 搜索功能
+        search_term = st.text_input(
+            "🔍 搜索文本内容", 
+            placeholder="输入关键词搜索...", 
+            key=f"{layout_type}_search"
+        )
+        
+        display_content = self.validator.md_content
+        if search_term:
+            lines = display_content.split('\n')
+            filtered_lines = [line for line in lines if search_term.lower() in line.lower()]
+            display_content = '\n'.join(filtered_lines)
+            if filtered_lines:
+                st.success(f"找到 {len(filtered_lines)} 行包含 '{search_term}'")
+            else:
+                st.warning(f"未找到包含 '{search_term}' 的内容")
+        
+        # 渲染方式选择
+        render_mode = st.radio(
+            "选择渲染方式",
+            ["HTML渲染", "Markdown渲染", "DataFrame表格", "原始文本"],
+            horizontal=True,
+            key=f"{layout_type}_render_mode"
+        )
+        
+        return display_content, render_mode
+    
+    def render_content_by_mode(self, content: str, render_mode: str, font_size: int, layout_type: str):
+        """根据渲染模式显示内容 - 统一方法"""
+        if content is None or render_mode is None:
+            return
+            
+        if render_mode == "HTML渲染":
+            content_style = f"""
+            <style>
+            .{layout_type}-content-display {{
+                font-size: {font_size}px !important;
+                line-height: 1.4;
+                color: #333333 !important;
+                background-color: #fafafa !important;
+                padding: 10px;
+                border-radius: 5px;
+                border: 1px solid #ddd;
+            }}
+            </style>
+            """
+            st.markdown(content_style, unsafe_allow_html=True)
+            st.markdown(f'<div class="{layout_type}-content-display">{content}</div>', unsafe_allow_html=True)
+            
+        elif render_mode == "Markdown渲染":
+            converted_content = convert_html_table_to_markdown(content)
+            content_style = f"""
+            <style>
+            .{layout_type}-content-display {{
+                font-size: {font_size}px !important;
+                line-height: 1.4;
+                color: #333333 !important;
+                background-color: #fafafa !important;
+                padding: 10px;
+                border-radius: 5px;
+                border: 1px solid #ddd;
+            }}
+            </style>
+            """
+            st.markdown(content_style, unsafe_allow_html=True)
+            st.markdown(f'<div class="{layout_type}-content-display">{converted_content}</div>', unsafe_allow_html=True)
+            
+        elif render_mode == "DataFrame表格":
+            if '<table' in content.lower():
+                self.validator.display_html_table_as_dataframe(content)
+            else:
+                st.info("当前内容中没有检测到HTML表格")
+                st.markdown(content, unsafe_allow_html=True)
+        else:  # 原始文本
+            st.text_area(
+                "MD内容预览",
+                content,
+                height=300,
+                key=f"{layout_type}_text_area"
+            )
+    
+    # 三种布局实现
+    def create_standard_layout(self, font_size: int = 10, zoom_level: float = 1.0):
+        """创建标准布局"""
+        if zoom_level is None:
+            zoom_level = self.config['styles']['layout']['default_zoom']
+            
+        # 主要内容区域
+        layout = self.config['styles']['layout']
+        left_col, right_col = st.columns([layout['content_width'], layout['sidebar_width']])
+        
+        with left_col:
+            self.render_content_section("standard")
+            
+            # 显示内容
+            if self.validator.md_content:
+                display_content, render_mode = self.render_md_content("standard")
+                self.render_content_by_mode(display_content, render_mode, font_size, "standard")
+        
+        with right_col:
+            self.create_aligned_image_display(zoom_level, "compact")
+    
+    def create_compact_layout(self, font_size: int = 10, zoom_level: float = 1.0):
+        """创建紧凑的对比布局"""
+        # 主要内容区域
+        layout = self.config['styles']['layout']
+        left_col, right_col = st.columns([layout['content_width'], layout['sidebar_width']])                
+
+        with left_col:
+            self.render_content_section("compact")
+
+            # 只保留一个内容区域高度选择
+            container_height = st.selectbox(
+                "选择内容区域高度", 
+                [400, 600, 800, 1000, 1200], 
+                index=2,
+                key="compact_content_height"
+            )
+            
+            # 快速定位文本选择器(使用不同的key)
+            if self.validator.text_bbox_mapping:
+                text_options = ["请选择文本..."] + list(self.validator.text_bbox_mapping.keys())
+                selected_index = st.selectbox(
+                    "快速定位文本",
+                    range(len(text_options)),
+                    format_func=lambda x: text_options[x][:30] + "..." if len(text_options[x]) > 30 else text_options[x],
+                    key="compact_quick_text_selector"  # 使用不同的key
+                )
+                
+                if selected_index > 0:
+                    st.session_state.selected_text = text_options[selected_index]
+            
+            # 自定义CSS样式
+            st.markdown(f"""
+            <style>
+            .compact-content {{
+                height: {container_height}px;
+                overflow-y: auto;
+                font-size: {font_size}px !important;
+                line-height: 1.4;
+                border: 1px solid #ddd;
+                padding: 10px;
+                background-color: #fafafa !important;
+                font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', monospace;
+                color: #333333 !important;
+            }}
+            
+            .highlight-text {{
+                background-color: #ffeb3b !important;
+                padding: 2px 4px;
+                border-radius: 3px;
+                cursor: pointer;
+                color: #333333 !important;
+            }}
+            
+            .selected-highlight {{
+                background-color: #4caf50 !important;
+                color: white !important;
+            }}
+            </style>
+            """, unsafe_allow_html=True)
+            
+            # 处理并显示OCR内容
+            if self.validator.md_content:
+                # 高亮可点击文本
+                highlighted_content = self.validator.md_content
+                for text in self.validator.text_bbox_mapping.keys():
+                    if len(text) > 2:  # 避免高亮过短的文本
+                        css_class = "highlight-text selected-highlight" if text == st.session_state.selected_text else "highlight-text"
+                        highlighted_content = highlighted_content.replace(
+                            text, 
+                            f'<span class="{css_class}" title="{text[:50]}...">{text}</span>'
+                        )
+                st.markdown(
+                    f'<div class="compact-content">{highlighted_content}</div>', 
+                    unsafe_allow_html=True
+                )
+        
+        with right_col:
+            # 修复的对齐图片显示
+            self.create_aligned_image_display(zoom_level, "compact")
+    
+    def create_aligned_image_display(self, zoom_level: float = 1.0, layout_type: str = "aligned"):
+        """创建与左侧对齐的图片显示"""
+        # 精确对齐CSS
+        st.markdown(f"""
+        <style>
+        .aligned-image-container-{layout_type} {{
+            margin-top: -70px;
+            padding-top: 0px;
+        }}
+        .aligned-image-container-{layout_type} h1 {{
+            margin-top: 0px !important;
+            padding-top: 0px !important;
+        }}
+        </style>
+        """, unsafe_allow_html=True)
+        
+        st.markdown(f'<div class="aligned-image-container-{layout_type}">', unsafe_allow_html=True)
+        st.header("🖼️ 原图标注")
+        
+        # 图片缩放控制
+        col1, col2 = st.columns(2)
+        with col1:
+            current_zoom = st.slider("图片缩放", 0.3, 2.0, zoom_level, 0.1, key=f"{layout_type}_zoom_level")
+        with col2:
+            show_all_boxes = st.checkbox("显示所有框", value=False, key=f"{layout_type}_show_all_boxes")
+        
+        if self.validator.image_path and Path(self.validator.image_path).exists():
+            try:
+                image = Image.open(self.validator.image_path)
+                
+                # 根据缩放级别调整图片大小
+                new_width = int(image.width * current_zoom)
+                new_height = int(image.height * current_zoom)
+                resized_image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
+                
+                # 在固定容器中显示图片
+                selected_bbox = None
+                if st.session_state.selected_text and st.session_state.selected_text in self.validator.text_bbox_mapping:
+                    info = self.validator.text_bbox_mapping[st.session_state.selected_text][0]
+                    # 根据缩放级别调整bbox坐标
+                    bbox = info['bbox']
+                    selected_bbox = [int(coord * current_zoom) for coord in bbox]
+                
+                # 创建交互式图片 - 确保从顶部开始显示
+                fig = self.create_resized_interactive_plot(resized_image, selected_bbox, current_zoom, show_all_boxes)
+                st.plotly_chart(fig, use_container_width=True, key=f"{layout_type}_plot")
+                
+                # 显示选中文本的详细信息
+                if st.session_state.selected_text and st.session_state.selected_text in self.validator.text_bbox_mapping:
+                    st.subheader("📍 选中文本详情")
+                    
+                    info = self.validator.text_bbox_mapping[st.session_state.selected_text][0]
+                    bbox = info['bbox']
+                    
+                    info_col1, info_col2 = st.columns(2)
+                    with info_col1:
+                        st.write(f"**文本内容:** {st.session_state.selected_text[:30]}...")
+                        st.write(f"**类别:** {info['category']}")
+                    
+                    with info_col2:
+                        st.write(f"**位置:** [{', '.join(map(str, bbox))}]")
+                        if len(bbox) >= 4:
+                            st.write(f"**大小:** {bbox[2] - bbox[0]} x {bbox[3] - bbox[1]} px")
+                    
+                    # 错误标记功能
+                    col1, col2 = st.columns(2)
+                    with col1:
+                        if st.button("❌ 标记为错误", key=f"{layout_type}_mark_error"):
+                            st.session_state.marked_errors.add(st.session_state.selected_text)
+                            st.rerun()
+                    
+                    with col2:
+                        if st.button("✅ 取消错误标记", key=f"{layout_type}_unmark_error"):
+                            st.session_state.marked_errors.discard(st.session_state.selected_text)
+                            st.rerun()
+                            
+            except Exception as e:
+                st.error(f"❌ 图片处理失败: {e}")
+        else:
+            st.error("未找到对应的图片文件")
+            if self.validator.image_path:
+                st.write(f"期望路径: {self.validator.image_path}")
+        
+        st.markdown('</div>', unsafe_allow_html=True)
+    
+    def create_resized_interactive_plot(self, image: Image.Image, selected_bbox: Optional[List[int]], zoom_level: float, show_all_boxes: bool) -> go.Figure:
+        """创建可调整大小的交互式图片 - 优化显示和定位"""
+        fig = go.Figure()
+        
+        fig.add_layout_image(
+            dict(
+                source=image,
+                xref="x", yref="y",
+                x=0, y=0,  # 改为从底部开始,这样图片会从顶部显示
+                sizex=image.width, sizey=image.height,
+                sizing="stretch", opacity=1.0, layer="below"
+            )
+        )
+        
+        # 显示所有bbox(如果开启)
+        if show_all_boxes:
+            for text, info_list in self.validator.text_bbox_mapping.items():
+                for info in info_list:
+                    bbox = info['bbox']
+                    if len(bbox) >= 4:
+                        x1, y1, x2, y2 = [coord * zoom_level for coord in bbox[:4]]
+                        
+                        color = "rgba(0, 100, 200, 0.2)"
+                        if text in self.validator.marked_errors:
+                            color = "rgba(255, 0, 0, 0.3)"
+                        
+                        fig.add_shape(
+                            type="rect",
+                            x0=x1, y0=y1,  # 调整坐标系,不再翻转
+                            x1=x2, y1=y2,
+                            line=dict(color=color.replace('0.2', '0.8').replace('0.3', '1.0'), width=1),
+                            fillcolor=color,
+                        )
+        
+        # 高亮显示选中的bbox
+        if selected_bbox and len(selected_bbox) >= 4:
+            x1, y1, x2, y2 = selected_bbox[:4]
+            fig.add_shape(
+                type="rect",
+                x0=x1, y0=y1,  # 调整坐标系
+                x1=x2, y1=y2,
+                line=dict(color="red", width=2),
+                fillcolor="rgba(255, 0, 0, 0.3)",
+            )
+        
+        # 设置坐标轴范围 - 让图片从顶部开始显示
+        fig.update_xaxes(visible=False, range=[0, image.width])
+        fig.update_yaxes(visible=False, range=[image.height, 0], scaleanchor="x")  # 翻转Y轴让图片从顶部开始
+        
+        # 计算更大的显示尺寸
+        aspect_ratio = image.width / image.height
+        display_height = min(1000, max(600, image.height))  # 动态调整高度
+        display_width = int(display_height * aspect_ratio)
+        
+        fig.update_layout(
+            width=display_width,
+            height=display_height,
+            margin=dict(l=0, r=0, t=0, b=0),
+            showlegend=False,
+            plot_bgcolor='white',
+            # 设置初始视图让图片从顶部开始显示
+            xaxis=dict(
+                range=[0, image.width],
+                constrain="domain"
+            ),
+            yaxis=dict(
+                range=[image.height, 0],  # 翻转范围
+                constrain="domain",
+                scaleanchor="x",
+                scaleratio=1
+            )
+        )
+        
+        return fig