| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987 |
- #!/usr/bin/env python3
- """
- 基于Streamlit的OCR可视化校验工具(重构版)
- 提供丰富的交互组件和更好的用户体验
- """
- import streamlit as st
- from pathlib import Path
- from PIL import Image
- from typing import Dict, List, Optional
- import plotly.graph_objects as go
- from io import BytesIO
- import pandas as pd
- import numpy as np
- import plotly.express as px
- import json
- # 导入工具模块
- from ocr_validator_utils import (
- load_config, load_css_styles, load_ocr_data_file, process_ocr_data,
- draw_bbox_on_image, get_ocr_statistics, convert_html_table_to_markdown,
- parse_html_tables, find_available_ocr_files, create_dynamic_css,
- export_tables_to_excel, get_table_statistics, group_texts_by_category
- )
- from ocr_validator_layout import OCRLayoutManager
- from ocr_by_vlm import ocr_with_vlm
- from compare_ocr_results import compare_ocr_results
- class StreamlitOCRValidator:
- def __init__(self):
- self.config = load_config()
- self.ocr_data = []
- self.md_content = ""
- self.image_path = ""
- self.text_bbox_mapping = {}
- self.selected_text = None
- self.marked_errors = set()
- self.file_info = []
- self.selected_file_index = -1 # 初始化不指向有效文件index
- self.display_options = []
- self.file_paths = []
- # 初始化布局管理器
- self.layout_manager = OCRLayoutManager(self)
- # 加载文件信息
- self.load_file_info()
-
- def load_file_info(self):
- # 查找可用的OCR文件
- self.file_info = find_available_ocr_files(self.config['paths']['ocr_out_dir'])
- # 初始化session_state中的选择索引
- if self.file_info:
- # 创建显示选项列表
- self.display_options = [f"{info['display_name']}" for info in self.file_info]
- self.file_paths = [info['path'] for info in self.file_info]
- def setup_page_config(self):
- """设置页面配置"""
- ui_config = self.config['ui']
- st.set_page_config(
- page_title=ui_config['page_title'],
- page_icon=ui_config['page_icon'],
- layout=ui_config['layout'],
- initial_sidebar_state=ui_config['sidebar_state']
- )
-
- # 加载CSS样式
- css_content = load_css_styles()
- st.markdown(f"<style>{css_content}</style>", unsafe_allow_html=True)
-
- def load_ocr_data(self, json_path: str, md_path: Optional[str] = None, image_path: Optional[str] = None):
- """加载OCR相关数据"""
- try:
- self.ocr_data, self.md_content, self.image_path = load_ocr_data_file(json_path, self.config)
- self.process_data()
- except Exception as e:
- st.error(f"❌ 加载失败: {e}")
-
- def process_data(self):
- """处理OCR数据"""
- self.text_bbox_mapping = process_ocr_data(self.ocr_data, self.config)
-
- def get_statistics(self) -> Dict:
- """获取统计信息"""
- return get_ocr_statistics(self.ocr_data, self.text_bbox_mapping, self.marked_errors)
-
- def display_html_table_as_dataframe(self, html_content: str, enable_editing: bool = False):
- """将HTML表格解析为DataFrame显示 - 增强版本支持横向滚动"""
- tables = parse_html_tables(html_content)
-
- if not tables:
- st.warning("未找到可解析的表格")
- # 对于无法解析的HTML表格,使用自定义CSS显示
- st.markdown("""
- <style>
- .scrollable-table {
- overflow-x: auto;
- white-space: nowrap;
- border: 1px solid #ddd;
- border-radius: 5px;
- margin: 10px 0;
- }
- .scrollable-table table {
- width: 100%;
- border-collapse: collapse;
- }
- .scrollable-table th, .scrollable-table td {
- border: 1px solid #ddd;
- padding: 8px;
- text-align: left;
- min-width: 100px;
- }
- .scrollable-table th {
- background-color: #f5f5f5;
- font-weight: bold;
- }
- </style>
- """, unsafe_allow_html=True)
-
- st.markdown(f'<div class="scrollable-table">{html_content}</div>', unsafe_allow_html=True)
- return
-
- for i, table in enumerate(tables):
- st.subheader(f"📊 表格 {i+1}")
-
- # 表格信息显示
- col_info1, col_info2, col_info3, col_info4 = st.columns(4)
- with col_info1:
- st.metric("行数", len(table))
- with col_info2:
- st.metric("列数", len(table.columns))
- with col_info3:
- # 检查是否有超宽表格
- is_wide_table = len(table.columns) > 8
- st.metric("表格类型", "超宽表格" if is_wide_table else "普通表格")
- with col_info4:
- # 表格操作模式选择
- display_mode = st.selectbox(
- f"显示模式 (表格{i+1})",
- ["完整显示", "分页显示", "筛选列显示"],
- key=f"display_mode_{i}"
- )
-
- # 创建表格操作按钮
- col1, col2, col3, col4 = st.columns(4)
- with col1:
- show_info = st.checkbox(f"显示详细信息", key=f"info_{i}")
- with col2:
- show_stats = st.checkbox(f"显示统计信息", key=f"stats_{i}")
- with col3:
- enable_filter = st.checkbox(f"启用过滤", key=f"filter_{i}")
- with col4:
- enable_sort = st.checkbox(f"启用排序", key=f"sort_{i}")
-
- # 根据显示模式处理表格
- display_table = self._process_table_display_mode(table, i, display_mode)
-
- # 数据过滤和排序逻辑
- filtered_table = self._apply_table_filters_and_sorts(display_table, i, enable_filter, enable_sort)
-
- # 显示表格 - 使用自定义CSS支持横向滚动
- st.markdown("""
- <style>
- .dataframe-container {
- overflow-x: auto;
- border: 1px solid #ddd;
- border-radius: 5px;
- margin: 10px 0;
- }
-
- /* 为超宽表格特殊样式 */
- .wide-table-container {
- overflow-x: auto;
- max-height: 500px;
- overflow-y: auto;
- border: 2px solid #0288d1;
- border-radius: 8px;
- background: linear-gradient(90deg, #f8f9fa 0%, #ffffff 100%);
- }
-
- .dataframe thead th {
- position: sticky;
- top: 0;
- background-color: #f5f5f5 !important;
- z-index: 10;
- border-bottom: 2px solid #0288d1;
- }
-
- .dataframe tbody td {
- white-space: nowrap;
- min-width: 100px;
- max-width: 300px;
- overflow: hidden;
- text-overflow: ellipsis;
- }
- </style>
- """, unsafe_allow_html=True)
-
- # 根据表格宽度选择显示容器
- container_class = "wide-table-container" if len(table.columns) > 8 else "dataframe-container"
-
- if enable_editing:
- st.markdown(f'<div class="{container_class}">', unsafe_allow_html=True)
- edited_table = st.data_editor(
- filtered_table,
- use_container_width=True,
- key=f"editor_{i}",
- height=400 if len(table.columns) > 8 else None
- )
- st.markdown('</div>', unsafe_allow_html=True)
-
- if not edited_table.equals(filtered_table):
- st.success("✏️ 表格已编辑,可以导出修改后的数据")
- else:
- st.markdown(f'<div class="{container_class}">', unsafe_allow_html=True)
- st.dataframe(
- filtered_table,
- # use_container_width=True,
- width =400 if len(table.columns) > 8 else "stretch"
- )
- st.markdown('</div>', unsafe_allow_html=True)
-
- # 显示表格信息和统计
- self._display_table_info_and_stats(table, filtered_table, show_info, show_stats, i)
-
- st.markdown("---")
-
- def _apply_table_filters_and_sorts(self, table: pd.DataFrame, table_index: int, enable_filter: bool, enable_sort: bool) -> pd.DataFrame:
- """应用表格过滤和排序"""
- filtered_table = table.copy()
-
- # 数据过滤
- if enable_filter and not table.empty:
- filter_col = st.selectbox(
- f"选择过滤列 (表格 {table_index+1})",
- options=['无'] + list(table.columns),
- key=f"filter_col_{table_index}"
- )
-
- if filter_col != '无':
- filter_value = st.text_input(f"过滤值 (表格 {table_index+1})", key=f"filter_value_{table_index}")
- if filter_value:
- filtered_table = table[table[filter_col].astype(str).str.contains(filter_value, na=False)]
-
- # 数据排序
- if enable_sort and not filtered_table.empty:
- sort_col = st.selectbox(
- f"选择排序列 (表格 {table_index+1})",
- options=['无'] + list(filtered_table.columns),
- key=f"sort_col_{table_index}"
- )
-
- if sort_col != '无':
- sort_order = st.radio(
- f"排序方式 (表格 {table_index+1})",
- options=['升序', '降序'],
- horizontal=True,
- key=f"sort_order_{table_index}"
- )
- ascending = (sort_order == '升序')
- filtered_table = filtered_table.sort_values(sort_col, ascending=ascending)
-
- return filtered_table
-
- def _display_table_info_and_stats(self, original_table: pd.DataFrame, filtered_table: pd.DataFrame,
- show_info: bool, show_stats: bool, table_index: int):
- """显示表格信息和统计数据"""
- if show_info:
- st.write("**表格信息:**")
- st.write(f"- 原始行数: {len(original_table)}")
- st.write(f"- 过滤后行数: {len(filtered_table)}")
- st.write(f"- 列数: {len(original_table.columns)}")
- st.write(f"- 列名: {', '.join(original_table.columns)}")
-
- if show_stats:
- st.write("**统计信息:**")
- numeric_cols = filtered_table.select_dtypes(include=[np.number]).columns
- if len(numeric_cols) > 0:
- st.dataframe(filtered_table[numeric_cols].describe())
- else:
- st.info("表格中没有数值列")
-
- # 导出功能
- if st.button(f"📥 导出表格 {table_index+1}", key=f"export_{table_index}"):
- self._create_export_buttons(filtered_table, table_index)
-
- def _create_export_buttons(self, table: pd.DataFrame, table_index: int):
- """创建导出按钮"""
- # CSV导出
- csv_data = table.to_csv(index=False)
- st.download_button(
- label=f"下载CSV (表格 {table_index+1})",
- data=csv_data,
- file_name=f"table_{table_index+1}.csv",
- mime="text/csv",
- key=f"download_csv_{table_index}"
- )
-
- # Excel导出
- excel_buffer = BytesIO()
- table.to_excel(excel_buffer, index=False)
- st.download_button(
- label=f"下载Excel (表格 {table_index+1})",
- data=excel_buffer.getvalue(),
- file_name=f"table_{table_index+1}.xlsx",
- mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
- key=f"download_excel_{table_index}"
- )
-
- def _process_table_display_mode(self, table: pd.DataFrame, table_index: int, display_mode: str) -> pd.DataFrame:
- """根据显示模式处理表格"""
- if display_mode == "分页显示":
- # 分页显示
- page_size = st.selectbox(
- f"每页显示行数 (表格 {table_index+1})",
- [10, 20, 50, 100],
- key=f"page_size_{table_index}"
- )
-
- total_pages = (len(table) - 1) // page_size + 1
-
- if total_pages > 1:
- page_number = st.selectbox(
- f"页码 (表格 {table_index+1})",
- range(1, total_pages + 1),
- key=f"page_number_{table_index}"
- )
-
- start_idx = (page_number - 1) * page_size
- end_idx = start_idx + page_size
- return table.iloc[start_idx:end_idx]
-
- return table
-
- elif display_mode == "筛选列显示":
- # 列筛选显示
- if len(table.columns) > 5:
- selected_columns = st.multiselect(
- f"选择要显示的列 (表格 {table_index+1})",
- table.columns.tolist(),
- default=table.columns.tolist()[:5], # 默认显示前5列
- key=f"selected_columns_{table_index}"
- )
-
- if selected_columns:
- return table[selected_columns]
-
- return table
-
- else: # 完整显示
- return table
-
- @st.dialog("VLM预校验", width="large", dismissible=True, on_dismiss="rerun")
- def vlm_pre_validation(self):
- """VLM预校验功能 - 封装OCR识别和结果对比"""
-
- if not self.image_path or not self.md_content:
- st.error("❌ 请先加载OCR数据文件")
- return
- # 初始化对比结果存储
- if 'comparison_result' not in st.session_state:
- st.session_state.comparison_result = None
- # 创建进度条和状态显示
- with st.spinner("正在进行VLM预校验...", show_time=True):
- status_text = st.empty()
-
- try:
- current_md_path = Path(self.file_paths[self.selected_file_index]).with_suffix('.md')
- if not current_md_path.exists():
- st.error("❌ 当前OCR结果的Markdown文件不存在,无法进行对比")
- return
- # 第一步:准备目录
- pre_validation_dir = Path(self.config['paths'].get('pre_validation_dir', './output/pre_validation/')).resolve()
- pre_validation_dir.mkdir(parents=True, exist_ok=True)
- status_text.write(f"工作目录: {pre_validation_dir}")
- # 第二步:调用VLM进行OCR识别
- status_text.text("🤖 正在调用VLM进行OCR识别...")
-
- # 在expander中显示OCR过程
- with st.expander("🔍 VLM OCR识别过程", expanded=True):
- ocr_output = st.empty()
-
- # 捕获OCR输出
- import io
- import contextlib
-
- # 创建字符串缓冲区来捕获print输出
- output_buffer = io.StringIO()
-
- with contextlib.redirect_stdout(output_buffer):
- ocr_result = ocr_with_vlm(
- image_path=str(self.image_path),
- output_dir=str(pre_validation_dir),
- normalize_numbers=True
- )
-
- # 显示OCR过程输出
- ocr_output.code(output_buffer.getvalue(), language='text')
-
- status_text.text("✅ VLM OCR识别完成")
-
- # 第三步:获取VLM生成的文件路径
- vlm_md_path = pre_validation_dir / f"{Path(self.image_path).stem}.md"
-
- if not vlm_md_path.exists():
- st.error("❌ VLM OCR结果文件未生成")
- return
-
- # 第四步:调用对比功能
- status_text.text("📊 正在对比OCR结果...")
-
- # 在expander中显示对比过程
- comparison_result_path = pre_validation_dir / f"{current_md_path.stem}_comparison_result"
- with st.expander("🔍 OCR结果对比过程", expanded=True):
- compare_output = st.empty()
-
- # 捕获对比输出
- output_buffer = io.StringIO()
-
- with contextlib.redirect_stdout(output_buffer):
- comparison_result = compare_ocr_results(
- file1_path=str(current_md_path),
- file2_path=str(vlm_md_path),
- output_file=str(comparison_result_path),
- output_format='both',
- ignore_images=True
- )
-
- # 显示对比过程输出
- compare_output.code(output_buffer.getvalue(), language='text')
-
- status_text.text("✅ VLM预校验完成")
- st.session_state.comparison_result = {
- "image_path": self.image_path,
- "comparison_result_json": f"{comparison_result_path}.json",
- "comparison_result_md": f"{comparison_result_path}.md",
- "comparison_result": comparison_result
- }
-
- # 第五步:显示对比结果
- self.display_comparison_results(comparison_result)
-
- # 第六步:提供文件下载
- # self.provide_download_options(pre_validation_dir, vlm_md_path, comparison_result)
-
- except Exception as e:
- st.error(f"❌ VLM预校验失败: {e}")
- st.exception(e)
-
- def display_comparison_results(self, comparison_result: dict):
- """显示对比结果摘要 - 使用DataFrame展示"""
-
- st.header("📊 VLM预校验结果")
-
- # 统计信息
- stats = comparison_result['statistics']
-
- # 统计信息概览
- col1, col2, col3, col4 = st.columns(4)
- with col1:
- st.metric("总差异数", stats['total_differences'])
- with col2:
- st.metric("表格差异", stats['table_differences'])
- with col3:
- st.metric("金额差异", stats['amount_differences'])
- with col4:
- st.metric("段落差异", stats['paragraph_differences'])
-
- # 结果判断
- if stats['total_differences'] == 0:
- st.success("🎉 完美匹配!VLM识别结果与原OCR结果完全一致")
- else:
- st.warning(f"⚠️ 发现 {stats['total_differences']} 个差异,建议人工检查")
-
- # 使用DataFrame显示差异详情
- if comparison_result['differences']:
- st.subheader("🔍 差异详情对比")
-
- # 准备DataFrame数据
- diff_data = []
- for i, diff in enumerate(comparison_result['differences'], 1):
- diff_data.append({
- '序号': i,
- '位置': diff['position'],
- '类型': diff['type'],
- '原OCR结果': diff['file1_value'][:100] + ('...' if len(diff['file1_value']) > 100 else ''),
- 'VLM识别结果': diff['file2_value'][:100] + ('...' if len(diff['file2_value']) > 100 else ''),
- '描述': diff['description'][:80] + ('...' if len(diff['description']) > 80 else ''),
- '严重程度': self._get_severity_level(diff)
- })
-
- # 创建DataFrame
- df_differences = pd.DataFrame(diff_data)
-
- # 添加样式
- def highlight_severity(val):
- """根据严重程度添加颜色"""
- if val == '高':
- return 'background-color: #ffebee; color: #c62828'
- elif val == '中':
- return 'background-color: #fff3e0; color: #ef6c00'
- elif val == '低':
- return 'background-color: #e8f5e8; color: #2e7d32'
- return ''
-
- # 显示DataFrame
- styled_df = df_differences.style.applymap(
- highlight_severity,
- subset=['严重程度']
- ).format({
- '序号': '{:d}',
- })
-
- st.dataframe(
- styled_df,
- use_container_width=True,
- height=400,
- hide_index=True,
- column_config={
- "序号": st.column_config.NumberColumn(
- "序号",
- width=None, # 自动调整宽度
- pinned=True,
- help="差异项序号"
- ),
- "位置": st.column_config.TextColumn(
- "位置",
- width=None, # 自动调整宽度
- pinned=True,
- help="差异在文档中的位置"
- ),
- "类型": st.column_config.TextColumn(
- "类型",
- width=None, # 自动调整宽度
- pinned=True,
- help="差异类型"
- ),
- "原OCR结果": st.column_config.TextColumn(
- "原OCR结果",
- width="large", # 自动调整宽度
- pinned=True,
- help="原始OCR识别结果"
- ),
- "VLM识别结果": st.column_config.TextColumn(
- "VLM识别结果",
- width="large", # 自动调整宽度
- help="VLM重新识别的结果"
- ),
- "描述": st.column_config.TextColumn(
- "描述",
- width="medium", # 自动调整宽度
- help="差异详细描述"
- ),
- "严重程度": st.column_config.TextColumn(
- "严重程度",
- width=None, # 自动调整宽度
- help="差异严重程度评级"
- )
- }
- )
-
- # 详细差异查看
- st.subheader("🔍 详细差异查看")
-
- # 选择要查看的差异
- selected_diff_index = st.selectbox(
- "选择要查看的差异:",
- options=range(len(comparison_result['differences'])),
- format_func=lambda x: f"差异 {x+1}: {comparison_result['differences'][x]['position']} - {comparison_result['differences'][x]['type']}",
- key="selected_diff"
- )
-
- if selected_diff_index is not None:
- diff = comparison_result['differences'][selected_diff_index]
-
- # 并排显示完整内容
- col1, col2 = st.columns(2)
-
- with col1:
- st.write("**原OCR结果:**")
- st.text_area(
- "原OCR结果详情",
- value=diff['file1_value'],
- height=200,
- key=f"original_{selected_diff_index}",
- label_visibility="collapsed"
- )
-
- with col2:
- st.write("**VLM识别结果:**")
- st.text_area(
- "VLM识别结果详情",
- value=diff['file2_value'],
- height=200,
- key=f"vlm_{selected_diff_index}",
- label_visibility="collapsed"
- )
-
- # 差异详细信息
- st.info(f"**位置:** {diff['position']}")
- st.info(f"**类型:** {diff['type']}")
- st.info(f"**描述:** {diff['description']}")
- st.info(f"**严重程度:** {self._get_severity_level(diff)}")
-
- # 差异统计图表
- st.subheader("📈 差异类型分布")
-
- # 按类型统计差异
- type_counts = {}
- severity_counts = {'高': 0, '中': 0, '低': 0}
-
- for diff in comparison_result['differences']:
- diff_type = diff['type']
- type_counts[diff_type] = type_counts.get(diff_type, 0) + 1
-
- severity = self._get_severity_level(diff)
- severity_counts[severity] += 1
-
- col1, col2 = st.columns(2)
-
- with col1:
- # 类型分布饼图
- if type_counts:
- fig_type = px.pie(
- values=list(type_counts.values()),
- names=list(type_counts.keys()),
- title="差异类型分布"
- )
- st.plotly_chart(fig_type, use_container_width=True)
-
- with col2:
- # 严重程度分布条形图
- fig_severity = px.bar(
- x=list(severity_counts.keys()),
- y=list(severity_counts.values()),
- title="差异严重程度分布",
- color=list(severity_counts.keys()),
- color_discrete_map={'高': '#f44336', '中': '#ff9800', '低': '#4caf50'}
- )
- st.plotly_chart(fig_severity, use_container_width=True)
-
- # 下载选项
- self._provide_download_options_in_results(comparison_result)
- def _get_severity_level(self, diff: dict) -> str:
- """根据差异类型和内容判断严重程度"""
- diff_type = diff['type'].lower()
-
- # 金额相关差异为高严重程度
- if 'amount' in diff_type or 'number' in diff_type:
- return '高'
-
- # 表格结构差异为中等严重程度
- if 'table' in diff_type or 'structure' in diff_type:
- return '中'
-
- # 检查内容长度差异
- len_diff = abs(len(diff['file1_value']) - len(diff['file2_value']))
- if len_diff > 50:
- return '高'
- elif len_diff > 10:
- return '中'
- else:
- return '低'
- def _provide_download_options_in_results(self, comparison_result: dict):
- """在结果页面提供下载选项"""
-
- st.subheader("📥 导出预校验结果")
-
- col1, col2, col3 = st.columns(3)
-
- with col1:
- # 导出差异详情为Excel
- if comparison_result['differences']:
- diff_data = []
- for i, diff in enumerate(comparison_result['differences'], 1):
- diff_data.append({
- '序号': i,
- '位置': diff['position'],
- '类型': diff['type'],
- '原OCR结果': diff['file1_value'],
- 'VLM识别结果': diff['file2_value'],
- '描述': diff['description'],
- '严重程度': self._get_severity_level(diff)
- })
-
- df_export = pd.DataFrame(diff_data)
- excel_buffer = BytesIO()
- df_export.to_excel(excel_buffer, index=False, sheet_name='差异详情')
-
- st.download_button(
- label="📊 下载差异详情(Excel)",
- data=excel_buffer.getvalue(),
- file_name=f"vlm_comparison_differences_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.xlsx",
- mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
- key="download_differences_excel"
- )
-
- with col2:
- # 导出统计报告
- stats_data = {
- '统计项目': ['总差异数', '表格差异', '金额差异', '段落差异'],
- '数量': [
- comparison_result['statistics']['total_differences'],
- comparison_result['statistics']['table_differences'],
- comparison_result['statistics']['amount_differences'],
- comparison_result['statistics']['paragraph_differences']
- ]
- }
-
- df_stats = pd.DataFrame(stats_data)
- csv_stats = df_stats.to_csv(index=False)
-
- st.download_button(
- label="📈 下载统计报告(CSV)",
- data=csv_stats,
- file_name=f"vlm_comparison_stats_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.csv",
- mime="text/csv",
- key="download_stats_csv"
- )
-
- with col3:
- # 导出完整报告为JSON
- import json
-
- report_json = json.dumps(comparison_result, ensure_ascii=False, indent=2)
-
- st.download_button(
- label="📄 下载完整报告(JSON)",
- data=report_json,
- file_name=f"vlm_comparison_full_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.json",
- mime="application/json",
- key="download_full_json"
- )
-
- # 操作建议
- st.subheader("🚀 后续操作建议")
-
- total_diffs = comparison_result['statistics']['total_differences']
- if total_diffs == 0:
- st.success("✅ VLM识别结果与原OCR完全一致,可信度很高,无需人工校验")
- elif total_diffs <= 5:
- st.warning("⚠️ 发现少量差异,建议重点检查高严重程度的差异项")
- elif total_diffs <= 20:
- st.warning("🔍 发现中等数量差异,建议详细检查差异表格中标红的项目")
- else:
- st.error("❌ 发现大量差异,建议重新进行OCR识别或检查原始图片质量")
-
- @st.dialog("查看预校验结果", width="large", dismissible=True, on_dismiss="rerun")
- def show_comparison_results_dialog(self):
- """显示VLM预校验结果的对话框"""
- current_md_path = Path(self.file_paths[self.selected_file_index]).with_suffix('.md')
- pre_validation_dir = Path(self.config['paths'].get('pre_validation_dir', './output/pre_validation/')).resolve()
- comparison_result_path = pre_validation_dir / f"{current_md_path.stem}_comparison_result.json"
- if 'comparison_result' in st.session_state and st.session_state.comparison_result:
- self.display_comparison_results(st.session_state.comparison_result['comparison_result'])
- elif comparison_result_path.exists():
- # 如果pre_validation_dir下有结果文件,提示用户加载
- if st.button("加载预校验结果"):
- with open(comparison_result_path, "r", encoding="utf-8") as f:
- comparison_json_result = json.load(f)
- comparison_result = {
- "image_path": self.image_path,
- "comparison_result_json": str(comparison_result_path),
- "comparison_result_md": str(comparison_result_path.with_suffix('.md')),
- "comparison_result": comparison_json_result
- }
-
- st.session_state.comparison_result = comparison_result
- self.display_comparison_results(comparison_json_result)
- else:
- st.info("暂无预校验结果,请先运行VLM预校验")
- def create_compact_layout(self, config):
- """创建滚动凑布局"""
- return self.layout_manager.create_compact_layout(config)
- @st.dialog("message", width="small", dismissible=True, on_dismiss="rerun")
- def message_box(msg: str, msg_type: str = "info"):
- if msg_type == "info":
- st.info(msg)
- elif msg_type == "warning":
- st.warning(msg)
- elif msg_type == "error":
- st.error(msg)
- def main():
- """主应用"""
- # 初始化应用
- if 'validator' not in st.session_state:
- validator = StreamlitOCRValidator()
- st.session_state.validator = validator
- st.session_state.validator.setup_page_config()
- # 页面标题
- config = st.session_state.validator.config
- st.title(config['ui']['page_title'])
- # st.markdown("---")
- else:
- # 主内容区域
- validator = st.session_state.validator
- config = st.session_state.validator.config
-
- if 'selected_text' not in st.session_state:
- st.session_state.selected_text = None
-
- if 'marked_errors' not in st.session_state:
- st.session_state.marked_errors = set()
-
- with st.container(height=75, horizontal=True, horizontal_alignment='left', gap="medium"):
- # st.subheader("📁 文件选择")
- # 初始化session_state中的选择索引
- if 'selected_file_index' not in st.session_state:
- st.session_state.selected_file_index = 0
- if validator.display_options:
- # 创建显示选项列表
- selected_index = st.selectbox("选择OCR结果文件",
- range(len(validator.display_options)),
- format_func=lambda i: validator.display_options[i],
- index=st.session_state.selected_file_index,
- width=100,
- key="selected_selectbox",
- label_visibility="collapsed")
- # 更新session_state
- if selected_index != st.session_state.selected_file_index:
- st.session_state.selected_file_index = selected_index
- selected_file = validator.file_paths[selected_index]
- # number_input, 范围是文件数量,默认值是1,步长是1
- # 页码输入器
- current_page = validator.file_info[selected_index]['page']
- page_input = st.number_input("输入一个数字",
- placeholder="输入页码",
- width=200,
- label_visibility="collapsed",
- min_value=1, max_value=len(validator.display_options), value=current_page, step=1,
- key="page_input"
- )
- # 当页码输入改变时,更新文件选择
- if page_input != current_page:
- # 查找对应页码的文件索引
- for i, info in enumerate(validator.file_info):
- if info['page'] == page_input:
- st.session_state.selected_file_index = i
- selected_file = validator.file_paths[i]
- st.rerun()
- break
- if (st.session_state.selected_file_index >= 0
- and validator.selected_file_index != st.session_state.selected_file_index
- and selected_file):
- validator.selected_file_index = st.session_state.selected_file_index
- st.session_state.validator.load_ocr_data(selected_file)
- st.success(f"✅ 已加载第{validator.file_info[st.session_state.selected_file_index]['page']}页")
- st.rerun()
- # if st.button("🔄 加载文件", type="secondary") and selected_file:
- # st.session_state.validator.load_ocr_data(selected_file)
- # st.success(f"✅ 已加载第{validator.file_info[selected_index]['page']}页")
- # st.rerun()
- else:
- st.warning("未找到OCR结果文件")
- st.info("请确保output目录下有OCR结果文件")
-
- if st.button("VLM预校验", type="primary", icon=":material/compare_arrows:"):
- if validator.image_path and validator.md_content:
- # 创建新的页面区域来显示VLM预校验结果
- validator.vlm_pre_validation()
- else:
- message_box("❌ 请先加载OCR数据文件", "error")
- if st.button("查看预校验结果", type="secondary", icon=":material/quick_reference_all:"):
- validator.show_comparison_results_dialog()
- with st.expander("🔧 OCR工具统计信息", expanded=False):
- # 显示统计信息
- stats = validator.get_statistics()
- col1, col2, col3, col4, col5 = st.columns(5) # 增加一列
- with col1:
- st.metric("📊 总文本块", stats['total_texts'])
- with col2:
- st.metric("🔗 可点击文本", stats['clickable_texts'])
- with col3:
- st.metric("❌ 标记错误", stats['marked_errors'])
- with col4:
- st.metric("✅ 准确率", f"{stats['accuracy_rate']:.1f}%")
- with col5:
- # 显示OCR工具信息
- if stats['tool_info']:
- tool_names = list(stats['tool_info'].keys())
- main_tool = tool_names[0] if tool_names else "未知"
- st.metric("🔧 OCR工具", main_tool)
- # 详细工具信息
- if stats['tool_info']:
- st.write(stats['tool_info'])
-
- # st.markdown("---")
-
- # 创建标签页
- tab1, tab2, tab3, tab4 = st.tabs(["📄 内容校验", "📊 表格分析", "📈 数据统计", "🚀 快速导航"])
-
- with tab1:
- validator.create_compact_layout(config)
- with tab2:
- # 表格分析页面
- st.header("📊 表格数据分析")
-
- if validator.md_content and '<table' in validator.md_content.lower():
- col1, col2 = st.columns([2, 1])
-
- with col1:
- st.subheader("🔍 表格数据预览")
- validator.display_html_table_as_dataframe(validator.md_content)
-
- with col2:
- st.subheader("⚙️ 表格操作")
-
- if st.button("📥 导出表格数据", type="primary"):
- tables = parse_html_tables(validator.md_content)
- if tables:
- output = export_tables_to_excel(tables)
- st.download_button(
- label="📥 下载Excel文件",
- data=output.getvalue(),
- file_name="ocr_tables.xlsx",
- mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
- )
- else:
- st.info("当前OCR结果中没有检测到表格数据")
-
- with tab3:
- # 数据统计页面
- st.header("📈 OCR数据统计")
-
- if stats['categories']:
- st.subheader("📊 类别分布")
- fig_pie = px.pie(
- values=list(stats['categories'].values()),
- names=list(stats['categories'].keys()),
- title="文本类别分布"
- )
- st.plotly_chart(fig_pie, use_container_width=True)
-
- # 错误率分析
- st.subheader("📈 质量分析")
- accuracy_data = {
- '状态': ['正确', '错误'],
- '数量': [stats['clickable_texts'] - stats['marked_errors'], stats['marked_errors']]
- }
-
- fig_bar = px.bar(
- accuracy_data, x='状态', y='数量', title="识别质量分布",
- color='状态', color_discrete_map={'正确': 'green', '错误': 'red'}
- )
- st.plotly_chart(fig_bar, use_container_width=True)
-
- with tab4:
- # 快速导航功能
- st.header("🚀 快速导航")
-
- if not validator.text_bbox_mapping:
- st.info("没有可用的文本项进行导航")
- else:
- # 按类别分组
- categories = group_texts_by_category(validator.text_bbox_mapping)
-
- # 创建导航按钮
- for category, texts in categories.items():
- with st.expander(f"{category} ({len(texts)}项)", expanded=False):
- cols = st.columns(3) # 每行3个按钮
- for i, text in enumerate(texts):
- col_idx = i % 3
- with cols[col_idx]:
- display_text = text[:15] + "..." if len(text) > 15 else text
- if st.button(display_text, key=f"nav_{category}_{i}"):
- st.session_state.selected_text = text
- st.rerun()
- if __name__ == "__main__":
- main()
|