#!/usr/bin/env python3 """ 基于Streamlit的OCR可视化校验工具(重构版) 提供丰富的交互组件和更好的用户体验 """ import streamlit as st from pathlib import Path from PIL import Image from typing import Dict, List, Optional import plotly.graph_objects as go from io import BytesIO import pandas as pd import numpy as np import plotly.express as px import json # 导入工具模块 from ocr_validator_utils import ( load_config, load_ocr_data_file, process_ocr_data, get_ocr_statistics, find_available_ocr_files, group_texts_by_category, find_available_ocr_files_multi_source, get_data_source_display_name ) from ocr_validator_file_utils import ( load_css_styles, draw_bbox_on_image, convert_html_table_to_markdown, parse_html_tables, create_dynamic_css, export_tables_to_excel, get_table_statistics, ) from ocr_validator_layout import OCRLayoutManager from ocr_by_vlm import ocr_with_vlm from compare_ocr_results import compare_ocr_results class StreamlitOCRValidator: def __init__(self): self.config = load_config() self.ocr_data = [] self.md_content = "" self.image_path = "" self.text_bbox_mapping = {} self.selected_text = None self.marked_errors = set() # 多数据源相关 self.all_sources = {} self.current_source_key = None self.current_source_config = None self.file_info = [] self.selected_file_index = -1 self.display_options = [] self.file_paths = [] # ✅ 新增:交叉验证数据源 self.verify_source_key = None self.verify_source_config = None self.verify_file_info = [] self.verify_display_options = [] self.verify_file_paths = [] # 初始化布局管理器 self.layout_manager = OCRLayoutManager(self) # 加载多数据源文件信息 self.load_multi_source_info() def load_multi_source_info(self): """加载多数据源文件信息""" self.all_sources = find_available_ocr_files_multi_source(self.config) # 如果有数据源,默认选择第一个作为OCR源 if self.all_sources: source_keys = list(self.all_sources.keys()) first_source_key = source_keys[0] self.switch_to_source(first_source_key) # 如果有第二个数据源,默认作为验证源 if len(source_keys) > 1: self.switch_to_verify_source(source_keys[1]) def switch_to_source(self, source_key: str): """切换到指定OCR数据源""" if source_key in self.all_sources: self.current_source_key = source_key source_data = self.all_sources[source_key] self.current_source_config = source_data['config'] self.file_info = source_data['files'] if self.file_info: # 创建显示选项列表 self.display_options = [f"{info['display_name']}" for info in self.file_info] self.file_paths = [info['path'] for info in self.file_info] # 重置文件选择 self.selected_file_index = -1 print(f"✅ 切换到OCR数据源: {source_key}") else: print(f"⚠️ 数据源 {source_key} 没有可用文件") def switch_to_verify_source(self, source_key: str): """切换到指定验证数据源""" if source_key in self.all_sources: self.verify_source_key = source_key source_data = self.all_sources[source_key] self.verify_source_config = source_data['config'] self.verify_file_info = source_data['files'] if self.verify_file_info: self.verify_display_options = [f"{info['display_name']}" for info in self.verify_file_info] self.verify_file_paths = [info['path'] for info in self.verify_file_info] print(f"✅ 切换到验证数据源: {source_key}") else: print(f"⚠️ 验证数据源 {source_key} 没有可用文件") def setup_page_config(self): """设置页面配置""" ui_config = self.config['ui'] st.set_page_config( page_title=ui_config['page_title'], page_icon=ui_config['page_icon'], layout=ui_config['layout'], initial_sidebar_state=ui_config['sidebar_state'] ) # 加载CSS样式 css_content = load_css_styles() st.markdown(f"", unsafe_allow_html=True) def create_data_source_selector(self): """创建双数据源选择器 - 支持交叉验证""" if not self.all_sources: st.warning("❌ 未找到任何数据源,请检查配置文件") return # 准备数据源选项 source_options = {} for source_key, source_data in self.all_sources.items(): display_name = get_data_source_display_name(source_data['config']) source_options[display_name] = source_key # 创建两列布局 col1, col2 = st.columns(2) with col1: st.markdown("#### 📊 OCR数据源") # OCR数据源选择 current_display_name = None if self.current_source_key: for display_name, key in source_options.items(): if key == self.current_source_key: current_display_name = display_name break selected_ocr_display = st.selectbox( "选择OCR数据源", options=list(source_options.keys()), index=list(source_options.keys()).index(current_display_name) if current_display_name else 0, key="ocr_source_selector", label_visibility="collapsed", help="选择要分析的OCR数据源" ) selected_ocr_key = source_options[selected_ocr_display] # 如果OCR数据源发生变化,切换数据源 if selected_ocr_key != self.current_source_key: self.switch_to_source(selected_ocr_key) if 'selected_file_index' in st.session_state: st.session_state.selected_file_index = 0 st.rerun() # 显示OCR数据源信息 if self.current_source_config: with st.expander("📋 OCR数据源详情", expanded=False): st.write(f"**工具:** {self.current_source_config['ocr_tool']}") st.write(f"**文件数:** {len(self.file_info)}") with col2: st.markdown("#### 🔍 验证数据源") # 验证数据源选择 verify_display_name = None if self.verify_source_key: for display_name, key in source_options.items(): if key == self.verify_source_key: verify_display_name = display_name break selected_verify_display = st.selectbox( "选择验证数据源", options=list(source_options.keys()), index=list(source_options.keys()).index(verify_display_name) if verify_display_name else (1 if len(source_options) > 1 else 0), key="verify_source_selector", label_visibility="collapsed", help="选择用于交叉验证的数据源" ) selected_verify_key = source_options[selected_verify_display] # 如果验证数据源发生变化,切换数据源 if selected_verify_key != self.verify_source_key: self.switch_to_verify_source(selected_verify_key) st.rerun() # 显示验证数据源信息 if self.verify_source_config: with st.expander("📋 验证数据源详情", expanded=False): st.write(f"**工具:** {self.verify_source_config['ocr_tool']}") st.write(f"**文件数:** {len(self.verify_file_info)}") # 数据源对比提示 if self.current_source_key == self.verify_source_key: st.warning("⚠️ OCR数据源和验证数据源相同,建议选择不同的数据源进行交叉验证") else: st.success(f"✅ 已选择 {selected_ocr_display} 与 {selected_verify_display} 进行交叉验证") def load_ocr_data(self, json_path: str, md_path: Optional[str] = None, image_path: Optional[str] = None): """加载OCR相关数据 - 支持多数据源配置""" try: # 使用当前数据源的配置加载数据 if self.current_source_config: # 临时修改config以使用当前数据源的配置 temp_config = self.config.copy() temp_config['paths'] = { 'ocr_out_dir': self.current_source_config['ocr_out_dir'], 'src_img_dir': self.current_source_config.get('src_img_dir', ''), 'pre_validation_dir': self.config['pre_validation']['out_dir'] } # 设置OCR工具类型 temp_config['current_ocr_tool'] = self.current_source_config['ocr_tool'] self.ocr_data, self.md_content, self.image_path = load_ocr_data_file(json_path, temp_config) else: self.ocr_data, self.md_content, self.image_path = load_ocr_data_file(json_path, self.config) self.process_data() except Exception as e: st.error(f"❌ 加载失败: {e}") st.exception(e) def process_data(self): """处理OCR数据""" self.text_bbox_mapping = process_ocr_data(self.ocr_data, self.config) def get_statistics(self) -> Dict: """获取统计信息""" return get_ocr_statistics(self.ocr_data, self.text_bbox_mapping, self.marked_errors) def display_html_table_as_dataframe(self, html_content: str, enable_editing: bool = False): """将HTML表格解析为DataFrame显示 - 增强版本支持横向滚动""" tables = parse_html_tables(html_content) wide_table_threshold = 15 # 超宽表格列数阈值 if not tables: st.warning("未找到可解析的表格") # 对于无法解析的HTML表格,使用自定义CSS显示 st.markdown(""" """, unsafe_allow_html=True) st.markdown(f'
{html_content}
', unsafe_allow_html=True) return for i, table in enumerate(tables): st.subheader(f"📊 表格 {i+1}") # 表格信息显示 col_info1, col_info2, col_info3, col_info4 = st.columns(4) with col_info1: st.metric("行数", len(table)) with col_info2: st.metric("列数", len(table.columns)) with col_info3: # 检查是否有超宽表格 is_wide_table = len(table.columns) > wide_table_threshold st.metric("表格类型", "超宽表格" if is_wide_table else "普通表格") with col_info4: # 表格操作模式选择 display_mode = st.selectbox( f"显示模式 (表格{i+1})", ["完整显示", "分页显示", "筛选列显示"], key=f"display_mode_{i}" ) # 创建表格操作按钮 col1, col2, col3, col4 = st.columns(4) with col1: show_info = st.checkbox(f"显示详细信息", key=f"info_{i}") with col2: show_stats = st.checkbox(f"显示统计信息", key=f"stats_{i}") with col3: enable_filter = st.checkbox(f"启用过滤", key=f"filter_{i}") with col4: enable_sort = st.checkbox(f"启用排序", key=f"sort_{i}") # 根据显示模式处理表格 display_table = self._process_table_display_mode(table, i, display_mode) # 数据过滤和排序逻辑 filtered_table = self._apply_table_filters_and_sorts(display_table, i, enable_filter, enable_sort) # 显示表格 - 使用自定义CSS支持横向滚动 st.markdown(""" """, unsafe_allow_html=True) # 根据表格宽度选择显示容器 container_class = "wide-table-container" if len(table.columns) > wide_table_threshold else "dataframe-container" if enable_editing: st.markdown(f'
', unsafe_allow_html=True) edited_table = st.data_editor( filtered_table, use_container_width=True, key=f"editor_{i}", height=400 if len(table.columns) > 8 else None ) st.markdown('
', unsafe_allow_html=True) if not edited_table.equals(filtered_table): st.success("✏️ 表格已编辑,可以导出修改后的数据") else: st.markdown(f'
', unsafe_allow_html=True) st.dataframe( filtered_table, # use_container_width=True, width =400 if len(table.columns) > wide_table_threshold else "stretch" ) st.markdown('
', unsafe_allow_html=True) # 显示表格信息和统计 self._display_table_info_and_stats(table, filtered_table, show_info, show_stats, i) st.markdown("---") def _apply_table_filters_and_sorts(self, table: pd.DataFrame, table_index: int, enable_filter: bool, enable_sort: bool) -> pd.DataFrame: """应用表格过滤和排序""" filtered_table = table.copy() # 数据过滤 if enable_filter and not table.empty: filter_col = st.selectbox( f"选择过滤列 (表格 {table_index+1})", options=['无'] + list(table.columns), key=f"filter_col_{table_index}" ) if filter_col != '无': filter_value = st.text_input(f"过滤值 (表格 {table_index+1})", key=f"filter_value_{table_index}") if filter_value: filtered_table = table[table[filter_col].astype(str).str.contains(filter_value, na=False)] # 数据排序 if enable_sort and not filtered_table.empty: sort_col = st.selectbox( f"选择排序列 (表格 {table_index+1})", options=['无'] + list(filtered_table.columns), key=f"sort_col_{table_index}" ) if sort_col != '无': sort_order = st.radio( f"排序方式 (表格 {table_index+1})", options=['升序', '降序'], horizontal=True, key=f"sort_order_{table_index}" ) ascending = (sort_order == '升序') filtered_table = filtered_table.sort_values(sort_col, ascending=ascending) return filtered_table def _display_table_info_and_stats(self, original_table: pd.DataFrame, filtered_table: pd.DataFrame, show_info: bool, show_stats: bool, table_index: int): """显示表格信息和统计数据""" if show_info: st.write("**表格信息:**") st.write(f"- 原始行数: {len(original_table)}") st.write(f"- 过滤后行数: {len(filtered_table)}") st.write(f"- 列数: {len(original_table.columns)}") st.write(f"- 列名: {', '.join(original_table.columns)}") if show_stats: st.write("**统计信息:**") numeric_cols = filtered_table.select_dtypes(include=[np.number]).columns if len(numeric_cols) > 0: st.dataframe(filtered_table[numeric_cols].describe()) else: st.info("表格中没有数值列") # 导出功能 if st.button(f"📥 导出表格 {table_index+1}", key=f"export_{table_index}"): self._create_export_buttons(filtered_table, table_index) def _create_export_buttons(self, table: pd.DataFrame, table_index: int): """创建导出按钮""" # CSV导出 csv_data = table.to_csv(index=False) st.download_button( label=f"下载CSV (表格 {table_index+1})", data=csv_data, file_name=f"table_{table_index+1}.csv", mime="text/csv", key=f"download_csv_{table_index}" ) # Excel导出 excel_buffer = BytesIO() table.to_excel(excel_buffer, index=False) st.download_button( label=f"下载Excel (表格 {table_index+1})", data=excel_buffer.getvalue(), file_name=f"table_{table_index+1}.xlsx", mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", key=f"download_excel_{table_index}" ) def _process_table_display_mode(self, table: pd.DataFrame, table_index: int, display_mode: str) -> pd.DataFrame: """根据显示模式处理表格""" if display_mode == "分页显示": # 分页显示 page_size = st.selectbox( f"每页显示行数 (表格 {table_index+1})", [10, 20, 50, 100], key=f"page_size_{table_index}" ) total_pages = (len(table) - 1) // page_size + 1 if total_pages > 1: page_number = st.selectbox( f"页码 (表格 {table_index+1})", range(1, total_pages + 1), key=f"page_number_{table_index}" ) start_idx = (page_number - 1) * page_size end_idx = start_idx + page_size return table.iloc[start_idx:end_idx] return table elif display_mode == "筛选列显示": # 列筛选显示 if len(table.columns) > 5: selected_columns = st.multiselect( f"选择要显示的列 (表格 {table_index+1})", table.columns.tolist(), default=table.columns.tolist()[:5], # 默认显示前5列 key=f"selected_columns_{table_index}" ) if selected_columns: return table[selected_columns] return table else: # 完整显示 return table def find_verify_md_path(self, selected_file_index: int) -> Optional[Path]: """查找当前OCR文件对应的验证文件路径""" current_page = self.file_info[selected_file_index]['page'] verify_md_path = None for i, info in enumerate(self.verify_file_info): if info['page'] == current_page: verify_md_path = Path(self.verify_file_paths[i]).with_suffix('.md') break return verify_md_path @st.dialog("交叉验证", width="large", dismissible=True, on_dismiss="rerun") def cross_validation(self): """交叉验证功能 - 比对两个数据源的OCR结果""" if not self.image_path or not self.md_content: st.error("❌ 请先加载OCR数据文件") return if self.current_source_key == self.verify_source_key: st.error("❌ OCR数据源和验证数据源不能相同") return # 初始化对比结果存储 if 'cross_validation_result' not in st.session_state: st.session_state.cross_validation_result = None # 初始化对比结果存储 if 'cross_validation_result' not in st.session_state: st.session_state.cross_validation_result = None # 创建进度条和状态显示 with st.spinner("正在进行交叉验证...", show_time=True): status_text = st.empty() try: # 第一步:获取当前OCR结果文件路径 current_md_path = Path(self.file_paths[self.selected_file_index]).with_suffix('.md') if not current_md_path.exists(): st.error("❌ 当前OCR结果的Markdown文件不存在") return status_text.text(f"📄 OCR文件: {current_md_path.name}") # 第二步:查找对应的验证文件 verify_md_path = self.find_verify_md_path(self.selected_file_index) if not verify_md_path or not verify_md_path.exists(): st.error(f"❌ 未找到验证数据源中第{current_md_path}页的对应文件") return status_text.text(f"🔍 验证文件: {verify_md_path.name}") # 第三步:准备输出目录 pre_validation_dir = Path(self.config['pre_validation'].get('out_dir', './output/pre_validation/')).resolve() pre_validation_dir.mkdir(parents=True, exist_ok=True) # 第四步:调用对比功能 status_text.text("📊 正在对比OCR结果...") comparison_result_path = pre_validation_dir / f"{current_md_path.stem}_cross_validation" # 在expander中显示对比过程 with st.expander("🔍 交叉验证对比过程", expanded=True): compare_output = st.empty() # 捕获对比输出 import io import contextlib output_buffer = io.StringIO() with contextlib.redirect_stdout(output_buffer): comparison_result = compare_ocr_results( file1_path=str(current_md_path), file2_path=str(verify_md_path), output_file=str(comparison_result_path), output_format='both', ignore_images=True, table_mode='flow_list', # ✅ 使用流水表格模式 similarity_algorithm='ratio' ) # 显示对比过程输出 compare_output.code(output_buffer.getvalue(), language='text') status_text.text("✅ 交叉验证完成") st.session_state.cross_validation_result = { "ocr_source": get_data_source_display_name(self.current_source_config), "verify_source": get_data_source_display_name(self.verify_source_config), "ocr_file": str(current_md_path), "verify_file": str(verify_md_path), "comparison_result_json": f"{comparison_result_path}.json", "comparison_result_md": f"{comparison_result_path}.md", "comparison_result": comparison_result } # 第五步:显示对比结果 self.display_comparison_results(comparison_result, detailed=False) except Exception as e: st.error(f"❌ 交叉验证失败: {e}") st.exception(e) @st.dialog("查看交叉验证结果", width="large", dismissible=True, on_dismiss="rerun") def show_cross_validation_results_dialog(self): """显示交叉验证结果的对话框""" current_md_path = Path(self.file_paths[self.selected_file_index]).with_suffix('.md') pre_validation_dir = Path(self.config['pre_validation'].get('out_dir', './output/pre_validation/')).resolve() comparison_result_path = pre_validation_dir / f"{current_md_path.stem}_cross_validation.json" if 'cross_validation_result' in st.session_state and st.session_state.cross_validation_result: result = st.session_state.cross_validation_result # 显示数据源信息 col1, col2 = st.columns(2) with col1: st.info(f"**OCR数据源:** {result['ocr_source']}") with col2: st.info(f"**验证数据源:** {result['verify_source']}") self.display_comparison_results(result['comparison_result']) elif comparison_result_path.exists(): # 如果有历史结果文件,提示加载 if st.button("📂 加载历史验证结果"): with open(comparison_result_path, "r", encoding="utf-8") as f: comparison_json_result = json.load(f) cross_validation_result = { "ocr_source": get_data_source_display_name(self.current_source_config), "verify_source": get_data_source_display_name(self.verify_source_config), "ocr_file": comparison_json_result['file1_path'], "verify_file": comparison_json_result['file2_path'], "comparison_result_json": str(comparison_result_path), "comparison_result_md": str(comparison_result_path.with_suffix('.md')), "comparison_result": comparison_json_result } st.session_state.cross_validation_result = cross_validation_result self.display_comparison_results(comparison_json_result) else: st.info("暂无交叉验证结果,请先运行交叉验证") def display_comparison_results(self, comparison_result: dict, detailed: bool = True): """显示对比结果摘要 - 使用DataFrame展示""" st.header("📊 VLM预校验结果") # 统计信息 stats = comparison_result['statistics'] # 统计信息概览 col1, col2, col3, col4 = st.columns(4) with col1: st.metric("总差异数", stats['total_differences']) with col2: st.metric("表格差异", stats['table_differences']) with col3: st.metric("金额差异", stats['amount_differences']) with col4: st.metric("段落差异", stats['paragraph_differences']) # 结果判断 if stats['total_differences'] == 0: st.success("🎉 完美匹配!VLM识别结果与原OCR结果完全一致") else: st.warning(f"⚠️ 发现 {stats['total_differences']} 个差异,建议人工检查") # 使用DataFrame显示差异详情 if comparison_result['differences']: st.subheader("🔍 差异详情对比") # 准备DataFrame数据 diff_data = [] for i, diff in enumerate(comparison_result['differences'], 1): diff_data.append({ '序号': i, '位置': diff['position'], '类型': diff['type'], '原OCR结果': diff['file1_value'][:100] + ('...' if len(diff['file1_value']) > 100 else ''), 'VLM识别结果': diff['file2_value'][:100] + ('...' if len(diff['file2_value']) > 100 else ''), '描述': diff['description'][:80] + ('...' if len(diff['description']) > 80 else ''), '严重程度': self._get_severity_level(diff) }) # 创建DataFrame df_differences = pd.DataFrame(diff_data) # 添加样式 def highlight_severity(val): """根据严重程度添加颜色""" if val == '高': return 'background-color: #ffebee; color: #c62828' elif val == '中': return 'background-color: #fff3e0; color: #ef6c00' elif val == '低': return 'background-color: #e8f5e8; color: #2e7d32' return '' # 显示DataFrame styled_df = df_differences.style.applymap( highlight_severity, subset=['严重程度'] ).format({ '序号': '{:d}', }) st.dataframe( styled_df, use_container_width=True, height=400, hide_index=True, column_config={ "序号": st.column_config.NumberColumn( "序号", width=None, # 自动调整宽度 pinned=True, help="差异项序号" ), "位置": st.column_config.TextColumn( "位置", width=None, # 自动调整宽度 pinned=True, help="差异在文档中的位置" ), "类型": st.column_config.TextColumn( "类型", width=None, # 自动调整宽度 pinned=True, help="差异类型" ), "原OCR结果": st.column_config.TextColumn( "原OCR结果", width="large", # 自动调整宽度 pinned=True, help="原始OCR识别结果" ), "VLM识别结果": st.column_config.TextColumn( "VLM识别结果", width="large", # 自动调整宽度 help="VLM重新识别的结果" ), "描述": st.column_config.TextColumn( "描述", width="medium", # 自动调整宽度 help="差异详细描述" ), "严重程度": st.column_config.TextColumn( "严重程度", width=None, # 自动调整宽度 help="差异严重程度评级" ) } ) # 详细差异查看 st.subheader("🔍 详细差异查看") if detailed: # 选择要查看的差异 selected_diff_index = st.selectbox( "选择要查看的差异:", options=range(len(comparison_result['differences'])), format_func=lambda x: f"差异 {x+1}: {comparison_result['differences'][x]['position']} - {comparison_result['differences'][x]['type']}", key="selected_diff" ) if selected_diff_index is not None: diff = comparison_result['differences'][selected_diff_index] # 并排显示完整内容 col1, col2 = st.columns(2) with col1: st.write("**原OCR结果:**") st.text_area( "原OCR结果详情", value=diff['file1_value'], height=200, key=f"original_{selected_diff_index}", label_visibility="collapsed" ) with col2: st.write("**验证数据源识别结果:**") st.text_area( "验证数据源识别结果详情", value=diff['file2_value'], height=200, key=f"vlm_{selected_diff_index}", label_visibility="collapsed" ) # 差异详细信息 st.info(f"**位置:** {diff['position']}") st.info(f"**类型:** {diff['type']}") st.info(f"**描述:** {diff['description']}") st.info(f"**严重程度:** {self._get_severity_level(diff)}") # 差异统计图表 st.subheader("📈 差异类型分布") # 按类型统计差异 type_counts = {} severity_counts = {'高': 0, '中': 0, '低': 0} for diff in comparison_result['differences']: diff_type = diff['type'] type_counts[diff_type] = type_counts.get(diff_type, 0) + 1 severity = self._get_severity_level(diff) severity_counts[severity] += 1 col1, col2 = st.columns(2) with col1: # 类型分布饼图 if type_counts: fig_type = px.pie( values=list(type_counts.values()), names=list(type_counts.keys()), title="差异类型分布" ) st.plotly_chart(fig_type, use_container_width=True) with col2: # 严重程度分布条形图 fig_severity = px.bar( x=list(severity_counts.keys()), y=list(severity_counts.values()), title="差异严重程度分布", color=list(severity_counts.keys()), color_discrete_map={'高': '#f44336', '中': '#ff9800', '低': '#4caf50'} ) st.plotly_chart(fig_severity, use_container_width=True) # 下载选项 if detailed: self._provide_download_options_in_results(comparison_result) def _get_severity_level(self, diff: dict) -> str: """根据差异类型和内容判断严重程度""" # 如果差异中已经包含严重程度,直接使用 if 'severity' in diff: severity_map = {'high': '高', 'medium': '中', 'low': '低'} return severity_map.get(diff['severity'], '中') # 原有的逻辑作为后备 diff_type = diff['type'].lower() # 金额相关差异为高严重程度 if 'amount' in diff_type or 'number' in diff_type: return '高' # 表格结构差异为中等严重程度 if 'table' in diff_type or 'structure' in diff_type: return '中' # 检查相似度 if 'similarity' in diff: similarity = diff['similarity'] if similarity < 50: return '高' elif similarity < 85: return '中' else: return '低' # 检查内容长度差异 len_diff = abs(len(diff['file1_value']) - len(diff['file2_value'])) if len_diff > 50: return '高' elif len_diff > 10: return '中' else: return '低' def _provide_download_options_in_results(self, comparison_result: dict): """在结果页面提供下载选项""" st.subheader("📥 导出预校验结果") col1, col2, col3 = st.columns(3) with col1: # 导出差异详情为Excel if comparison_result['differences']: diff_data = [] for i, diff in enumerate(comparison_result['differences'], 1): diff_data.append({ '序号': i, '位置': diff['position'], '类型': diff['type'], '原OCR结果': diff['file1_value'], 'VLM识别结果': diff['file2_value'], '描述': diff['description'], '严重程度': self._get_severity_level(diff) }) df_export = pd.DataFrame(diff_data) excel_buffer = BytesIO() df_export.to_excel(excel_buffer, index=False, sheet_name='差异详情') st.download_button( label="📊 下载差异详情(Excel)", data=excel_buffer.getvalue(), file_name=f"vlm_comparison_differences_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.xlsx", mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", key="download_differences_excel" ) with col2: # 导出统计报告 stats_data = { '统计项目': ['总差异数', '表格差异', '金额差异', '段落差异'], '数量': [ comparison_result['statistics']['total_differences'], comparison_result['statistics']['table_differences'], comparison_result['statistics']['amount_differences'], comparison_result['statistics']['paragraph_differences'] ] } df_stats = pd.DataFrame(stats_data) csv_stats = df_stats.to_csv(index=False) st.download_button( label="📈 下载统计报告(CSV)", data=csv_stats, file_name=f"vlm_comparison_stats_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.csv", mime="text/csv", key="download_stats_csv" ) with col3: # 导出完整报告为JSON import json report_json = json.dumps(comparison_result, ensure_ascii=False, indent=2) st.download_button( label="📄 下载完整报告(JSON)", data=report_json, file_name=f"vlm_comparison_full_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.json", mime="application/json", key="download_full_json" ) # 操作建议 st.subheader("🚀 后续操作建议") total_diffs = comparison_result['statistics']['total_differences'] if total_diffs == 0: st.success("✅ VLM识别结果与原OCR完全一致,可信度很高,无需人工校验") elif total_diffs <= 5: st.warning("⚠️ 发现少量差异,建议重点检查高严重程度的差异项") elif total_diffs <= 20: st.warning("🔍 发现中等数量差异,建议详细检查差异表格中标红的项目") else: st.error("❌ 发现大量差异,建议重新进行OCR识别或检查原始图片质量") def create_compact_layout(self, config): """创建滚动凑布局""" return self.layout_manager.create_compact_layout(config) @st.dialog("message", width="small", dismissible=True, on_dismiss="rerun") def message_box(msg: str, msg_type: str = "info"): if msg_type == "info": st.info(msg) elif msg_type == "warning": st.warning(msg) elif msg_type == "error": st.error(msg) def main(): """主应用""" # 初始化应用 if 'validator' not in st.session_state: validator = StreamlitOCRValidator() st.session_state.validator = validator st.session_state.validator.setup_page_config() # 页面标题 config = st.session_state.validator.config st.title(config['ui']['page_title']) else: validator = st.session_state.validator config = st.session_state.validator.config if 'selected_text' not in st.session_state: st.session_state.selected_text = None if 'marked_errors' not in st.session_state: st.session_state.marked_errors = set() # 数据源选择器 validator.create_data_source_selector() # 如果没有可用的数据源,提前返回 if not validator.all_sources: st.stop() # 文件选择区域 with st.container(height=75, horizontal=True, horizontal_alignment='left', gap="medium"): # 初始化session_state中的选择索引 if 'selected_file_index' not in st.session_state: st.session_state.selected_file_index = 0 if validator.display_options: # 文件选择下拉框 selected_index = st.selectbox( "选择OCR结果文件", range(len(validator.display_options)), format_func=lambda i: validator.display_options[i], index=st.session_state.selected_file_index, key="selected_selectbox", label_visibility="collapsed" ) # 更新session_state if selected_index != st.session_state.selected_file_index: st.session_state.selected_file_index = selected_index selected_file = validator.file_paths[selected_index] # 页码输入器 current_page = validator.file_info[selected_index]['page'] page_input = st.number_input( "输入页码", placeholder="输入页码", label_visibility="collapsed", min_value=1, max_value=len(validator.display_options), value=current_page, step=1, key="page_input" ) # 当页码输入改变时,更新文件选择 if page_input != current_page: for i, info in enumerate(validator.file_info): if info['page'] == page_input: st.session_state.selected_file_index = i selected_file = validator.file_paths[i] st.rerun() break # 自动加载文件 if (st.session_state.selected_file_index >= 0 and validator.selected_file_index != st.session_state.selected_file_index and selected_file): validator.selected_file_index = st.session_state.selected_file_index st.session_state.validator.load_ocr_data(selected_file) # 显示加载成功信息 current_source_name = get_data_source_display_name(validator.current_source_config) st.success(f"✅ 已加载 {current_source_name} - 第{validator.file_info[st.session_state.selected_file_index]['page']}页") st.rerun() else: st.warning("当前数据源中未找到OCR结果文件") # 交叉验证按钮 if st.button("交叉验证", type="primary", icon=":material/compare_arrows:"): if validator.image_path and validator.md_content: validator.cross_validation() else: message_box("❌ 请先选择OCR数据文件", "error") # 查看预校验结果按钮 if st.button("查看验证结果", type="secondary", icon=":material/quick_reference_all:"): validator.show_cross_validation_results_dialog() # 显示当前数据源统计信息 with st.expander("🔧 OCR工具统计信息", expanded=False): stats = validator.get_statistics() col1, col2, col3, col4, col5 = st.columns(5) with col1: st.metric("📊 总文本块", stats['total_texts']) with col2: st.metric("🔗 可点击文本", stats['clickable_texts']) with col3: st.metric("❌ 标记错误", stats['marked_errors']) with col4: st.metric("✅ 准确率", f"{stats['accuracy_rate']:.1f}%") with col5: # 显示当前数据源信息 if validator.current_source_config: tool_display = validator.current_source_config['ocr_tool'].upper() st.metric("🔧 OCR工具", tool_display) # 详细工具信息 if stats['tool_info']: st.write("**详细信息:**", stats['tool_info']) # 其余标签页保持不变... tab1, tab2, tab3 = st.tabs(["📄 内容人工检查", "🔍 交叉验证结果", "📊 表格分析"]) with tab1: validator.create_compact_layout(config) with tab2: # st.header("📄 VLM预校验识别结果") current_md_path = Path(validator.file_paths[validator.selected_file_index]).with_suffix('.md') pre_validation_dir = Path(validator.config['pre_validation'].get('out_dir', './output/pre_validation/')).resolve() comparison_result_path = pre_validation_dir / f"{current_md_path.stem}_cross_validation.json" # pre_validation_path = pre_validation_dir / f"{current_md_path.stem}.md" verify_md_path = validator.find_verify_md_path(validator.selected_file_index) if comparison_result_path.exists(): # 加载并显示验证结果 with open(comparison_result_path, "r", encoding="utf-8") as f: comparison_result = json.load(f) # 左边显示OCR结果,右边显示VLM结果 col1, col2 = st.columns([1,1]) with col1: st.subheader("🤖 原OCR识别结果") with open(current_md_path, "r", encoding="utf-8") as f: original_md_content = f.read() font_size = config['styles'].get('font_size', 10) height = config['styles']['layout'].get('default_height', 800) layout_type = "compact" validator.layout_manager.render_content_by_mode(original_md_content, "HTML渲染", font_size, height, layout_type) with col2: st.subheader("🤖 VLM识别结果") with open(str(verify_md_path), "r", encoding="utf-8") as f: verify_md_content = f.read() font_size = config['styles'].get('font_size', 10) height = config['styles']['layout'].get('default_height', 800) layout_type = "compact" validator.layout_manager.render_content_by_mode(verify_md_content, "HTML渲染", font_size, height, layout_type) # 显示差异统计 st.markdown("---") validator.display_comparison_results(comparison_result, detailed=True) else: st.info("暂无预校验结果,请先运行VLM预校验") with tab3: # 表格分析页面 - 保持原有逻辑 st.header("📊 表格数据分析") if validator.md_content and '