소스 검색

添加Streamlit OCR可视化校验工具,支持OCR数据加载、图像标注和交互式展示

zhch158_admin 2 달 전
부모
커밋
aa704bde06
1개의 변경된 파일639개의 추가작업 그리고 0개의 파일을 삭제
  1. 639 0
      streamlit_ocr_validator.py

+ 639 - 0
streamlit_ocr_validator.py

@@ -0,0 +1,639 @@
+#!/usr/bin/env python3
+"""
+基于Streamlit的OCR可视化校验工具(修复版)
+提供丰富的交互组件和更好的用户体验
+"""
+
+import streamlit as st
+import json
+import pandas as pd
+from pathlib import Path
+import numpy as np
+from PIL import Image, ImageDraw, ImageFont
+import cv2
+import base64
+from typing import Dict, List, Optional, Tuple
+import plotly.express as px
+import plotly.graph_objects as go
+from plotly.subplots import make_subplots
+
+# 设置页面配置
+st.set_page_config(
+    page_title="OCR可视化校验工具",
+    page_icon="🔍",
+    layout="wide",
+    initial_sidebar_state="expanded"
+)
+
+# 自定义CSS样式
+st.markdown("""
+<style>
+    .main > div {
+        padding-top: 2rem;
+    }
+    
+    .stSelectbox > div > div > div {
+        background-color: #f0f2f6;
+    }
+    
+    .clickable-text {
+        background-color: #e1f5fe;
+        padding: 2px 6px;
+        border-radius: 4px;
+        border: 1px solid #0288d1;
+        cursor: pointer;
+        margin: 2px;
+        display: inline-block;
+    }
+    
+    .selected-text {
+        background-color: #fff3e0;
+        border-color: #ff9800;
+        font-weight: bold;
+    }
+    
+    .error-text {
+        background-color: #ffebee;
+        border-color: #f44336;
+        color: #d32f2f;
+    }
+    
+    .stats-container {
+        background-color: #f8f9fa;
+        padding: 1rem;
+        border-radius: 8px;
+        border-left: 4px solid #28a745;
+    }
+</style>
+""", unsafe_allow_html=True)
+
+class StreamlitOCRValidator:
+    def __init__(self):
+        self.ocr_data = []
+        self.md_content = ""
+        self.image_path = ""
+        self.text_bbox_mapping = {}
+        self.selected_text = None
+        self.marked_errors = set()
+        
+    def load_ocr_data(self, json_path: str, md_path: Optional[str] = None, image_path: Optional[str] = None):
+        """加载OCR相关数据"""
+        json_file = Path(json_path)
+        
+        # 加载JSON数据
+        try:
+            with open(json_file, 'r', encoding='utf-8') as f:
+                data = json.load(f)
+                # 确保数据是列表格式
+                if isinstance(data, list):
+                    self.ocr_data = data
+                elif isinstance(data, dict) and 'results' in data:
+                    self.ocr_data = data['results']
+                else:
+                    st.error(f"❌ 不支持的JSON格式: {json_path}")
+                    return
+        except Exception as e:
+            st.error(f"❌ 加载JSON文件失败: {e}")
+            return
+        
+        # 推断MD文件路径
+        if md_path is None:
+            md_file = json_file.with_suffix('.md')
+        else:
+            md_file = Path(md_path)
+        
+        if md_file.exists():
+            with open(md_file, 'r', encoding='utf-8') as f:
+                self.md_content = f.read()
+        
+        # 推断图片路径
+        if image_path is None:
+            image_name = json_file.stem
+            sample_data_dir = Path("./sample_data")
+            
+            image_candidates = [
+                sample_data_dir / f"{image_name}.png",
+                sample_data_dir / f"{image_name}.jpg",
+                json_file.parent / f"{image_name}.png",
+                json_file.parent / f"{image_name}.jpg",
+            ]
+            
+            for candidate in image_candidates:
+                if candidate.exists():
+                    self.image_path = str(candidate)
+                    break
+        else:
+            self.image_path = image_path
+        
+        # 处理数据
+        self.process_data()
+    
+    def process_data(self):
+        """处理OCR数据,建立文本到bbox的映射"""
+        self.text_bbox_mapping = {}
+        
+        # 确保 ocr_data 是列表
+        if not isinstance(self.ocr_data, list):
+            st.warning("⚠️ OCR数据格式不正确,期望列表格式")
+            return
+        
+        for i, item in enumerate(self.ocr_data):
+            # 确保 item 是字典类型
+            if not isinstance(item, dict):
+                continue
+                
+            if 'text' in item and 'bbox' in item:
+                text = str(item['text']).strip()
+                if text and text not in ['Picture', '']:
+                    bbox = item['bbox']
+                    # 确保bbox是4个数字的列表
+                    if isinstance(bbox, list) and len(bbox) == 4:
+                        if text not in self.text_bbox_mapping:
+                            self.text_bbox_mapping[text] = []
+                        self.text_bbox_mapping[text].append({
+                            'bbox': bbox,
+                            'category': item.get('category', 'Text'),
+                            'index': i,
+                            'confidence': item.get('confidence', 1.0)
+                        })
+    
+    def draw_bbox_on_image(self, image: Image.Image, bbox: List[int], color: str = "red", width: int = 3) -> Image.Image:
+        """在图片上绘制bbox框"""
+        img_copy = image.copy()
+        draw = ImageDraw.Draw(img_copy)
+        
+        x1, y1, x2, y2 = bbox
+        
+        # 绘制矩形框
+        draw.rectangle([x1, y1, x2, y2], outline=color, width=width)
+        
+        # 添加半透明填充
+        overlay = Image.new('RGBA', img_copy.size, (0, 0, 0, 0))
+        overlay_draw = ImageDraw.Draw(overlay)
+        
+        if color == "red":
+            fill_color = (255, 0, 0, 30)
+        elif color == "blue":
+            fill_color = (0, 0, 255, 30)
+        elif color == "green":
+            fill_color = (0, 255, 0, 30)
+        else:
+            fill_color = (255, 255, 0, 30)
+        
+        overlay_draw.rectangle([x1, y1, x2, y2], fill=fill_color)
+        img_copy = Image.alpha_composite(img_copy.convert('RGBA'), overlay).convert('RGB')
+        
+        return img_copy
+    
+    def create_interactive_plot(self, image: Image.Image, selected_bbox: Optional[List[int]] = None) -> go.Figure:
+        """创建交互式图片显示"""
+        fig = go.Figure()
+        
+        # 添加图片
+        fig.add_layout_image(
+            dict(
+                source=image,
+                xref="x",
+                yref="y",
+                x=0,
+                y=image.height,
+                sizex=image.width,
+                sizey=image.height,
+                sizing="stretch",
+                opacity=1.0,
+                layer="below"
+            )
+        )
+        
+        # 添加所有bbox(浅色显示)
+        for text, info_list in self.text_bbox_mapping.items():
+            for info in info_list:
+                bbox = info['bbox']
+                if len(bbox) >= 4:  # 确保bbox有足够的坐标
+                    x1, y1, x2, y2 = bbox[:4]
+                    
+                    color = "rgba(0, 100, 200, 0.2)"  # 默认浅蓝色
+                    if text in self.marked_errors:
+                        color = "rgba(255, 0, 0, 0.3)"  # 错误标记为红色
+                    
+                    fig.add_shape(
+                        type="rect",
+                        x0=x1, y0=image.height-y2,
+                        x1=x2, y1=image.height-y1,
+                        line=dict(color=color.replace('0.2', '0.8').replace('0.3', '1.0'), width=1),
+                        fillcolor=color,
+                    )
+        
+        # 高亮显示选中的bbox
+        if selected_bbox and len(selected_bbox) >= 4:
+            x1, y1, x2, y2 = selected_bbox[:4]
+            fig.add_shape(
+                type="rect",
+                x0=x1, y0=image.height-y2,
+                x1=x2, y1=image.height-y1,
+                line=dict(color="red", width=3),
+                fillcolor="rgba(255, 0, 0, 0.2)",
+            )
+        
+        # 设置布局
+        fig.update_xaxes(
+            visible=False,
+            range=[0, image.width]
+        )
+        
+        fig.update_yaxes(
+            visible=False,
+            range=[0, image.height],
+            scaleanchor="x"
+        )
+        
+        fig.update_layout(
+            width=800,
+            height=600,
+            margin=dict(l=0, r=0, t=0, b=0),
+            xaxis_showgrid=False,
+            yaxis_showgrid=False,
+            plot_bgcolor='white'
+        )
+        
+        return fig
+    
+    def get_statistics(self) -> Dict:
+        """获取统计信息"""
+        # 先确保 ocr_data 不为空且是列表
+        if not isinstance(self.ocr_data, list) or not self.ocr_data:
+            return {
+                'total_texts': 0,
+                'clickable_texts': 0,
+                'marked_errors': 0,
+                'categories': {},
+                'accuracy_rate': 0
+            }
+        
+        total_texts = len(self.ocr_data)
+        clickable_texts = len(self.text_bbox_mapping)
+        marked_errors = len(self.marked_errors)
+        
+        # 按类别统计 - 添加类型检查
+        categories = {}
+        for item in self.ocr_data:
+            # 确保 item 是字典类型
+            if isinstance(item, dict):
+                category = item.get('category', 'Unknown')
+            elif isinstance(item, str):
+                category = 'Text'  # 字符串类型默认为 Text 类别
+            else:
+                category = 'Unknown'
+            
+            categories[category] = categories.get(category, 0) + 1
+        
+        return {
+            'total_texts': total_texts,
+            'clickable_texts': clickable_texts,
+            'marked_errors': marked_errors,
+            'categories': categories,
+            'accuracy_rate': (clickable_texts - marked_errors) / clickable_texts * 100 if clickable_texts > 0 else 0
+        }
+    
+    def convert_html_table_to_markdown(self, content: str) -> str:
+        """将HTML表格转换为Markdown表格格式"""
+        import re
+        from html import unescape
+        
+        # 简单的HTML表格到Markdown转换
+        def replace_table(match):
+            table_html = match.group(0)
+            
+            # 提取所有行
+            rows = re.findall(r'<tr>(.*?)</tr>', table_html, re.DOTALL | re.IGNORECASE)
+            if not rows:
+                return table_html  # 如果没有找到行,返回原始内容
+            
+            markdown_rows = []
+            for i, row in enumerate(rows):
+                # 提取单元格
+                cells = re.findall(r'<td[^>]*>(.*?)</td>', row, re.DOTALL | re.IGNORECASE)
+                if cells:
+                    # 清理单元格内容
+                    clean_cells = []
+                    for cell in cells:
+                        # 移除HTML标签,保留文本
+                        cell_text = re.sub(r'<[^>]+>', '', cell).strip()
+                        cell_text = unescape(cell_text)  # 解码HTML实体
+                        clean_cells.append(cell_text)
+                    
+                    # 构建Markdown行
+                    markdown_row = '| ' + ' | '.join(clean_cells) + ' |'
+                    markdown_rows.append(markdown_row)
+                    
+                    # 在第一行后添加分隔符
+                    if i == 0:
+                        separator = '| ' + ' | '.join(['---'] * len(clean_cells)) + ' |'
+                        markdown_rows.append(separator)
+            
+            return '\n'.join(markdown_rows) if markdown_rows else table_html
+        
+        # 替换所有HTML表格
+        converted = re.sub(r'<table[^>]*>.*?</table>', replace_table, content, flags=re.DOTALL | re.IGNORECASE)
+        return converted
+    
+    def render_markdown_with_options(self, markdown_content: str, table_format: str = "grid", escape_html: bool = True):
+        """自定义Markdown渲染方法,支持多种选项"""
+        import markdown
+        
+        # 处理HTML表格
+        if escape_html:
+            markdown_content = self.convert_html_table_to_markdown(markdown_content)
+        
+        # 渲染Markdown
+        html_content = markdown.markdown(markdown_content)
+        
+        # 根据选项包裹在特定的HTML结构中
+        if table_format == "grid":
+            # 网格布局
+            wrapped_content = f"""
+            <div class="markdown-grid">
+                {html_content}
+            </div>
+            """
+        elif table_format == "list":
+            # 列表布局
+            wrapped_content = f"""
+            <div class="markdown-list">
+                {html_content}
+            </div>
+            """
+        else:
+            # 默认直接返回
+            wrapped_content = html_content
+        
+        return wrapped_content
+    
+    def display_html_table_as_dataframe(self, html_content: str):
+        """将HTML表格解析为DataFrame显示"""
+        import pandas as pd
+        from io import StringIO
+        
+        try:
+            # 使用pandas直接读取HTML表格
+            tables = pd.read_html(StringIO(html_content))
+            if tables:
+                for i, table in enumerate(tables):
+                    st.subheader(f"表格 {i+1}")
+                    st.dataframe(table, use_container_width=True)
+        except Exception as e:
+            st.error(f"表格解析失败: {e}")
+            # 回退到HTML渲染
+            st.markdown(html_content, unsafe_allow_html=True)
+
+def main():
+    """主应用"""
+    st.title("🔍 OCR可视化校验工具")
+    st.markdown("---")
+    
+    # 初始化session state
+    if 'validator' not in st.session_state:
+        st.session_state.validator = StreamlitOCRValidator()
+    
+    if 'selected_text' not in st.session_state:
+        st.session_state.selected_text = None
+    
+    if 'marked_errors' not in st.session_state:
+        st.session_state.marked_errors = set()
+    
+    # 同步标记的错误到validator
+    st.session_state.validator.marked_errors = st.session_state.marked_errors
+    
+    # 侧边栏 - 文件选择和控制
+    with st.sidebar:
+        st.header("📁 文件选择")
+        
+        # 查找可用的OCR文件
+        output_dir = Path("output")
+        available_files = []
+        
+        if output_dir.exists():
+            for json_file in output_dir.rglob("*.json"):
+                available_files.append(str(json_file))
+        
+        if available_files:
+            selected_file = st.selectbox(
+                "选择OCR结果文件",
+                available_files,
+                index=0
+            )
+            
+            if st.button("🔄 加载文件", type="primary") and selected_file:
+                try:
+                    st.session_state.validator.load_ocr_data(selected_file)
+                    st.success("✅ 文件加载成功!")
+                    st.rerun()  # 重新运行应用以更新界面
+                except Exception as e:
+                    st.error(f"❌ 加载失败: {e}")
+        else:
+            st.warning("未找到OCR结果文件")
+            st.info("请确保output目录下有OCR结果文件")
+        
+        st.markdown("---")
+        
+        # 控制面板
+        st.header("🎛️ 控制面板")
+        
+        if st.button("🧹 清除选择"):
+            st.session_state.selected_text = None
+            st.rerun()
+        
+        if st.button("❌ 清除错误标记"):
+            st.session_state.marked_errors = set()
+            st.rerun()
+        
+        # 显示调试信息
+        if st.checkbox("🔧 调试信息"):
+            st.write("**当前状态:**")
+            st.write(f"- OCR数据项数: {len(st.session_state.validator.ocr_data)}")
+            st.write(f"- 可点击文本: {len(st.session_state.validator.text_bbox_mapping)}")
+            st.write(f"- 选中文本: {st.session_state.selected_text}")
+            st.write(f"- 标记错误数: {len(st.session_state.marked_errors)}")
+            
+            if st.session_state.validator.ocr_data:
+                st.write("**数据类型检查:**")
+                sample_item = st.session_state.validator.ocr_data[0] if st.session_state.validator.ocr_data else None
+                st.write(f"- 第一项类型: {type(sample_item)}")
+                if isinstance(sample_item, dict):
+                    st.write(f"- 第一项键: {list(sample_item.keys())}")
+    
+    # 主内容区域
+    if not st.session_state.validator.ocr_data:
+        st.info("👈 请在左侧选择并加载OCR结果文件")
+        return
+    
+    # 显示统计信息
+    try:
+        stats = st.session_state.validator.get_statistics()
+        
+        col1, col2, col3, col4 = st.columns(4)
+        with col1:
+            st.metric("📊 总文本块", stats['total_texts'])
+        with col2:
+            st.metric("🔗 可点击文本", stats['clickable_texts'])
+        with col3:
+            st.metric("❌ 标记错误", stats['marked_errors'])
+        with col4:
+            st.metric("✅ 准确率", f"{stats['accuracy_rate']:.1f}%")
+        
+        st.markdown("---")
+    except Exception as e:
+        st.error(f"❌ 统计信息计算失败: {e}")
+        return
+    
+    # 主要布局 - 左右分栏
+    left_col, right_col = st.columns([1, 1])
+    
+    # 左侧 - OCR文本内容
+    with left_col:
+        st.header("📄 OCR识别内容")
+        
+        # 文本选择器
+        if st.session_state.validator.text_bbox_mapping:
+            text_options = ["请选择文本..."] + list(st.session_state.validator.text_bbox_mapping.keys())
+            selected_index = st.selectbox(
+                "选择要校验的文本",
+                range(len(text_options)),
+                format_func=lambda x: text_options[x],
+                key="text_selector"
+            )
+            
+            if selected_index > 0:
+                st.session_state.selected_text = text_options[selected_index]
+        else:
+            st.warning("没有找到可点击的文本")
+        
+        # 显示MD内容(可搜索和过滤)
+        if st.session_state.validator.md_content:
+            search_term = st.text_input("🔍 搜索文本内容", placeholder="输入关键词搜索...")
+            
+            display_content = st.session_state.validator.md_content
+            if search_term:
+                lines = display_content.split('\n')
+                filtered_lines = [line for line in lines if search_term.lower() in line.lower()]
+                display_content = '\n'.join(filtered_lines)
+                if filtered_lines:
+                    st.success(f"找到 {len(filtered_lines)} 行包含 '{search_term}'")
+                else:
+                    st.warning(f"未找到包含 '{search_term}' 的内容")
+            
+            # 渲染方式选择
+            render_mode = st.radio(
+                "选择渲染方式",
+                ["HTML渲染", "Markdown渲染", "DataFrame表格", "原始文本"],  # 添加DataFrame选项
+                horizontal=True
+            )
+
+            if render_mode == "HTML渲染":
+                # 使用unsafe_allow_html=True来渲染HTML表格
+                st.markdown(display_content, unsafe_allow_html=True)
+            elif render_mode == "Markdown渲染":
+                # 转换HTML表格为Markdown格式
+                converted_content = st.session_state.validator.convert_html_table_to_markdown(display_content)
+                st.markdown(converted_content)
+            elif render_mode == "DataFrame表格":
+                # 新增:使用DataFrame显示表格
+                if '<table>' in display_content.lower():
+                    st.session_state.validator.display_html_table_as_dataframe(display_content)
+                else:
+                    st.info("当前内容中没有检测到HTML表格")
+                    st.markdown(display_content)
+            else:
+                # 原始文本显示
+                st.text_area(
+                    "MD内容预览",
+                    display_content,
+                    height=300,
+                    help="OCR识别的文本内容"
+                )
+        
+        # 可点击文本列表
+        st.subheader("🎯 可点击文本列表")
+        
+        if st.session_state.validator.text_bbox_mapping:
+            for text, info_list in st.session_state.validator.text_bbox_mapping.items():
+                info = info_list[0]  # 使用第一个bbox信息
+                
+                # 确定显示样式
+                is_selected = (text == st.session_state.selected_text)
+                is_error = (text in st.session_state.marked_errors)
+                
+                # 创建按钮行
+                button_col, error_col = st.columns([4, 1])
+                
+                with button_col:
+                    button_type = "primary" if is_selected else "secondary"
+                    if st.button(f"📍 {text}", key=f"btn_{text}", type=button_type):
+                        st.session_state.selected_text = text
+                        st.rerun()
+                
+                with error_col:
+                    if is_error:
+                        if st.button("✅", key=f"fix_{text}", help="取消错误标记"):
+                            st.session_state.marked_errors.discard(text)
+                            st.rerun()
+                    else:
+                        if st.button("❌", key=f"error_{text}", help="标记为错误"):
+                            st.session_state.marked_errors.add(text)
+                            st.rerun()
+        else:
+            st.info("没有可点击的文本项目")
+    
+    # 右侧 - 图像显示
+    with right_col:
+        st.header("🖼️ 原图标注")
+        
+        if st.session_state.validator.image_path and Path(st.session_state.validator.image_path).exists():
+            try:
+                # 加载图片
+                image = Image.open(st.session_state.validator.image_path)
+                
+                # 创建交互式图片
+                selected_bbox = None
+                if st.session_state.selected_text and st.session_state.selected_text in st.session_state.validator.text_bbox_mapping:
+                    info = st.session_state.validator.text_bbox_mapping[st.session_state.selected_text][0]
+                    selected_bbox = info['bbox']
+                
+                fig = st.session_state.validator.create_interactive_plot(image, selected_bbox)
+                st.plotly_chart(fig, use_container_width=True)
+                
+                # 显示选中文本的详细信息
+                if st.session_state.selected_text:
+                    st.subheader("📍 选中文本详情")
+                    
+                    if st.session_state.selected_text in st.session_state.validator.text_bbox_mapping:
+                        info = st.session_state.validator.text_bbox_mapping[st.session_state.selected_text][0]
+                        bbox = info['bbox']
+                        
+                        info_col1, info_col2 = st.columns(2)
+                        with info_col1:
+                            st.write(f"**文本内容:** {st.session_state.selected_text}")
+                            st.write(f"**类别:** {info['category']}")
+                            st.write(f"**置信度:** {info.get('confidence', 'N/A')}")
+                        
+                        with info_col2:
+                            st.write(f"**位置:** [{', '.join(map(str, bbox))}]")
+                            if len(bbox) >= 4:
+                                st.write(f"**宽度:** {bbox[2] - bbox[0]} px")
+                                st.write(f"**高度:** {bbox[3] - bbox[1]} px")
+                        
+                        # 标记状态
+                        is_error = st.session_state.selected_text in st.session_state.marked_errors
+                        if is_error:
+                            st.error("⚠️ 此文本已标记为错误")
+                        else:
+                            st.success("✅ 此文本未标记错误")
+            except Exception as e:
+                st.error(f"❌ 图片处理失败: {e}")
+        else:
+            st.error("未找到对应的图片文件")
+            if st.session_state.validator.image_path:
+                st.write(f"期望路径: {st.session_state.validator.image_path}")
+
+if __name__ == "__main__":
+    main()