#!/usr/bin/env python3 """ 基于Streamlit的OCR可视化校验工具(重构版) 提供丰富的交互组件和更好的用户体验 """ import streamlit as st from pathlib import Path from PIL import Image from typing import Dict, List, Optional import plotly.graph_objects as go from io import BytesIO import pandas as pd import numpy as np import plotly.express as px # 导入工具模块 from ocr_validator_utils import ( load_config, load_css_styles, load_ocr_data_file, process_ocr_data, draw_bbox_on_image, get_ocr_statistics, convert_html_table_to_markdown, parse_html_tables, find_available_ocr_files, create_dynamic_css, export_tables_to_excel, get_table_statistics, group_texts_by_category ) from ocr_validator_layout import OCRLayoutManager class StreamlitOCRValidator: def __init__(self): self.config = load_config() self.ocr_data = [] self.md_content = "" self.image_path = "" self.text_bbox_mapping = {} self.selected_text = None self.marked_errors = set() # 初始化布局管理器 self.layout_manager = OCRLayoutManager(self) def setup_page_config(self): """设置页面配置""" ui_config = self.config['ui'] st.set_page_config( page_title=ui_config['page_title'], page_icon=ui_config['page_icon'], layout=ui_config['layout'], initial_sidebar_state=ui_config['sidebar_state'] ) # 加载CSS样式 css_content = load_css_styles() st.markdown(f"", unsafe_allow_html=True) def load_ocr_data(self, json_path: str, md_path: Optional[str] = None, image_path: Optional[str] = None): """加载OCR相关数据""" try: self.ocr_data, self.md_content, self.image_path = load_ocr_data_file(json_path, self.config) self.process_data() except Exception as e: st.error(f"❌ 加载失败: {e}") def process_data(self): """处理OCR数据""" self.text_bbox_mapping = process_ocr_data(self.ocr_data, self.config) def get_statistics(self) -> Dict: """获取统计信息""" return get_ocr_statistics(self.ocr_data, self.text_bbox_mapping, self.marked_errors) def display_html_table_as_dataframe(self, html_content: str, enable_editing: bool = False): """将HTML表格解析为DataFrame显示""" tables = parse_html_tables(html_content) if not tables: st.warning("未找到可解析的表格") st.markdown(html_content, unsafe_allow_html=True) return for i, table in enumerate(tables): st.subheader(f"📊 表格 {i+1}") # 创建表格操作按钮 col1, col2, col3, col4 = st.columns(4) with col1: show_info = st.checkbox(f"显示表格信息", key=f"info_{i}") with col2: show_stats = st.checkbox(f"显示统计信息", key=f"stats_{i}") with col3: enable_filter = st.checkbox(f"启用过滤", key=f"filter_{i}") with col4: enable_sort = st.checkbox(f"启用排序", key=f"sort_{i}") # 数据过滤和排序逻辑 filtered_table = self._apply_table_filters_and_sorts(table, i, enable_filter, enable_sort) # 显示表格 if enable_editing: edited_table = st.data_editor(filtered_table, use_container_width=True, key=f"editor_{i}") if not edited_table.equals(filtered_table): st.success("✏️ 表格已编辑,可以导出修改后的数据") else: st.dataframe(filtered_table, use_container_width=True) # 显示表格信息和统计 self._display_table_info_and_stats(table, filtered_table, show_info, show_stats, i) st.markdown("---") def _apply_table_filters_and_sorts(self, table: pd.DataFrame, table_index: int, enable_filter: bool, enable_sort: bool) -> pd.DataFrame: """应用表格过滤和排序""" filtered_table = table.copy() # 数据过滤 if enable_filter and not table.empty: filter_col = st.selectbox( f"选择过滤列 (表格 {table_index+1})", options=['无'] + list(table.columns), key=f"filter_col_{table_index}" ) if filter_col != '无': filter_value = st.text_input(f"过滤值 (表格 {table_index+1})", key=f"filter_value_{table_index}") if filter_value: filtered_table = table[table[filter_col].astype(str).str.contains(filter_value, na=False)] # 数据排序 if enable_sort and not filtered_table.empty: sort_col = st.selectbox( f"选择排序列 (表格 {table_index+1})", options=['无'] + list(filtered_table.columns), key=f"sort_col_{table_index}" ) if sort_col != '无': sort_order = st.radio( f"排序方式 (表格 {table_index+1})", options=['升序', '降序'], horizontal=True, key=f"sort_order_{table_index}" ) ascending = (sort_order == '升序') filtered_table = filtered_table.sort_values(sort_col, ascending=ascending) return filtered_table def _display_table_info_and_stats(self, original_table: pd.DataFrame, filtered_table: pd.DataFrame, show_info: bool, show_stats: bool, table_index: int): """显示表格信息和统计数据""" if show_info: st.write("**表格信息:**") st.write(f"- 原始行数: {len(original_table)}") st.write(f"- 过滤后行数: {len(filtered_table)}") st.write(f"- 列数: {len(original_table.columns)}") st.write(f"- 列名: {', '.join(original_table.columns)}") if show_stats: st.write("**统计信息:**") numeric_cols = filtered_table.select_dtypes(include=[np.number]).columns if len(numeric_cols) > 0: st.dataframe(filtered_table[numeric_cols].describe()) else: st.info("表格中没有数值列") # 导出功能 if st.button(f"📥 导出表格 {table_index+1}", key=f"export_{table_index}"): self._create_export_buttons(filtered_table, table_index) def _create_export_buttons(self, table: pd.DataFrame, table_index: int): """创建导出按钮""" # CSV导出 csv_data = table.to_csv(index=False) st.download_button( label=f"下载CSV (表格 {table_index+1})", data=csv_data, file_name=f"table_{table_index+1}.csv", mime="text/csv", key=f"download_csv_{table_index}" ) # Excel导出 excel_buffer = BytesIO() table.to_excel(excel_buffer, index=False) st.download_button( label=f"下载Excel (表格 {table_index+1})", data=excel_buffer.getvalue(), file_name=f"table_{table_index+1}.xlsx", mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", key=f"download_excel_{table_index}" ) # 布局方法现在委托给布局管理器 def create_standard_layout(self, font_size: int = 12, zoom_level: float = 1.0): """创建标准布局""" return self.layout_manager.create_standard_layout(font_size, zoom_level) def create_compact_layout(self, font_size: int = 12, zoom_level: float = 1.0): """创建滚动凑布局""" return self.layout_manager.create_compact_layout(font_size, zoom_level) def main(): """主应用""" # 初始化应用 if 'validator' not in st.session_state: st.session_state.validator = StreamlitOCRValidator() st.session_state.validator.setup_page_config() if 'selected_text' not in st.session_state: st.session_state.selected_text = None if 'marked_errors' not in st.session_state: st.session_state.marked_errors = set() # 同步标记的错误到validator st.session_state.validator.marked_errors = st.session_state.marked_errors # 页面标题 config = st.session_state.validator.config st.title(config['ui']['page_title']) st.markdown("---") # 侧边栏 - 文件选择和控制 with st.sidebar: st.header("📁 文件选择") # 查找可用的OCR文件 available_files = find_available_ocr_files(config['paths']['output_dir']) if available_files: selected_file = st.selectbox("选择OCR结果文件", available_files, index=0) if st.button("🔄 加载文件", type="primary") and selected_file: st.session_state.validator.load_ocr_data(selected_file) st.success("✅ 文件加载成功!") st.rerun() else: st.warning("未找到OCR结果文件") st.info("请确保output目录下有OCR结果文件") st.markdown("---") # 控制面板 st.header("🎛️ 控制面板") if st.button("🧹 清除选择"): st.session_state.selected_text = None st.rerun() if st.button("❌ 清除错误标记"): st.session_state.marked_errors = set() st.rerun() # 主内容区域 validator = st.session_state.validator if not validator.ocr_data: st.info("👈 请在左侧选择并加载OCR结果文件") return # 显示统计信息 stats = validator.get_statistics() col1, col2, col3, col4 = st.columns(4) with col1: st.metric("📊 总文本块", stats['total_texts']) with col2: st.metric("🔗 可点击文本", stats['clickable_texts']) with col3: st.metric("❌ 标记错误", stats['marked_errors']) with col4: st.metric("✅ 准确率", f"{stats['accuracy_rate']:.1f}%") st.markdown("---") # 创建标签页 tab1, tab2, tab3, tab4 = st.tabs(["📄 内容校验", "📊 表格分析", "📈 数据统计", "🚀 快速导航"]) with tab1: # 顶部控制区域 control_col1, control_col2 = st.columns(2) with control_col1: layout_mode = st.selectbox( "布局模式", ["标准布局", "滚动布局"], key="layout_mode" ) with control_col2: font_size = st.selectbox("字体大小", [10, 12, 14, 16], index=0, key="font_size_select") # 根据选择的布局模式显示不同的界面,传递参数 if layout_mode == "滚动布局": validator.create_compact_layout(font_size, 1.0) else: # 调用封装的标准布局方法 validator.create_standard_layout(font_size, 1.0) with tab2: # 表格分析页面 st.header("📊 表格数据分析") if validator.md_content and ' 15 else text if st.button(display_text, key=f"nav_{category}_{i}"): st.session_state.selected_text = text st.rerun() if __name__ == "__main__": main()