Просмотр исходного кода

添加VLM预校验功能,封装OCR识别和结果对比,增强用户交互体验

zhch158_admin 2 месяцев назад
Родитель
Сommit
c688df2cdb
1 измененных файлов с 448 добавлено и 9 удалено
  1. 448 9
      streamlit_ocr_validator.py

+ 448 - 9
streamlit_ocr_validator.py

@@ -13,6 +13,7 @@ from io import BytesIO
 import pandas as pd
 import numpy as np
 import plotly.express as px
+import json
 
 # 导入工具模块
 from ocr_validator_utils import (
@@ -22,6 +23,8 @@ from ocr_validator_utils import (
     export_tables_to_excel, get_table_statistics, group_texts_by_category
 )
 from ocr_validator_layout import OCRLayoutManager
+from ocr_by_vlm import ocr_with_vlm
+from compare_ocr_results import compare_ocr_results
 
 
 class StreamlitOCRValidator:
@@ -349,11 +352,443 @@ class StreamlitOCRValidator:
         else:  # 完整显示
             return table
     
-    # 布局方法现在委托给布局管理器
+    @st.dialog("VLM预校验", width="large", dismissible=True, on_dismiss="rerun")
+    def vlm_pre_validation(self):
+        """VLM预校验功能 - 封装OCR识别和结果对比"""
+        
+        if not self.image_path or not self.md_content:
+            st.error("❌ 请先加载OCR数据文件")
+            return
+        # 初始化对比结果存储
+        if 'comparison_result' not in st.session_state:
+            st.session_state.comparison_result = None
+
+        # 创建进度条和状态显示
+        with st.spinner("正在进行VLM预校验...", show_time=True):
+            status_text = st.empty()
+            
+            try:
+                current_md_path = Path(self.file_paths[self.selected_file_index]).with_suffix('.md')
+                if not current_md_path.exists():
+                    st.error("❌ 当前OCR结果的Markdown文件不存在,无法进行对比")
+                    return
+                # 第一步:准备目录
+                pre_validation_dir = Path(self.config['paths'].get('pre_validation_dir', './output/pre_validation/')).resolve()
+                pre_validation_dir.mkdir(parents=True, exist_ok=True)
+                status_text.write(f"工作目录: {pre_validation_dir}")
+
+                # 第二步:调用VLM进行OCR识别
+                status_text.text("🤖 正在调用VLM进行OCR识别...")
+                
+                # 在expander中显示OCR过程
+                with st.expander("🔍 VLM OCR识别过程", expanded=True):
+                    ocr_output = st.empty()
+                    
+                    # 捕获OCR输出
+                    import io
+                    import contextlib
+                    
+                    # 创建字符串缓冲区来捕获print输出
+                    output_buffer = io.StringIO()
+                    
+                    with contextlib.redirect_stdout(output_buffer):
+                        ocr_result = ocr_with_vlm(
+                            image_path=str(self.image_path),
+                            output_dir=str(pre_validation_dir),
+                            normalize_numbers=True
+                        )
+                    
+                    # 显示OCR过程输出
+                    ocr_output.code(output_buffer.getvalue(), language='text')
+                    
+                    status_text.text("✅ VLM OCR识别完成")
+                    
+                    # 第三步:获取VLM生成的文件路径
+                    vlm_md_path = pre_validation_dir / f"{Path(self.image_path).stem}.md"
+                    
+                    if not vlm_md_path.exists():
+                        st.error("❌ VLM OCR结果文件未生成")
+                        return
+                    
+                    # 第四步:调用对比功能
+                    status_text.text("📊 正在对比OCR结果...")
+                    
+                    # 在expander中显示对比过程
+                    comparison_result_path = pre_validation_dir / f"{current_md_path.stem}_comparison_result"
+                    with st.expander("🔍 OCR结果对比过程", expanded=True):
+                        compare_output = st.empty()
+                        
+                        # 捕获对比输出
+                        output_buffer = io.StringIO()
+                        
+                        with contextlib.redirect_stdout(output_buffer):
+                            comparison_result = compare_ocr_results(
+                                file1_path=str(current_md_path),
+                                file2_path=str(vlm_md_path),
+                                output_file=str(comparison_result_path),
+                                output_format='both',
+                                ignore_images=True
+                            )
+                        
+                        # 显示对比过程输出
+                        compare_output.code(output_buffer.getvalue(), language='text')
+                    
+                    status_text.text("✅ VLM预校验完成")
+
+                    st.session_state.comparison_result = {
+                        "image_path": self.image_path,
+                        "comparison_result_json": f"{comparison_result_path}.json",
+                        "comparison_result_md": f"{comparison_result_path}.md",
+                        "comparison_result": comparison_result
+                    }
+                
+                # 第五步:显示对比结果
+                self.display_comparison_results(comparison_result)
+                
+                # 第六步:提供文件下载
+                # self.provide_download_options(pre_validation_dir, vlm_md_path, comparison_result)
+                
+            except Exception as e:
+                st.error(f"❌ VLM预校验失败: {e}")
+                st.exception(e)
+    
+    def display_comparison_results(self, comparison_result: dict):
+        """显示对比结果摘要 - 使用DataFrame展示"""
+        
+        st.header("📊 VLM预校验结果")
+        
+        # 统计信息
+        stats = comparison_result['statistics']
+        
+        # 统计信息概览
+        col1, col2, col3, col4 = st.columns(4)
+        with col1:
+            st.metric("总差异数", stats['total_differences'])
+        with col2:
+            st.metric("表格差异", stats['table_differences'])
+        with col3:
+            st.metric("金额差异", stats['amount_differences'])
+        with col4:
+            st.metric("段落差异", stats['paragraph_differences'])
+        
+        # 结果判断
+        if stats['total_differences'] == 0:
+            st.success("🎉 完美匹配!VLM识别结果与原OCR结果完全一致")
+        else:
+            st.warning(f"⚠️ 发现 {stats['total_differences']} 个差异,建议人工检查")
+            
+            # 使用DataFrame显示差异详情
+            if comparison_result['differences']:
+                st.subheader("🔍 差异详情对比")
+                
+                # 准备DataFrame数据
+                diff_data = []
+                for i, diff in enumerate(comparison_result['differences'], 1):
+                    diff_data.append({
+                        '序号': i,
+                        '位置': diff['position'],
+                        '类型': diff['type'],
+                        '原OCR结果': diff['file1_value'][:100] + ('...' if len(diff['file1_value']) > 100 else ''),
+                        'VLM识别结果': diff['file2_value'][:100] + ('...' if len(diff['file2_value']) > 100 else ''),
+                        '描述': diff['description'][:80] + ('...' if len(diff['description']) > 80 else ''),
+                        '严重程度': self._get_severity_level(diff)
+                    })
+                
+                # 创建DataFrame
+                df_differences = pd.DataFrame(diff_data)
+                
+                # 添加样式
+                def highlight_severity(val):
+                    """根据严重程度添加颜色"""
+                    if val == '高':
+                        return 'background-color: #ffebee; color: #c62828'
+                    elif val == '中':
+                        return 'background-color: #fff3e0; color: #ef6c00'
+                    elif val == '低':
+                        return 'background-color: #e8f5e8; color: #2e7d32'
+                    return ''
+                
+                # 显示DataFrame
+                styled_df = df_differences.style.applymap(
+                    highlight_severity, 
+                    subset=['严重程度']
+                ).format({
+                    '序号': '{:d}',
+                })
+                
+                st.dataframe(
+                    styled_df, 
+                    use_container_width=True,
+                    height=400,
+                    hide_index=True,
+                    column_config={
+                        "序号": st.column_config.NumberColumn(
+                            "序号", 
+                            width=None,  # 自动调整宽度
+                            pinned=True,
+                            help="差异项序号"
+                        ),
+                        "位置": st.column_config.TextColumn(
+                            "位置", 
+                            width=None,  # 自动调整宽度
+                            pinned=True,
+                            help="差异在文档中的位置"
+                        ),
+                        "类型": st.column_config.TextColumn(
+                            "类型", 
+                            width=None,  # 自动调整宽度
+                            pinned=True,
+                            help="差异类型"
+                        ),
+                        "原OCR结果": st.column_config.TextColumn(
+                            "原OCR结果", 
+                            width="large",  # 自动调整宽度
+                            pinned=True,
+                            help="原始OCR识别结果"
+                        ),
+                        "VLM识别结果": st.column_config.TextColumn(
+                            "VLM识别结果", 
+                            width="large",  # 自动调整宽度
+                            help="VLM重新识别的结果"
+                        ),
+                        "描述": st.column_config.TextColumn(
+                            "描述", 
+                            width="medium",  # 自动调整宽度
+                            help="差异详细描述"
+                        ),
+                        "严重程度": st.column_config.TextColumn(
+                            "严重程度", 
+                            width=None,  # 自动调整宽度
+                            help="差异严重程度评级"
+                        )
+                    }
+                )
+                
+                # 详细差异查看
+                st.subheader("🔍 详细差异查看")
+                
+                # 选择要查看的差异
+                selected_diff_index = st.selectbox(
+                    "选择要查看的差异:",
+                    options=range(len(comparison_result['differences'])),
+                    format_func=lambda x: f"差异 {x+1}: {comparison_result['differences'][x]['position']} - {comparison_result['differences'][x]['type']}",
+                    key="selected_diff"
+                )
+                
+                if selected_diff_index is not None:
+                    diff = comparison_result['differences'][selected_diff_index]
+                    
+                    # 并排显示完整内容
+                    col1, col2 = st.columns(2)
+                    
+                    with col1:
+                        st.write("**原OCR结果:**")
+                        st.text_area(
+                            "原OCR结果详情",
+                            value=diff['file1_value'],
+                            height=200,
+                            key=f"original_{selected_diff_index}",
+                            label_visibility="collapsed"
+                        )
+                    
+                    with col2:
+                        st.write("**VLM识别结果:**")
+                        st.text_area(
+                            "VLM识别结果详情",
+                            value=diff['file2_value'],
+                            height=200,
+                            key=f"vlm_{selected_diff_index}",
+                            label_visibility="collapsed"
+                        )
+                    
+                    # 差异详细信息
+                    st.info(f"**位置:** {diff['position']}")
+                    st.info(f"**类型:** {diff['type']}")
+                    st.info(f"**描述:** {diff['description']}")
+                    st.info(f"**严重程度:** {self._get_severity_level(diff)}")
+                
+                # 差异统计图表
+                st.subheader("📈 差异类型分布")
+                
+                # 按类型统计差异
+                type_counts = {}
+                severity_counts = {'高': 0, '中': 0, '低': 0}
+                
+                for diff in comparison_result['differences']:
+                    diff_type = diff['type']
+                    type_counts[diff_type] = type_counts.get(diff_type, 0) + 1
+                    
+                    severity = self._get_severity_level(diff)
+                    severity_counts[severity] += 1
+                
+                col1, col2 = st.columns(2)
+                
+                with col1:
+                    # 类型分布饼图
+                    if type_counts:
+                        fig_type = px.pie(
+                            values=list(type_counts.values()),
+                            names=list(type_counts.keys()),
+                            title="差异类型分布"
+                        )
+                        st.plotly_chart(fig_type, use_container_width=True)
+                
+                with col2:
+                    # 严重程度分布条形图
+                    fig_severity = px.bar(
+                        x=list(severity_counts.keys()),
+                        y=list(severity_counts.values()),
+                        title="差异严重程度分布",
+                        color=list(severity_counts.keys()),
+                        color_discrete_map={'高': '#f44336', '中': '#ff9800', '低': '#4caf50'}
+                    )
+                    st.plotly_chart(fig_severity, use_container_width=True)
+        
+        # 下载选项
+        self._provide_download_options_in_results(comparison_result)
+
+    def _get_severity_level(self, diff: dict) -> str:
+        """根据差异类型和内容判断严重程度"""
+        diff_type = diff['type'].lower()
+        
+        # 金额相关差异为高严重程度
+        if 'amount' in diff_type or 'number' in diff_type:
+            return '高'
+        
+        # 表格结构差异为中等严重程度
+        if 'table' in diff_type or 'structure' in diff_type:
+            return '中'
+        
+        # 检查内容长度差异
+        len_diff = abs(len(diff['file1_value']) - len(diff['file2_value']))
+        if len_diff > 50:
+            return '高'
+        elif len_diff > 10:
+            return '中'
+        else:
+            return '低'
+
+    def _provide_download_options_in_results(self, comparison_result: dict):
+        """在结果页面提供下载选项"""
+        
+        st.subheader("📥 导出预校验结果")
+        
+        col1, col2, col3 = st.columns(3)
+        
+        with col1:
+            # 导出差异详情为Excel
+            if comparison_result['differences']:
+                diff_data = []
+                for i, diff in enumerate(comparison_result['differences'], 1):
+                    diff_data.append({
+                        '序号': i,
+                        '位置': diff['position'],
+                        '类型': diff['type'],
+                        '原OCR结果': diff['file1_value'],
+                        'VLM识别结果': diff['file2_value'],
+                        '描述': diff['description'],
+                        '严重程度': self._get_severity_level(diff)
+                    })
+                
+                df_export = pd.DataFrame(diff_data)
+                excel_buffer = BytesIO()
+                df_export.to_excel(excel_buffer, index=False, sheet_name='差异详情')
+                
+                st.download_button(
+                    label="📊 下载差异详情(Excel)",
+                    data=excel_buffer.getvalue(),
+                    file_name=f"vlm_comparison_differences_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.xlsx",
+                    mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
+                    key="download_differences_excel"
+                )
+        
+        with col2:
+            # 导出统计报告
+            stats_data = {
+                '统计项目': ['总差异数', '表格差异', '金额差异', '段落差异'],
+                '数量': [
+                    comparison_result['statistics']['total_differences'],
+                    comparison_result['statistics']['table_differences'],
+                    comparison_result['statistics']['amount_differences'],
+                    comparison_result['statistics']['paragraph_differences']
+                ]
+            }
+            
+            df_stats = pd.DataFrame(stats_data)
+            csv_stats = df_stats.to_csv(index=False)
+            
+            st.download_button(
+                label="📈 下载统计报告(CSV)",
+                data=csv_stats,
+                file_name=f"vlm_comparison_stats_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.csv",
+                mime="text/csv",
+                key="download_stats_csv"
+            )
+        
+        with col3:
+            # 导出完整报告为JSON
+            import json
+            
+            report_json = json.dumps(comparison_result, ensure_ascii=False, indent=2)
+            
+            st.download_button(
+                label="📄 下载完整报告(JSON)",
+                data=report_json,
+                file_name=f"vlm_comparison_full_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.json",
+                mime="application/json",
+                key="download_full_json"
+            )
+        
+        # 操作建议
+        st.subheader("🚀 后续操作建议")
+        
+        total_diffs = comparison_result['statistics']['total_differences']
+        if total_diffs == 0:
+            st.success("✅ VLM识别结果与原OCR完全一致,可信度很高,无需人工校验")
+        elif total_diffs <= 5:
+            st.warning("⚠️ 发现少量差异,建议重点检查高严重程度的差异项")
+        elif total_diffs <= 20:
+            st.warning("🔍 发现中等数量差异,建议详细检查差异表格中标红的项目")
+        else:
+            st.error("❌ 发现大量差异,建议重新进行OCR识别或检查原始图片质量")
+    
+    @st.dialog("查看预校验结果", width="large", dismissible=True, on_dismiss="rerun")
+    def show_comparison_results_dialog(self):
+        """显示VLM预校验结果的对话框"""
+        current_md_path = Path(self.file_paths[self.selected_file_index]).with_suffix('.md')
+        pre_validation_dir = Path(self.config['paths'].get('pre_validation_dir', './output/pre_validation/')).resolve()
+        comparison_result_path = pre_validation_dir / f"{current_md_path.stem}_comparison_result.json"
+        if 'comparison_result' in st.session_state and st.session_state.comparison_result:
+            self.display_comparison_results(st.session_state.comparison_result['comparison_result'])
+        elif comparison_result_path.exists():
+            # 如果pre_validation_dir下有结果文件,提示用户加载
+            if st.button("加载预校验结果"):
+                with open(comparison_result_path, "r", encoding="utf-8") as f:
+                    comparison_json_result = json.load(f)
+                comparison_result = {
+                    "image_path": self.image_path,
+                    "comparison_result_json": str(comparison_result_path),
+                    "comparison_result_md": str(comparison_result_path.with_suffix('.md')),
+                    "comparison_result": comparison_json_result
+                }
+                    
+                st.session_state.comparison_result = comparison_result
+                self.display_comparison_results(comparison_json_result)
+        else:
+            st.info("暂无预校验结果,请先运行VLM预校验")
+
     def create_compact_layout(self, config):
         """创建滚动凑布局"""
         return self.layout_manager.create_compact_layout(config)
 
+@st.dialog("message", width="small", dismissible=True, on_dismiss="rerun")
+def message_box(msg: str, msg_type: str = "info"):
+    if msg_type == "info":
+        st.info(msg)
+    elif msg_type == "warning":
+        st.warning(msg)
+    elif msg_type == "error":
+        st.error(msg)
 
 def main():
     """主应用"""
@@ -377,7 +812,7 @@ def main():
     if 'marked_errors' not in st.session_state:
         st.session_state.marked_errors = set()
     
-    with st.container(height=100, horizontal=True, horizontal_alignment='left'):
+    with st.container(height=75, horizontal=True, horizontal_alignment='left', gap="medium"):
         # st.subheader("📁 文件选择")
         # 初始化session_state中的选择索引
         if 'selected_file_index' not in st.session_state:
@@ -388,6 +823,7 @@ def main():
                 range(len(validator.display_options)),
                 format_func=lambda i: validator.display_options[i],
                 index=st.session_state.selected_file_index,
+                width=100,
                 key="selected_selectbox",
                 label_visibility="collapsed")
             # 更新session_state
@@ -401,6 +837,7 @@ def main():
             current_page = validator.file_info[selected_index]['page']
             page_input = st.number_input("输入一个数字", 
                 placeholder="输入页码", 
+                width=200,
                 label_visibility="collapsed",
                 min_value=1, max_value=len(validator.display_options), value=current_page, step=1,
                 key="page_input"
@@ -430,13 +867,15 @@ def main():
             st.warning("未找到OCR结果文件")
             st.info("请确保output目录下有OCR结果文件")
         
-        if st.button("🧹 清除选择"):
-            st.session_state.selected_text = None
-            st.rerun()
-        
-        if st.button("❌ 清除错误标记"):
-            st.session_state.marked_errors = set()
-            st.rerun()
+        if st.button("VLM预校验", type="primary", icon=":material/compare_arrows:"):
+            if validator.image_path and validator.md_content:
+                # 创建新的页面区域来显示VLM预校验结果
+                validator.vlm_pre_validation()
+            else:
+                message_box("❌ 请先加载OCR数据文件", "error")
+
+        if st.button("查看预校验结果", type="secondary", icon=":material/quick_reference_all:"):
+            validator.show_comparison_results_dialog()
 
     with st.expander("🔧 OCR工具统计信息", expanded=False):
         # 显示统计信息