| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386 |
- #!/usr/bin/env python3
- """
- 基于Streamlit的OCR可视化校验工具(重构版)
- 提供丰富的交互组件和更好的用户体验
- """
- import streamlit as st
- from pathlib import Path
- from PIL import Image
- from typing import Dict, List, Optional
- import plotly.graph_objects as go
- from io import BytesIO
- import pandas as pd
- import numpy as np
- import plotly.express as px
- # 导入工具模块
- from ocr_validator_utils import (
- load_config, load_css_styles, load_ocr_data_file, process_ocr_data,
- draw_bbox_on_image, get_ocr_statistics, convert_html_table_to_markdown,
- parse_html_tables, find_available_ocr_files, create_dynamic_css,
- export_tables_to_excel, get_table_statistics, group_texts_by_category
- )
- from ocr_validator_layout import OCRLayoutManager
- class StreamlitOCRValidator:
- def __init__(self):
- self.config = load_config()
- self.ocr_data = []
- self.md_content = ""
- self.image_path = ""
- self.text_bbox_mapping = {}
- self.selected_text = None
- self.marked_errors = set()
-
- # 初始化布局管理器
- self.layout_manager = OCRLayoutManager(self)
-
- def setup_page_config(self):
- """设置页面配置"""
- ui_config = self.config['ui']
- st.set_page_config(
- page_title=ui_config['page_title'],
- page_icon=ui_config['page_icon'],
- layout=ui_config['layout'],
- initial_sidebar_state=ui_config['sidebar_state']
- )
-
- # 加载CSS样式
- css_content = load_css_styles()
- st.markdown(f"<style>{css_content}</style>", unsafe_allow_html=True)
-
- def load_ocr_data(self, json_path: str, md_path: Optional[str] = None, image_path: Optional[str] = None):
- """加载OCR相关数据"""
- try:
- self.ocr_data, self.md_content, self.image_path = load_ocr_data_file(json_path, self.config)
- self.process_data()
- except Exception as e:
- st.error(f"❌ 加载失败: {e}")
-
- def process_data(self):
- """处理OCR数据"""
- self.text_bbox_mapping = process_ocr_data(self.ocr_data, self.config)
-
- def get_statistics(self) -> Dict:
- """获取统计信息"""
- return get_ocr_statistics(self.ocr_data, self.text_bbox_mapping, self.marked_errors)
-
- def display_html_table_as_dataframe(self, html_content: str, enable_editing: bool = False):
- """将HTML表格解析为DataFrame显示"""
- tables = parse_html_tables(html_content)
-
- if not tables:
- st.warning("未找到可解析的表格")
- st.markdown(html_content, unsafe_allow_html=True)
- return
-
- for i, table in enumerate(tables):
- st.subheader(f"📊 表格 {i+1}")
-
- # 创建表格操作按钮
- col1, col2, col3, col4 = st.columns(4)
- with col1:
- show_info = st.checkbox(f"显示表格信息", key=f"info_{i}")
- with col2:
- show_stats = st.checkbox(f"显示统计信息", key=f"stats_{i}")
- with col3:
- enable_filter = st.checkbox(f"启用过滤", key=f"filter_{i}")
- with col4:
- enable_sort = st.checkbox(f"启用排序", key=f"sort_{i}")
-
- # 数据过滤和排序逻辑
- filtered_table = self._apply_table_filters_and_sorts(table, i, enable_filter, enable_sort)
-
- # 显示表格
- if enable_editing:
- edited_table = st.data_editor(filtered_table, use_container_width=True, key=f"editor_{i}")
- if not edited_table.equals(filtered_table):
- st.success("✏️ 表格已编辑,可以导出修改后的数据")
- else:
- st.dataframe(filtered_table, use_container_width=True)
-
- # 显示表格信息和统计
- self._display_table_info_and_stats(table, filtered_table, show_info, show_stats, i)
-
- st.markdown("---")
-
- def _apply_table_filters_and_sorts(self, table: pd.DataFrame, table_index: int, enable_filter: bool, enable_sort: bool) -> pd.DataFrame:
- """应用表格过滤和排序"""
- filtered_table = table.copy()
-
- # 数据过滤
- if enable_filter and not table.empty:
- filter_col = st.selectbox(
- f"选择过滤列 (表格 {table_index+1})",
- options=['无'] + list(table.columns),
- key=f"filter_col_{table_index}"
- )
-
- if filter_col != '无':
- filter_value = st.text_input(f"过滤值 (表格 {table_index+1})", key=f"filter_value_{table_index}")
- if filter_value:
- filtered_table = table[table[filter_col].astype(str).str.contains(filter_value, na=False)]
-
- # 数据排序
- if enable_sort and not filtered_table.empty:
- sort_col = st.selectbox(
- f"选择排序列 (表格 {table_index+1})",
- options=['无'] + list(filtered_table.columns),
- key=f"sort_col_{table_index}"
- )
-
- if sort_col != '无':
- sort_order = st.radio(
- f"排序方式 (表格 {table_index+1})",
- options=['升序', '降序'],
- horizontal=True,
- key=f"sort_order_{table_index}"
- )
- ascending = (sort_order == '升序')
- filtered_table = filtered_table.sort_values(sort_col, ascending=ascending)
-
- return filtered_table
-
- def _display_table_info_and_stats(self, original_table: pd.DataFrame, filtered_table: pd.DataFrame,
- show_info: bool, show_stats: bool, table_index: int):
- """显示表格信息和统计数据"""
- if show_info:
- st.write("**表格信息:**")
- st.write(f"- 原始行数: {len(original_table)}")
- st.write(f"- 过滤后行数: {len(filtered_table)}")
- st.write(f"- 列数: {len(original_table.columns)}")
- st.write(f"- 列名: {', '.join(original_table.columns)}")
-
- if show_stats:
- st.write("**统计信息:**")
- numeric_cols = filtered_table.select_dtypes(include=[np.number]).columns
- if len(numeric_cols) > 0:
- st.dataframe(filtered_table[numeric_cols].describe())
- else:
- st.info("表格中没有数值列")
-
- # 导出功能
- if st.button(f"📥 导出表格 {table_index+1}", key=f"export_{table_index}"):
- self._create_export_buttons(filtered_table, table_index)
-
- def _create_export_buttons(self, table: pd.DataFrame, table_index: int):
- """创建导出按钮"""
- # CSV导出
- csv_data = table.to_csv(index=False)
- st.download_button(
- label=f"下载CSV (表格 {table_index+1})",
- data=csv_data,
- file_name=f"table_{table_index+1}.csv",
- mime="text/csv",
- key=f"download_csv_{table_index}"
- )
-
- # Excel导出
- excel_buffer = BytesIO()
- table.to_excel(excel_buffer, index=False)
- st.download_button(
- label=f"下载Excel (表格 {table_index+1})",
- data=excel_buffer.getvalue(),
- file_name=f"table_{table_index+1}.xlsx",
- mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
- key=f"download_excel_{table_index}"
- )
-
- # 布局方法现在委托给布局管理器
- def create_standard_layout(self, font_size: int = 12, zoom_level: float = 1.0):
- """创建标准布局"""
- return self.layout_manager.create_standard_layout(font_size, zoom_level)
-
- def create_compact_layout(self, font_size: int = 12, zoom_level: float = 1.0):
- """创建滚动凑布局"""
- return self.layout_manager.create_compact_layout(font_size, zoom_level)
- def main():
- """主应用"""
- # 初始化应用
- if 'validator' not in st.session_state:
- st.session_state.validator = StreamlitOCRValidator()
- st.session_state.validator.setup_page_config()
-
- if 'selected_text' not in st.session_state:
- st.session_state.selected_text = None
-
- if 'marked_errors' not in st.session_state:
- st.session_state.marked_errors = set()
-
- # 同步标记的错误到validator
- st.session_state.validator.marked_errors = st.session_state.marked_errors
-
- # 页面标题
- config = st.session_state.validator.config
- st.title(config['ui']['page_title'])
- st.markdown("---")
-
- # 侧边栏 - 文件选择和控制
- with st.sidebar:
- st.header("📁 文件选择")
-
- # 查找可用的OCR文件
- available_files = find_available_ocr_files(config['paths']['ocr_out_dir'])
-
- if available_files:
- selected_file = st.selectbox("选择OCR结果文件", available_files, index=0)
-
- if st.button("🔄 加载文件", type="primary") and selected_file:
- st.session_state.validator.load_ocr_data(selected_file)
- st.success("✅ 文件加载成功!")
- st.rerun()
- else:
- st.warning("未找到OCR结果文件")
- st.info("请确保output目录下有OCR结果文件")
-
- st.markdown("---")
-
- # 控制面板
- st.header("🎛️ 控制面板")
-
- if st.button("🧹 清除选择"):
- st.session_state.selected_text = None
- st.rerun()
-
- if st.button("❌ 清除错误标记"):
- st.session_state.marked_errors = set()
- st.rerun()
-
- # 主内容区域
- validator = st.session_state.validator
-
- if not validator.ocr_data:
- st.info("👈 请在左侧选择并加载OCR结果文件")
- return
-
- # 显示统计信息
- stats = validator.get_statistics()
-
- col1, col2, col3, col4, col5 = st.columns(5) # 增加一列
- with col1:
- st.metric("📊 总文本块", stats['total_texts'])
- with col2:
- st.metric("🔗 可点击文本", stats['clickable_texts'])
- with col3:
- st.metric("❌ 标记错误", stats['marked_errors'])
- with col4:
- st.metric("✅ 准确率", f"{stats['accuracy_rate']:.1f}%")
- with col5:
- # 显示OCR工具信息
- if stats['tool_info']:
- tool_names = list(stats['tool_info'].keys())
- main_tool = tool_names[0] if tool_names else "未知"
- st.metric("🔧 OCR工具", main_tool)
-
- # 详细工具信息
- if stats['tool_info']:
- st.expander("🔧 OCR工具详情", expanded=False).write(stats['tool_info'])
-
- st.markdown("---")
-
- # 创建标签页
- tab1, tab2, tab3, tab4 = st.tabs(["📄 内容校验", "📊 表格分析", "📈 数据统计", "🚀 快速导航"])
-
- with tab1:
- # 顶部控制区域
- control_col1, control_col2 = st.columns(2)
-
- with control_col1:
- layout_mode = st.selectbox(
- "布局模式",
- ["标准布局", "滚动布局"],
- key="layout_mode"
- )
-
- with control_col2:
- font_size = st.selectbox("字体大小", [10, 12, 14, 16], index=0, key="font_size_select")
-
- # 根据选择的布局模式显示不同的界面,传递参数
- if layout_mode == "滚动布局":
- validator.create_compact_layout(font_size, 1.0)
- else:
- # 调用封装的标准布局方法
- validator.create_standard_layout(font_size, 1.0)
- with tab2:
- # 表格分析页面
- st.header("📊 表格数据分析")
-
- if validator.md_content and '<table' in validator.md_content.lower():
- col1, col2 = st.columns([2, 1])
-
- with col1:
- st.subheader("🔍 表格数据预览")
- validator.display_html_table_as_dataframe(validator.md_content)
-
- with col2:
- st.subheader("⚙️ 表格操作")
-
- if st.button("📥 导出表格数据", type="primary"):
- tables = parse_html_tables(validator.md_content)
- if tables:
- output = export_tables_to_excel(tables)
- st.download_button(
- label="📥 下载Excel文件",
- data=output.getvalue(),
- file_name="ocr_tables.xlsx",
- mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
- )
- else:
- st.info("当前OCR结果中没有检测到表格数据")
-
- with tab3:
- # 数据统计页面
- st.header("📈 OCR数据统计")
-
- if stats['categories']:
- st.subheader("📊 类别分布")
- fig_pie = px.pie(
- values=list(stats['categories'].values()),
- names=list(stats['categories'].keys()),
- title="文本类别分布"
- )
- st.plotly_chart(fig_pie, use_container_width=True)
-
- # 错误率分析
- st.subheader("📈 质量分析")
- accuracy_data = {
- '状态': ['正确', '错误'],
- '数量': [stats['clickable_texts'] - stats['marked_errors'], stats['marked_errors']]
- }
-
- fig_bar = px.bar(
- accuracy_data, x='状态', y='数量', title="识别质量分布",
- color='状态', color_discrete_map={'正确': 'green', '错误': 'red'}
- )
- st.plotly_chart(fig_bar, use_container_width=True)
-
- with tab4:
- # 快速导航功能
- st.header("🚀 快速导航")
-
- if not validator.text_bbox_mapping:
- st.info("没有可用的文本项进行导航")
- else:
- # 按类别分组
- categories = group_texts_by_category(validator.text_bbox_mapping)
-
- # 创建导航按钮
- for category, texts in categories.items():
- with st.expander(f"{category} ({len(texts)}项)", expanded=False):
- cols = st.columns(3) # 每行3个按钮
- for i, text in enumerate(texts):
- col_idx = i % 3
- with cols[col_idx]:
- display_text = text[:15] + "..." if len(text) > 15 else text
- if st.button(display_text, key=f"nav_{category}_{i}"):
- st.session_state.selected_text = text
- st.rerun()
- if __name__ == "__main__":
- main()
|