|
@@ -13,6 +13,7 @@ from io import BytesIO
|
|
|
import pandas as pd
|
|
import pandas as pd
|
|
|
import numpy as np
|
|
import numpy as np
|
|
|
import plotly.express as px
|
|
import plotly.express as px
|
|
|
|
|
+import json
|
|
|
|
|
|
|
|
# 导入工具模块
|
|
# 导入工具模块
|
|
|
from ocr_validator_utils import (
|
|
from ocr_validator_utils import (
|
|
@@ -22,6 +23,8 @@ from ocr_validator_utils import (
|
|
|
export_tables_to_excel, get_table_statistics, group_texts_by_category
|
|
export_tables_to_excel, get_table_statistics, group_texts_by_category
|
|
|
)
|
|
)
|
|
|
from ocr_validator_layout import OCRLayoutManager
|
|
from ocr_validator_layout import OCRLayoutManager
|
|
|
|
|
+from ocr_by_vlm import ocr_with_vlm
|
|
|
|
|
+from compare_ocr_results import compare_ocr_results
|
|
|
|
|
|
|
|
|
|
|
|
|
class StreamlitOCRValidator:
|
|
class StreamlitOCRValidator:
|
|
@@ -349,11 +352,443 @@ class StreamlitOCRValidator:
|
|
|
else: # 完整显示
|
|
else: # 完整显示
|
|
|
return table
|
|
return table
|
|
|
|
|
|
|
|
- # 布局方法现在委托给布局管理器
|
|
|
|
|
|
|
+ @st.dialog("VLM预校验", width="large", dismissible=True, on_dismiss="rerun")
|
|
|
|
|
+ def vlm_pre_validation(self):
|
|
|
|
|
+ """VLM预校验功能 - 封装OCR识别和结果对比"""
|
|
|
|
|
+
|
|
|
|
|
+ if not self.image_path or not self.md_content:
|
|
|
|
|
+ st.error("❌ 请先加载OCR数据文件")
|
|
|
|
|
+ return
|
|
|
|
|
+ # 初始化对比结果存储
|
|
|
|
|
+ if 'comparison_result' not in st.session_state:
|
|
|
|
|
+ st.session_state.comparison_result = None
|
|
|
|
|
+
|
|
|
|
|
+ # 创建进度条和状态显示
|
|
|
|
|
+ with st.spinner("正在进行VLM预校验...", show_time=True):
|
|
|
|
|
+ status_text = st.empty()
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ current_md_path = Path(self.file_paths[self.selected_file_index]).with_suffix('.md')
|
|
|
|
|
+ if not current_md_path.exists():
|
|
|
|
|
+ st.error("❌ 当前OCR结果的Markdown文件不存在,无法进行对比")
|
|
|
|
|
+ return
|
|
|
|
|
+ # 第一步:准备目录
|
|
|
|
|
+ pre_validation_dir = Path(self.config['paths'].get('pre_validation_dir', './output/pre_validation/')).resolve()
|
|
|
|
|
+ pre_validation_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
+ status_text.write(f"工作目录: {pre_validation_dir}")
|
|
|
|
|
+
|
|
|
|
|
+ # 第二步:调用VLM进行OCR识别
|
|
|
|
|
+ status_text.text("🤖 正在调用VLM进行OCR识别...")
|
|
|
|
|
+
|
|
|
|
|
+ # 在expander中显示OCR过程
|
|
|
|
|
+ with st.expander("🔍 VLM OCR识别过程", expanded=True):
|
|
|
|
|
+ ocr_output = st.empty()
|
|
|
|
|
+
|
|
|
|
|
+ # 捕获OCR输出
|
|
|
|
|
+ import io
|
|
|
|
|
+ import contextlib
|
|
|
|
|
+
|
|
|
|
|
+ # 创建字符串缓冲区来捕获print输出
|
|
|
|
|
+ output_buffer = io.StringIO()
|
|
|
|
|
+
|
|
|
|
|
+ with contextlib.redirect_stdout(output_buffer):
|
|
|
|
|
+ ocr_result = ocr_with_vlm(
|
|
|
|
|
+ image_path=str(self.image_path),
|
|
|
|
|
+ output_dir=str(pre_validation_dir),
|
|
|
|
|
+ normalize_numbers=True
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # 显示OCR过程输出
|
|
|
|
|
+ ocr_output.code(output_buffer.getvalue(), language='text')
|
|
|
|
|
+
|
|
|
|
|
+ status_text.text("✅ VLM OCR识别完成")
|
|
|
|
|
+
|
|
|
|
|
+ # 第三步:获取VLM生成的文件路径
|
|
|
|
|
+ vlm_md_path = pre_validation_dir / f"{Path(self.image_path).stem}.md"
|
|
|
|
|
+
|
|
|
|
|
+ if not vlm_md_path.exists():
|
|
|
|
|
+ st.error("❌ VLM OCR结果文件未生成")
|
|
|
|
|
+ return
|
|
|
|
|
+
|
|
|
|
|
+ # 第四步:调用对比功能
|
|
|
|
|
+ status_text.text("📊 正在对比OCR结果...")
|
|
|
|
|
+
|
|
|
|
|
+ # 在expander中显示对比过程
|
|
|
|
|
+ comparison_result_path = pre_validation_dir / f"{current_md_path.stem}_comparison_result"
|
|
|
|
|
+ with st.expander("🔍 OCR结果对比过程", expanded=True):
|
|
|
|
|
+ compare_output = st.empty()
|
|
|
|
|
+
|
|
|
|
|
+ # 捕获对比输出
|
|
|
|
|
+ output_buffer = io.StringIO()
|
|
|
|
|
+
|
|
|
|
|
+ with contextlib.redirect_stdout(output_buffer):
|
|
|
|
|
+ comparison_result = compare_ocr_results(
|
|
|
|
|
+ file1_path=str(current_md_path),
|
|
|
|
|
+ file2_path=str(vlm_md_path),
|
|
|
|
|
+ output_file=str(comparison_result_path),
|
|
|
|
|
+ output_format='both',
|
|
|
|
|
+ ignore_images=True
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # 显示对比过程输出
|
|
|
|
|
+ compare_output.code(output_buffer.getvalue(), language='text')
|
|
|
|
|
+
|
|
|
|
|
+ status_text.text("✅ VLM预校验完成")
|
|
|
|
|
+
|
|
|
|
|
+ st.session_state.comparison_result = {
|
|
|
|
|
+ "image_path": self.image_path,
|
|
|
|
|
+ "comparison_result_json": f"{comparison_result_path}.json",
|
|
|
|
|
+ "comparison_result_md": f"{comparison_result_path}.md",
|
|
|
|
|
+ "comparison_result": comparison_result
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ # 第五步:显示对比结果
|
|
|
|
|
+ self.display_comparison_results(comparison_result)
|
|
|
|
|
+
|
|
|
|
|
+ # 第六步:提供文件下载
|
|
|
|
|
+ # self.provide_download_options(pre_validation_dir, vlm_md_path, comparison_result)
|
|
|
|
|
+
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ st.error(f"❌ VLM预校验失败: {e}")
|
|
|
|
|
+ st.exception(e)
|
|
|
|
|
+
|
|
|
|
|
+ def display_comparison_results(self, comparison_result: dict):
|
|
|
|
|
+ """显示对比结果摘要 - 使用DataFrame展示"""
|
|
|
|
|
+
|
|
|
|
|
+ st.header("📊 VLM预校验结果")
|
|
|
|
|
+
|
|
|
|
|
+ # 统计信息
|
|
|
|
|
+ stats = comparison_result['statistics']
|
|
|
|
|
+
|
|
|
|
|
+ # 统计信息概览
|
|
|
|
|
+ col1, col2, col3, col4 = st.columns(4)
|
|
|
|
|
+ with col1:
|
|
|
|
|
+ st.metric("总差异数", stats['total_differences'])
|
|
|
|
|
+ with col2:
|
|
|
|
|
+ st.metric("表格差异", stats['table_differences'])
|
|
|
|
|
+ with col3:
|
|
|
|
|
+ st.metric("金额差异", stats['amount_differences'])
|
|
|
|
|
+ with col4:
|
|
|
|
|
+ st.metric("段落差异", stats['paragraph_differences'])
|
|
|
|
|
+
|
|
|
|
|
+ # 结果判断
|
|
|
|
|
+ if stats['total_differences'] == 0:
|
|
|
|
|
+ st.success("🎉 完美匹配!VLM识别结果与原OCR结果完全一致")
|
|
|
|
|
+ else:
|
|
|
|
|
+ st.warning(f"⚠️ 发现 {stats['total_differences']} 个差异,建议人工检查")
|
|
|
|
|
+
|
|
|
|
|
+ # 使用DataFrame显示差异详情
|
|
|
|
|
+ if comparison_result['differences']:
|
|
|
|
|
+ st.subheader("🔍 差异详情对比")
|
|
|
|
|
+
|
|
|
|
|
+ # 准备DataFrame数据
|
|
|
|
|
+ diff_data = []
|
|
|
|
|
+ for i, diff in enumerate(comparison_result['differences'], 1):
|
|
|
|
|
+ diff_data.append({
|
|
|
|
|
+ '序号': i,
|
|
|
|
|
+ '位置': diff['position'],
|
|
|
|
|
+ '类型': diff['type'],
|
|
|
|
|
+ '原OCR结果': diff['file1_value'][:100] + ('...' if len(diff['file1_value']) > 100 else ''),
|
|
|
|
|
+ 'VLM识别结果': diff['file2_value'][:100] + ('...' if len(diff['file2_value']) > 100 else ''),
|
|
|
|
|
+ '描述': diff['description'][:80] + ('...' if len(diff['description']) > 80 else ''),
|
|
|
|
|
+ '严重程度': self._get_severity_level(diff)
|
|
|
|
|
+ })
|
|
|
|
|
+
|
|
|
|
|
+ # 创建DataFrame
|
|
|
|
|
+ df_differences = pd.DataFrame(diff_data)
|
|
|
|
|
+
|
|
|
|
|
+ # 添加样式
|
|
|
|
|
+ def highlight_severity(val):
|
|
|
|
|
+ """根据严重程度添加颜色"""
|
|
|
|
|
+ if val == '高':
|
|
|
|
|
+ return 'background-color: #ffebee; color: #c62828'
|
|
|
|
|
+ elif val == '中':
|
|
|
|
|
+ return 'background-color: #fff3e0; color: #ef6c00'
|
|
|
|
|
+ elif val == '低':
|
|
|
|
|
+ return 'background-color: #e8f5e8; color: #2e7d32'
|
|
|
|
|
+ return ''
|
|
|
|
|
+
|
|
|
|
|
+ # 显示DataFrame
|
|
|
|
|
+ styled_df = df_differences.style.applymap(
|
|
|
|
|
+ highlight_severity,
|
|
|
|
|
+ subset=['严重程度']
|
|
|
|
|
+ ).format({
|
|
|
|
|
+ '序号': '{:d}',
|
|
|
|
|
+ })
|
|
|
|
|
+
|
|
|
|
|
+ st.dataframe(
|
|
|
|
|
+ styled_df,
|
|
|
|
|
+ use_container_width=True,
|
|
|
|
|
+ height=400,
|
|
|
|
|
+ hide_index=True,
|
|
|
|
|
+ column_config={
|
|
|
|
|
+ "序号": st.column_config.NumberColumn(
|
|
|
|
|
+ "序号",
|
|
|
|
|
+ width=None, # 自动调整宽度
|
|
|
|
|
+ pinned=True,
|
|
|
|
|
+ help="差异项序号"
|
|
|
|
|
+ ),
|
|
|
|
|
+ "位置": st.column_config.TextColumn(
|
|
|
|
|
+ "位置",
|
|
|
|
|
+ width=None, # 自动调整宽度
|
|
|
|
|
+ pinned=True,
|
|
|
|
|
+ help="差异在文档中的位置"
|
|
|
|
|
+ ),
|
|
|
|
|
+ "类型": st.column_config.TextColumn(
|
|
|
|
|
+ "类型",
|
|
|
|
|
+ width=None, # 自动调整宽度
|
|
|
|
|
+ pinned=True,
|
|
|
|
|
+ help="差异类型"
|
|
|
|
|
+ ),
|
|
|
|
|
+ "原OCR结果": st.column_config.TextColumn(
|
|
|
|
|
+ "原OCR结果",
|
|
|
|
|
+ width="large", # 自动调整宽度
|
|
|
|
|
+ pinned=True,
|
|
|
|
|
+ help="原始OCR识别结果"
|
|
|
|
|
+ ),
|
|
|
|
|
+ "VLM识别结果": st.column_config.TextColumn(
|
|
|
|
|
+ "VLM识别结果",
|
|
|
|
|
+ width="large", # 自动调整宽度
|
|
|
|
|
+ help="VLM重新识别的结果"
|
|
|
|
|
+ ),
|
|
|
|
|
+ "描述": st.column_config.TextColumn(
|
|
|
|
|
+ "描述",
|
|
|
|
|
+ width="medium", # 自动调整宽度
|
|
|
|
|
+ help="差异详细描述"
|
|
|
|
|
+ ),
|
|
|
|
|
+ "严重程度": st.column_config.TextColumn(
|
|
|
|
|
+ "严重程度",
|
|
|
|
|
+ width=None, # 自动调整宽度
|
|
|
|
|
+ help="差异严重程度评级"
|
|
|
|
|
+ )
|
|
|
|
|
+ }
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # 详细差异查看
|
|
|
|
|
+ st.subheader("🔍 详细差异查看")
|
|
|
|
|
+
|
|
|
|
|
+ # 选择要查看的差异
|
|
|
|
|
+ selected_diff_index = st.selectbox(
|
|
|
|
|
+ "选择要查看的差异:",
|
|
|
|
|
+ options=range(len(comparison_result['differences'])),
|
|
|
|
|
+ format_func=lambda x: f"差异 {x+1}: {comparison_result['differences'][x]['position']} - {comparison_result['differences'][x]['type']}",
|
|
|
|
|
+ key="selected_diff"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ if selected_diff_index is not None:
|
|
|
|
|
+ diff = comparison_result['differences'][selected_diff_index]
|
|
|
|
|
+
|
|
|
|
|
+ # 并排显示完整内容
|
|
|
|
|
+ col1, col2 = st.columns(2)
|
|
|
|
|
+
|
|
|
|
|
+ with col1:
|
|
|
|
|
+ st.write("**原OCR结果:**")
|
|
|
|
|
+ st.text_area(
|
|
|
|
|
+ "原OCR结果详情",
|
|
|
|
|
+ value=diff['file1_value'],
|
|
|
|
|
+ height=200,
|
|
|
|
|
+ key=f"original_{selected_diff_index}",
|
|
|
|
|
+ label_visibility="collapsed"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ with col2:
|
|
|
|
|
+ st.write("**VLM识别结果:**")
|
|
|
|
|
+ st.text_area(
|
|
|
|
|
+ "VLM识别结果详情",
|
|
|
|
|
+ value=diff['file2_value'],
|
|
|
|
|
+ height=200,
|
|
|
|
|
+ key=f"vlm_{selected_diff_index}",
|
|
|
|
|
+ label_visibility="collapsed"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # 差异详细信息
|
|
|
|
|
+ st.info(f"**位置:** {diff['position']}")
|
|
|
|
|
+ st.info(f"**类型:** {diff['type']}")
|
|
|
|
|
+ st.info(f"**描述:** {diff['description']}")
|
|
|
|
|
+ st.info(f"**严重程度:** {self._get_severity_level(diff)}")
|
|
|
|
|
+
|
|
|
|
|
+ # 差异统计图表
|
|
|
|
|
+ st.subheader("📈 差异类型分布")
|
|
|
|
|
+
|
|
|
|
|
+ # 按类型统计差异
|
|
|
|
|
+ type_counts = {}
|
|
|
|
|
+ severity_counts = {'高': 0, '中': 0, '低': 0}
|
|
|
|
|
+
|
|
|
|
|
+ for diff in comparison_result['differences']:
|
|
|
|
|
+ diff_type = diff['type']
|
|
|
|
|
+ type_counts[diff_type] = type_counts.get(diff_type, 0) + 1
|
|
|
|
|
+
|
|
|
|
|
+ severity = self._get_severity_level(diff)
|
|
|
|
|
+ severity_counts[severity] += 1
|
|
|
|
|
+
|
|
|
|
|
+ col1, col2 = st.columns(2)
|
|
|
|
|
+
|
|
|
|
|
+ with col1:
|
|
|
|
|
+ # 类型分布饼图
|
|
|
|
|
+ if type_counts:
|
|
|
|
|
+ fig_type = px.pie(
|
|
|
|
|
+ values=list(type_counts.values()),
|
|
|
|
|
+ names=list(type_counts.keys()),
|
|
|
|
|
+ title="差异类型分布"
|
|
|
|
|
+ )
|
|
|
|
|
+ st.plotly_chart(fig_type, use_container_width=True)
|
|
|
|
|
+
|
|
|
|
|
+ with col2:
|
|
|
|
|
+ # 严重程度分布条形图
|
|
|
|
|
+ fig_severity = px.bar(
|
|
|
|
|
+ x=list(severity_counts.keys()),
|
|
|
|
|
+ y=list(severity_counts.values()),
|
|
|
|
|
+ title="差异严重程度分布",
|
|
|
|
|
+ color=list(severity_counts.keys()),
|
|
|
|
|
+ color_discrete_map={'高': '#f44336', '中': '#ff9800', '低': '#4caf50'}
|
|
|
|
|
+ )
|
|
|
|
|
+ st.plotly_chart(fig_severity, use_container_width=True)
|
|
|
|
|
+
|
|
|
|
|
+ # 下载选项
|
|
|
|
|
+ self._provide_download_options_in_results(comparison_result)
|
|
|
|
|
+
|
|
|
|
|
+ def _get_severity_level(self, diff: dict) -> str:
|
|
|
|
|
+ """根据差异类型和内容判断严重程度"""
|
|
|
|
|
+ diff_type = diff['type'].lower()
|
|
|
|
|
+
|
|
|
|
|
+ # 金额相关差异为高严重程度
|
|
|
|
|
+ if 'amount' in diff_type or 'number' in diff_type:
|
|
|
|
|
+ return '高'
|
|
|
|
|
+
|
|
|
|
|
+ # 表格结构差异为中等严重程度
|
|
|
|
|
+ if 'table' in diff_type or 'structure' in diff_type:
|
|
|
|
|
+ return '中'
|
|
|
|
|
+
|
|
|
|
|
+ # 检查内容长度差异
|
|
|
|
|
+ len_diff = abs(len(diff['file1_value']) - len(diff['file2_value']))
|
|
|
|
|
+ if len_diff > 50:
|
|
|
|
|
+ return '高'
|
|
|
|
|
+ elif len_diff > 10:
|
|
|
|
|
+ return '中'
|
|
|
|
|
+ else:
|
|
|
|
|
+ return '低'
|
|
|
|
|
+
|
|
|
|
|
+ def _provide_download_options_in_results(self, comparison_result: dict):
|
|
|
|
|
+ """在结果页面提供下载选项"""
|
|
|
|
|
+
|
|
|
|
|
+ st.subheader("📥 导出预校验结果")
|
|
|
|
|
+
|
|
|
|
|
+ col1, col2, col3 = st.columns(3)
|
|
|
|
|
+
|
|
|
|
|
+ with col1:
|
|
|
|
|
+ # 导出差异详情为Excel
|
|
|
|
|
+ if comparison_result['differences']:
|
|
|
|
|
+ diff_data = []
|
|
|
|
|
+ for i, diff in enumerate(comparison_result['differences'], 1):
|
|
|
|
|
+ diff_data.append({
|
|
|
|
|
+ '序号': i,
|
|
|
|
|
+ '位置': diff['position'],
|
|
|
|
|
+ '类型': diff['type'],
|
|
|
|
|
+ '原OCR结果': diff['file1_value'],
|
|
|
|
|
+ 'VLM识别结果': diff['file2_value'],
|
|
|
|
|
+ '描述': diff['description'],
|
|
|
|
|
+ '严重程度': self._get_severity_level(diff)
|
|
|
|
|
+ })
|
|
|
|
|
+
|
|
|
|
|
+ df_export = pd.DataFrame(diff_data)
|
|
|
|
|
+ excel_buffer = BytesIO()
|
|
|
|
|
+ df_export.to_excel(excel_buffer, index=False, sheet_name='差异详情')
|
|
|
|
|
+
|
|
|
|
|
+ st.download_button(
|
|
|
|
|
+ label="📊 下载差异详情(Excel)",
|
|
|
|
|
+ data=excel_buffer.getvalue(),
|
|
|
|
|
+ file_name=f"vlm_comparison_differences_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.xlsx",
|
|
|
|
|
+ mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
|
|
|
|
+ key="download_differences_excel"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ with col2:
|
|
|
|
|
+ # 导出统计报告
|
|
|
|
|
+ stats_data = {
|
|
|
|
|
+ '统计项目': ['总差异数', '表格差异', '金额差异', '段落差异'],
|
|
|
|
|
+ '数量': [
|
|
|
|
|
+ comparison_result['statistics']['total_differences'],
|
|
|
|
|
+ comparison_result['statistics']['table_differences'],
|
|
|
|
|
+ comparison_result['statistics']['amount_differences'],
|
|
|
|
|
+ comparison_result['statistics']['paragraph_differences']
|
|
|
|
|
+ ]
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ df_stats = pd.DataFrame(stats_data)
|
|
|
|
|
+ csv_stats = df_stats.to_csv(index=False)
|
|
|
|
|
+
|
|
|
|
|
+ st.download_button(
|
|
|
|
|
+ label="📈 下载统计报告(CSV)",
|
|
|
|
|
+ data=csv_stats,
|
|
|
|
|
+ file_name=f"vlm_comparison_stats_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.csv",
|
|
|
|
|
+ mime="text/csv",
|
|
|
|
|
+ key="download_stats_csv"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ with col3:
|
|
|
|
|
+ # 导出完整报告为JSON
|
|
|
|
|
+ import json
|
|
|
|
|
+
|
|
|
|
|
+ report_json = json.dumps(comparison_result, ensure_ascii=False, indent=2)
|
|
|
|
|
+
|
|
|
|
|
+ st.download_button(
|
|
|
|
|
+ label="📄 下载完整报告(JSON)",
|
|
|
|
|
+ data=report_json,
|
|
|
|
|
+ file_name=f"vlm_comparison_full_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.json",
|
|
|
|
|
+ mime="application/json",
|
|
|
|
|
+ key="download_full_json"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # 操作建议
|
|
|
|
|
+ st.subheader("🚀 后续操作建议")
|
|
|
|
|
+
|
|
|
|
|
+ total_diffs = comparison_result['statistics']['total_differences']
|
|
|
|
|
+ if total_diffs == 0:
|
|
|
|
|
+ st.success("✅ VLM识别结果与原OCR完全一致,可信度很高,无需人工校验")
|
|
|
|
|
+ elif total_diffs <= 5:
|
|
|
|
|
+ st.warning("⚠️ 发现少量差异,建议重点检查高严重程度的差异项")
|
|
|
|
|
+ elif total_diffs <= 20:
|
|
|
|
|
+ st.warning("🔍 发现中等数量差异,建议详细检查差异表格中标红的项目")
|
|
|
|
|
+ else:
|
|
|
|
|
+ st.error("❌ 发现大量差异,建议重新进行OCR识别或检查原始图片质量")
|
|
|
|
|
+
|
|
|
|
|
+ @st.dialog("查看预校验结果", width="large", dismissible=True, on_dismiss="rerun")
|
|
|
|
|
+ def show_comparison_results_dialog(self):
|
|
|
|
|
+ """显示VLM预校验结果的对话框"""
|
|
|
|
|
+ current_md_path = Path(self.file_paths[self.selected_file_index]).with_suffix('.md')
|
|
|
|
|
+ pre_validation_dir = Path(self.config['paths'].get('pre_validation_dir', './output/pre_validation/')).resolve()
|
|
|
|
|
+ comparison_result_path = pre_validation_dir / f"{current_md_path.stem}_comparison_result.json"
|
|
|
|
|
+ if 'comparison_result' in st.session_state and st.session_state.comparison_result:
|
|
|
|
|
+ self.display_comparison_results(st.session_state.comparison_result['comparison_result'])
|
|
|
|
|
+ elif comparison_result_path.exists():
|
|
|
|
|
+ # 如果pre_validation_dir下有结果文件,提示用户加载
|
|
|
|
|
+ if st.button("加载预校验结果"):
|
|
|
|
|
+ with open(comparison_result_path, "r", encoding="utf-8") as f:
|
|
|
|
|
+ comparison_json_result = json.load(f)
|
|
|
|
|
+ comparison_result = {
|
|
|
|
|
+ "image_path": self.image_path,
|
|
|
|
|
+ "comparison_result_json": str(comparison_result_path),
|
|
|
|
|
+ "comparison_result_md": str(comparison_result_path.with_suffix('.md')),
|
|
|
|
|
+ "comparison_result": comparison_json_result
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ st.session_state.comparison_result = comparison_result
|
|
|
|
|
+ self.display_comparison_results(comparison_json_result)
|
|
|
|
|
+ else:
|
|
|
|
|
+ st.info("暂无预校验结果,请先运行VLM预校验")
|
|
|
|
|
+
|
|
|
def create_compact_layout(self, config):
|
|
def create_compact_layout(self, config):
|
|
|
"""创建滚动凑布局"""
|
|
"""创建滚动凑布局"""
|
|
|
return self.layout_manager.create_compact_layout(config)
|
|
return self.layout_manager.create_compact_layout(config)
|
|
|
|
|
|
|
|
|
|
+@st.dialog("message", width="small", dismissible=True, on_dismiss="rerun")
|
|
|
|
|
+def message_box(msg: str, msg_type: str = "info"):
|
|
|
|
|
+ if msg_type == "info":
|
|
|
|
|
+ st.info(msg)
|
|
|
|
|
+ elif msg_type == "warning":
|
|
|
|
|
+ st.warning(msg)
|
|
|
|
|
+ elif msg_type == "error":
|
|
|
|
|
+ st.error(msg)
|
|
|
|
|
|
|
|
def main():
|
|
def main():
|
|
|
"""主应用"""
|
|
"""主应用"""
|
|
@@ -377,7 +812,7 @@ def main():
|
|
|
if 'marked_errors' not in st.session_state:
|
|
if 'marked_errors' not in st.session_state:
|
|
|
st.session_state.marked_errors = set()
|
|
st.session_state.marked_errors = set()
|
|
|
|
|
|
|
|
- with st.container(height=100, horizontal=True, horizontal_alignment='left'):
|
|
|
|
|
|
|
+ with st.container(height=75, horizontal=True, horizontal_alignment='left', gap="medium"):
|
|
|
# st.subheader("📁 文件选择")
|
|
# st.subheader("📁 文件选择")
|
|
|
# 初始化session_state中的选择索引
|
|
# 初始化session_state中的选择索引
|
|
|
if 'selected_file_index' not in st.session_state:
|
|
if 'selected_file_index' not in st.session_state:
|
|
@@ -388,6 +823,7 @@ def main():
|
|
|
range(len(validator.display_options)),
|
|
range(len(validator.display_options)),
|
|
|
format_func=lambda i: validator.display_options[i],
|
|
format_func=lambda i: validator.display_options[i],
|
|
|
index=st.session_state.selected_file_index,
|
|
index=st.session_state.selected_file_index,
|
|
|
|
|
+ width=100,
|
|
|
key="selected_selectbox",
|
|
key="selected_selectbox",
|
|
|
label_visibility="collapsed")
|
|
label_visibility="collapsed")
|
|
|
# 更新session_state
|
|
# 更新session_state
|
|
@@ -401,6 +837,7 @@ def main():
|
|
|
current_page = validator.file_info[selected_index]['page']
|
|
current_page = validator.file_info[selected_index]['page']
|
|
|
page_input = st.number_input("输入一个数字",
|
|
page_input = st.number_input("输入一个数字",
|
|
|
placeholder="输入页码",
|
|
placeholder="输入页码",
|
|
|
|
|
+ width=200,
|
|
|
label_visibility="collapsed",
|
|
label_visibility="collapsed",
|
|
|
min_value=1, max_value=len(validator.display_options), value=current_page, step=1,
|
|
min_value=1, max_value=len(validator.display_options), value=current_page, step=1,
|
|
|
key="page_input"
|
|
key="page_input"
|
|
@@ -430,13 +867,15 @@ def main():
|
|
|
st.warning("未找到OCR结果文件")
|
|
st.warning("未找到OCR结果文件")
|
|
|
st.info("请确保output目录下有OCR结果文件")
|
|
st.info("请确保output目录下有OCR结果文件")
|
|
|
|
|
|
|
|
- if st.button("🧹 清除选择"):
|
|
|
|
|
- st.session_state.selected_text = None
|
|
|
|
|
- st.rerun()
|
|
|
|
|
-
|
|
|
|
|
- if st.button("❌ 清除错误标记"):
|
|
|
|
|
- st.session_state.marked_errors = set()
|
|
|
|
|
- st.rerun()
|
|
|
|
|
|
|
+ if st.button("VLM预校验", type="primary", icon=":material/compare_arrows:"):
|
|
|
|
|
+ if validator.image_path and validator.md_content:
|
|
|
|
|
+ # 创建新的页面区域来显示VLM预校验结果
|
|
|
|
|
+ validator.vlm_pre_validation()
|
|
|
|
|
+ else:
|
|
|
|
|
+ message_box("❌ 请先加载OCR数据文件", "error")
|
|
|
|
|
+
|
|
|
|
|
+ if st.button("查看预校验结果", type="secondary", icon=":material/quick_reference_all:"):
|
|
|
|
|
+ validator.show_comparison_results_dialog()
|
|
|
|
|
|
|
|
with st.expander("🔧 OCR工具统计信息", expanded=False):
|
|
with st.expander("🔧 OCR工具统计信息", expanded=False):
|
|
|
# 显示统计信息
|
|
# 显示统计信息
|