Переглянути джерело

feat: 新增交叉验证功能,支持选择不同数据源进行OCR结果比对

zhch158_admin 1 місяць тому
батько
коміт
1a1388c344
1 змінених файлів з 241 додано та 181 видалено
  1. 241 181
      streamlit_ocr_validator.py

+ 241 - 181
streamlit_ocr_validator.py

@@ -55,6 +55,13 @@ class StreamlitOCRValidator:
         self.selected_file_index = -1
         self.display_options = []
         self.file_paths = []
+        
+        # ✅ 新增:交叉验证数据源
+        self.verify_source_key = None
+        self.verify_source_config = None
+        self.verify_file_info = []
+        self.verify_display_options = []
+        self.verify_file_paths = []
 
         # 初始化布局管理器
         self.layout_manager = OCRLayoutManager(self)
@@ -66,13 +73,18 @@ class StreamlitOCRValidator:
         """加载多数据源文件信息"""
         self.all_sources = find_available_ocr_files_multi_source(self.config)
         
-        # 如果有数据源,默认选择第一个
+        # 如果有数据源,默认选择第一个作为OCR源
         if self.all_sources:
-            first_source_key = list(self.all_sources.keys())[0]
+            source_keys = list(self.all_sources.keys())
+            first_source_key = source_keys[0]
             self.switch_to_source(first_source_key)
+            
+            # 如果有第二个数据源,默认作为验证源
+            if len(source_keys) > 1:
+                self.switch_to_verify_source(source_keys[1])
     
     def switch_to_source(self, source_key: str):
-        """切换到指定数据源"""
+        """切换到指定OCR数据源"""
         if source_key in self.all_sources:
             self.current_source_key = source_key
             source_data = self.all_sources[source_key]
@@ -86,11 +98,25 @@ class StreamlitOCRValidator:
                 
                 # 重置文件选择
                 self.selected_file_index = -1
-                
-                print(f"✅ 切换到数据源: {source_key}")
+                print(f"✅ 切换到OCR数据源: {source_key}")
             else:
                 print(f"⚠️ 数据源 {source_key} 没有可用文件")
     
+    def switch_to_verify_source(self, source_key: str):
+        """切换到指定验证数据源"""
+        if source_key in self.all_sources:
+            self.verify_source_key = source_key
+            source_data = self.all_sources[source_key]
+            self.verify_source_config = source_data['config']
+            self.verify_file_info = source_data['files']
+            
+            if self.verify_file_info:
+                self.verify_display_options = [f"{info['display_name']}" for info in self.verify_file_info]
+                self.verify_file_paths = [info['path'] for info in self.verify_file_info]
+                print(f"✅ 切换到验证数据源: {source_key}")
+            else:
+                print(f"⚠️ 验证数据源 {source_key} 没有可用文件")
+
     def setup_page_config(self):
         """设置页面配置"""
         ui_config = self.config['ui']
@@ -106,56 +132,91 @@ class StreamlitOCRValidator:
         st.markdown(f"<style>{css_content}</style>", unsafe_allow_html=True)
 
     def create_data_source_selector(self):
-        """创建数据源选择器"""
+        """创建数据源选择器 - 支持交叉验证"""
         if not self.all_sources:
             st.warning("❌ 未找到任何数据源,请检查配置文件")
             return
         
-        # 数据源选择
+        # 准备数据源选项
         source_options = {}
         for source_key, source_data in self.all_sources.items():
             display_name = get_data_source_display_name(source_data['config'])
             source_options[display_name] = source_key
         
-        # 获取当前选择的显示名称
-        current_display_name = None
-        if self.current_source_key:
-            for display_name, key in source_options.items():
-                if key == self.current_source_key:
-                    current_display_name = display_name
-                    break
-        
-        selected_display_name = st.selectbox(
-            "📁 选择数据源",
-            options=list(source_options.keys()),
-            index=list(source_options.keys()).index(current_display_name) if current_display_name else 0,
-            key="data_source_selector",
-            help="选择要分析的OCR数据源"
-        )
+        # 创建两列布局
+        col1, col2 = st.columns(2)
         
-        selected_source_key = source_options[selected_display_name]
+        with col1:
+            st.markdown("#### 📊 OCR数据源")
+            # OCR数据源选择
+            current_display_name = None
+            if self.current_source_key:
+                for display_name, key in source_options.items():
+                    if key == self.current_source_key:
+                        current_display_name = display_name
+                        break
+            
+            selected_ocr_display = st.selectbox(
+                "选择OCR数据源",
+                options=list(source_options.keys()),
+                index=list(source_options.keys()).index(current_display_name) if current_display_name else 0,
+                key="ocr_source_selector",
+                label_visibility="collapsed",
+                help="选择要分析的OCR数据源"
+            )
+            
+            selected_ocr_key = source_options[selected_ocr_display]
+            
+            # 如果OCR数据源发生变化,切换数据源
+            if selected_ocr_key != self.current_source_key:
+                self.switch_to_source(selected_ocr_key)
+                if 'selected_file_index' in st.session_state:
+                    st.session_state.selected_file_index = 0
+                st.rerun()
+            
+            # 显示OCR数据源信息
+            if self.current_source_config:
+                with st.expander("📋 OCR数据源详情", expanded=False):
+                    st.write(f"**工具:** {self.current_source_config['ocr_tool']}")
+                    st.write(f"**文件数:** {len(self.file_info)}")
         
-        # 如果数据源发生变化,切换数据源
-        if selected_source_key != self.current_source_key:
-            self.switch_to_source(selected_source_key)
-            # 重置session state
-            if 'selected_file_index' in st.session_state:
-                st.session_state.selected_file_index = 0
-            st.rerun()
+        with col2:
+            st.markdown("#### 🔍 验证数据源")
+            # 验证数据源选择
+            verify_display_name = None
+            if self.verify_source_key:
+                for display_name, key in source_options.items():
+                    if key == self.verify_source_key:
+                        verify_display_name = display_name
+                        break
+            
+            selected_verify_display = st.selectbox(
+                "选择验证数据源",
+                options=list(source_options.keys()),
+                index=list(source_options.keys()).index(verify_display_name) if verify_display_name else (1 if len(source_options) > 1 else 0),
+                key="verify_source_selector",
+                label_visibility="collapsed",
+                help="选择用于交叉验证的数据源"
+            )
+            
+            selected_verify_key = source_options[selected_verify_display]
+            
+            # 如果验证数据源发生变化,切换数据源
+            if selected_verify_key != self.verify_source_key:
+                self.switch_to_verify_source(selected_verify_key)
+                st.rerun()
+            
+            # 显示验证数据源信息
+            if self.verify_source_config:
+                with st.expander("📋 验证数据源详情", expanded=False):
+                    st.write(f"**工具:** {self.verify_source_config['ocr_tool']}")
+                    st.write(f"**文件数:** {len(self.verify_file_info)}")
         
-        # 显示数据源信息
-        if self.current_source_config:
-            with st.expander("📋 数据源详情", expanded=False):
-                col1, col2, col3 = st.columns(3)
-                with col1:
-                    st.write(f"**名称:** {self.current_source_config['name']}")
-                    st.write(f"**OCR工具:** {self.current_source_config['ocr_tool']}")
-                with col2:
-                    st.write(f"**输出目录:** {self.current_source_config['ocr_out_dir']}")
-                    st.write(f"**图片目录:** {self.current_source_config.get('src_img_dir', 'N/A')}")
-                with col3:
-                    st.write(f"**描述:** {self.current_source_config.get('description', 'N/A')}")
-                    st.write(f"**文件数量:** {len(self.file_info)}")
+        # 数据源对比提示
+        if self.current_source_key == self.verify_source_key:
+            st.warning("⚠️ OCR数据源和验证数据源相同,建议选择不同的数据源进行交叉验证")
+        else:
+            st.success(f"✅ 已选择 {selected_ocr_display} 与 {selected_verify_display} 进行交叉验证")    
     
     def load_ocr_data(self, json_path: str, md_path: Optional[str] = None, image_path: Optional[str] = None):
         """加载OCR相关数据 - 支持多数据源配置"""
@@ -456,107 +517,151 @@ class StreamlitOCRValidator:
             
         else:  # 完整显示
             return table
-    
-    @st.dialog("VLM预校验", width="large", dismissible=True, on_dismiss="rerun")
-    def vlm_pre_validation(self):
-        """VLM预校验功能 - 封装OCR识别和结果对比"""
+
+    def find_verify_md_path(self, selected_file_index: int) -> Optional[Path]:
+        """查找当前OCR文件对应的验证文件路径"""
+        current_page = self.file_info[selected_file_index]['page']
+        verify_md_path = None
+
+        for i, info in enumerate(self.verify_file_info):
+            if info['page'] == current_page:
+                verify_md_path = Path(self.verify_file_paths[i]).with_suffix('.md')
+                break
+
+        return verify_md_path
+
+    @st.dialog("交叉验证", width="large", dismissible=True, on_dismiss="rerun")
+    def cross_validation(self):
+        """交叉验证功能 - 比对两个数据源的OCR结果"""
         
         if not self.image_path or not self.md_content:
             st.error("❌ 请先加载OCR数据文件")
             return
+        if self.current_source_key == self.verify_source_key:
+            st.error("❌ OCR数据源和验证数据源不能相同")
+            return
         # 初始化对比结果存储
-        if 'comparison_result' not in st.session_state:
-            st.session_state.comparison_result = None
+        if 'cross_validation_result' not in st.session_state:
+            st.session_state.cross_validation_result = None
+        
+        # 初始化对比结果存储
+        if 'cross_validation_result' not in st.session_state:
+            st.session_state.cross_validation_result = None
 
         # 创建进度条和状态显示
-        with st.spinner("正在进行VLM预校验...", show_time=True):
+        with st.spinner("正在进行交叉验证...", show_time=True):
             status_text = st.empty()
             
             try:
+                # 第一步:获取当前OCR结果文件路径
                 current_md_path = Path(self.file_paths[self.selected_file_index]).with_suffix('.md')
                 if not current_md_path.exists():
-                    st.error("❌ 当前OCR结果的Markdown文件不存在,无法进行对比")
+                    st.error("❌ 当前OCR结果的Markdown文件不存在")
                     return
-                # 第一步:准备目录
+                
+                status_text.text(f"📄 OCR文件: {current_md_path.name}")
+                
+                # 第二步:查找对应的验证文件
+                verify_md_path = self.find_verify_md_path(self.selected_file_index)
+                
+                if not verify_md_path or not verify_md_path.exists():
+                    st.error(f"❌ 未找到验证数据源中第{current_md_path}页的对应文件")
+                    return
+                
+                status_text.text(f"🔍 验证文件: {verify_md_path.name}")
+                
+                # 第三步:准备输出目录
                 pre_validation_dir = Path(self.config['pre_validation'].get('out_dir', './output/pre_validation/')).resolve()
                 pre_validation_dir.mkdir(parents=True, exist_ok=True)
-                status_text.write(f"工作目录: {pre_validation_dir}")
-
-                # 第二步:调用VLM进行OCR识别
-                status_text.text("🤖 正在调用VLM进行OCR识别...")
                 
-                # 在expander中显示OCR过程
-                with st.expander("🔍 VLM OCR识别过程", expanded=True):
-                    ocr_output = st.empty()
+                # 第四步:调用对比功能
+                status_text.text("📊 正在对比OCR结果...")
+                
+                comparison_result_path = pre_validation_dir / f"{current_md_path.stem}_cross_validation"
+                
+                # 在expander中显示对比过程
+                with st.expander("🔍 交叉验证对比过程", expanded=True):
+                    compare_output = st.empty()
                     
-                    # 捕获OCR输出
+                    # 捕获对比输出
                     import io
                     import contextlib
                     
-                    # 创建字符串缓冲区来捕获print输出
                     output_buffer = io.StringIO()
                     
                     with contextlib.redirect_stdout(output_buffer):
-                        ocr_result = ocr_with_vlm(
-                            image_path=str(self.image_path),
-                            output_dir=str(pre_validation_dir),
-                            normalize_numbers=True
+                        comparison_result = compare_ocr_results(
+                            file1_path=str(current_md_path),
+                            file2_path=str(verify_md_path),
+                            output_file=str(comparison_result_path),
+                            output_format='both',
+                            ignore_images=True,
+                            table_mode='flow_list',  # ✅ 使用流水表格模式
+                            similarity_algorithm='ratio'
                         )
                     
-                    # 显示OCR过程输出
-                    ocr_output.code(output_buffer.getvalue(), language='text')
-                    
-                    status_text.text("✅ VLM OCR识别完成")
-                    
-                    # 第三步:获取VLM生成的文件路径
-                    vlm_md_path = pre_validation_dir / f"{Path(self.image_path).stem}.md"
-                    
-                    if not vlm_md_path.exists():
-                        st.error("❌ VLM OCR结果文件未生成")
-                        return
-                    
-                    # 第四步:调用对比功能
-                    status_text.text("📊 正在对比OCR结果...")
-                    
-                    # 在expander中显示对比过程
-                    comparison_result_path = pre_validation_dir / f"{current_md_path.stem}_comparison_result"
-                    with st.expander("🔍 OCR结果对比过程", expanded=True):
-                        compare_output = st.empty()
-                        
-                        # 捕获对比输出
-                        output_buffer = io.StringIO()
-                        
-                        with contextlib.redirect_stdout(output_buffer):
-                            comparison_result = compare_ocr_results(
-                                file1_path=str(current_md_path),
-                                file2_path=str(vlm_md_path),
-                                output_file=str(comparison_result_path),
-                                output_format='both',
-                                ignore_images=True
-                            )
-                        
-                        # 显示对比过程输出
-                        compare_output.code(output_buffer.getvalue(), language='text')
-                    
-                    status_text.text("✅ VLM预校验完成")
+                    # 显示对比过程输出
+                    compare_output.code(output_buffer.getvalue(), language='text')
+                
+                status_text.text("✅ 交叉验证完成")
 
-                    st.session_state.comparison_result = {
-                        "image_path": self.image_path,
-                        "comparison_result_json": f"{comparison_result_path}.json",
-                        "comparison_result_md": f"{comparison_result_path}.md",
-                        "comparison_result": comparison_result
-                    }
+                st.session_state.cross_validation_result = {
+                    "ocr_source": get_data_source_display_name(self.current_source_config),
+                    "verify_source": get_data_source_display_name(self.verify_source_config),
+                    "ocr_file": str(current_md_path),
+                    "verify_file": str(verify_md_path),
+                    "comparison_result_json": f"{comparison_result_path}.json",
+                    "comparison_result_md": f"{comparison_result_path}.md",
+                    "comparison_result": comparison_result
+                }
                 
                 # 第五步:显示对比结果
                 self.display_comparison_results(comparison_result, detailed=False)
                 
-                # 第六步:提供文件下载
-                # self.provide_download_options(pre_validation_dir, vlm_md_path, comparison_result)
-                
             except Exception as e:
-                st.error(f"❌ VLM预校验失败: {e}")
+                st.error(f"❌ 交叉验证失败: {e}")
                 st.exception(e)
-    
+
+    @st.dialog("查看交叉验证结果", width="large", dismissible=True, on_dismiss="rerun")
+    def show_cross_validation_results_dialog(self):
+        """显示交叉验证结果的对话框"""
+        current_md_path = Path(self.file_paths[self.selected_file_index]).with_suffix('.md')
+        pre_validation_dir = Path(self.config['pre_validation'].get('out_dir', './output/pre_validation/')).resolve()
+        comparison_result_path = pre_validation_dir / f"{current_md_path.stem}_cross_validation.json"
+        
+        if 'cross_validation_result' in st.session_state and st.session_state.cross_validation_result:
+            result = st.session_state.cross_validation_result
+            
+            # 显示数据源信息
+            col1, col2 = st.columns(2)
+            with col1:
+                st.info(f"**OCR数据源:** {result['ocr_source']}")
+            with col2:
+                st.info(f"**验证数据源:** {result['verify_source']}")
+            
+            self.display_comparison_results(result['comparison_result'])
+            
+        elif comparison_result_path.exists():
+            # 如果有历史结果文件,提示加载
+            if st.button("📂 加载历史验证结果"):
+                with open(comparison_result_path, "r", encoding="utf-8") as f:
+                    comparison_json_result = json.load(f)
+                
+                cross_validation_result = {
+                    "ocr_source": get_data_source_display_name(self.current_source_config),
+                    "verify_source": get_data_source_display_name(self.verify_source_config),
+                    "ocr_file": comparison_json_result['file1_path'],
+                    "verify_file": comparison_json_result['file2_path'],
+                    "comparison_result_json": str(comparison_result_path),
+                    "comparison_result_md": str(comparison_result_path.with_suffix('.md')),
+                    "comparison_result": comparison_json_result
+                }
+                
+                st.session_state.cross_validation_result = cross_validation_result
+                self.display_comparison_results(comparison_json_result)
+        else:
+            st.info("暂无交叉验证结果,请先运行交叉验证")
+
     def display_comparison_results(self, comparison_result: dict, detailed: bool = True):
         """显示对比结果摘要 - 使用DataFrame展示"""
         
@@ -698,9 +803,9 @@ class StreamlitOCRValidator:
                             )
                         
                         with col2:
-                            st.write("**VLM识别结果:**")
+                            st.write("**验证数据源识别结果:**")
                             st.text_area(
-                                "VLM识别结果详情",
+                                "验证数据源识别结果详情",
                                 value=diff['file2_value'],
                                 height=200,
                                 key=f"vlm_{selected_diff_index}",
@@ -875,31 +980,6 @@ class StreamlitOCRValidator:
         else:
             st.error("❌ 发现大量差异,建议重新进行OCR识别或检查原始图片质量")
     
-    @st.dialog("查看预校验结果", width="large", dismissible=True, on_dismiss="rerun")
-    def show_comparison_results_dialog(self):
-        """显示VLM预校验结果的对话框"""
-        current_md_path = Path(self.file_paths[self.selected_file_index]).with_suffix('.md')
-        pre_validation_dir = Path(self.config['pre_validation'].get('out_dir', './output/pre_validation/')).resolve()
-        comparison_result_path = pre_validation_dir / f"{current_md_path.stem}_comparison_result.json"
-        if 'comparison_result' in st.session_state and st.session_state.comparison_result:
-            self.display_comparison_results(st.session_state.comparison_result['comparison_result'])
-        elif comparison_result_path.exists():
-            # 如果pre_validation_dir下有结果文件,提示用户加载
-            if st.button("加载预校验结果"):
-                with open(comparison_result_path, "r", encoding="utf-8") as f:
-                    comparison_json_result = json.load(f)
-                comparison_result = {
-                    "image_path": self.image_path,
-                    "comparison_result_json": str(comparison_result_path),
-                    "comparison_result_md": str(comparison_result_path.with_suffix('.md')),
-                    "comparison_result": comparison_json_result
-                }
-                    
-                st.session_state.comparison_result = comparison_result
-                self.display_comparison_results(comparison_json_result)
-        else:
-            st.info("暂无预校验结果,请先运行VLM预校验")
-
     def create_compact_layout(self, config):
         """创建滚动凑布局"""
         return self.layout_manager.create_compact_layout(config)
@@ -999,17 +1079,17 @@ def main():
                 st.rerun()
         else:
             st.warning("当前数据源中未找到OCR结果文件")
-        
-        # VLM预校验按钮
-        if st.button("VLM预校验", type="primary", icon=":material/compare_arrows:"):
+
+        # 交叉验证按钮
+        if st.button("交叉验证", type="primary", icon=":material/compare_arrows:"):
             if validator.image_path and validator.md_content:
-                validator.vlm_pre_validation()
+                validator.cross_validation()
             else:
                 message_box("❌ 请先选择OCR数据文件", "error")
 
         # 查看预校验结果按钮
-        if st.button("查看预校验结果", type="secondary", icon=":material/quick_reference_all:"):
-            validator.show_comparison_results_dialog()
+        if st.button("查看验结果", type="secondary", icon=":material/quick_reference_all:"):
+            validator.show_cross_validation_results_dialog()
 
     # 显示当前数据源统计信息
     with st.expander("🔧 OCR工具统计信息", expanded=False):
@@ -1035,7 +1115,7 @@ def main():
             st.write("**详细信息:**", stats['tool_info'])
     
     # 其余标签页保持不变...
-    tab1, tab2, tab3 = st.tabs(["📄 内容校验", "📄 VLM预校验识别结果", "📊 表格分析"])
+    tab1, tab2, tab3 = st.tabs(["📄 内容人工检查", "🔍 交叉验证结果", "📊 表格分析"])
     
     with tab1:
         validator.create_compact_layout(config)
@@ -1044,9 +1124,15 @@ def main():
         # st.header("📄 VLM预校验识别结果")
         current_md_path = Path(validator.file_paths[validator.selected_file_index]).with_suffix('.md')
         pre_validation_dir = Path(validator.config['pre_validation'].get('out_dir', './output/pre_validation/')).resolve()
-        comparison_result_path = pre_validation_dir / f"{current_md_path.stem}_comparison_result.json"
-        pre_validation_path = pre_validation_dir / f"{current_md_path.stem}.md"
+        comparison_result_path = pre_validation_dir / f"{current_md_path.stem}_cross_validation.json"
+        # pre_validation_path = pre_validation_dir / f"{current_md_path.stem}.md"
+        verify_md_path = validator.find_verify_md_path(validator.selected_file_index)
+        
         if comparison_result_path.exists():
+            # 加载并显示验证结果
+            with open(comparison_result_path, "r", encoding="utf-8") as f:
+                comparison_result = json.load(f)
+
             # 左边显示OCR结果,右边显示VLM结果
             col1, col2 = st.columns([1,1])
             with col1:
@@ -1059,12 +1145,16 @@ def main():
                 validator.layout_manager.render_content_by_mode(original_md_content, "HTML渲染", font_size, height, layout_type)
             with col2:
                 st.subheader("🤖 VLM识别结果")
-                with open(pre_validation_path, "r", encoding="utf-8") as f:
-                    pre_validation_md_content = f.read()
+                with open(str(verify_md_path), "r", encoding="utf-8") as f:
+                    verify_md_content = f.read()
                 font_size = config['styles'].get('font_size', 10)
                 height = config['styles']['layout'].get('default_height', 800)
                 layout_type = "compact"
-                validator.layout_manager.render_content_by_mode(pre_validation_md_content, "HTML渲染", font_size, height, layout_type)
+                validator.layout_manager.render_content_by_mode(verify_md_content, "HTML渲染", font_size, height, layout_type)
+
+            # 显示差异统计
+            st.markdown("---")
+            validator.display_comparison_results(comparison_result, detailed=True)
         else:
             st.info("暂无预校验结果,请先运行VLM预校验")
 
@@ -1079,35 +1169,5 @@ def main():
         else:
             st.info("当前OCR结果中没有检测到表格数据")
     
-    # with tab4:
-    #     # 数据统计页面 - 保持原有逻辑
-    #     st.header("📈 OCR数据统计")
-        
-    #     # 添加数据源特定的统计信息
-    #     if validator.current_source_config:
-    #         st.subheader(f"📊 {get_data_source_display_name(validator.current_source_config)} - 统计信息")
-        
-    #     if stats['categories']:
-    #         st.subheader("📊 类别分布")
-    #         fig_pie = px.pie(
-    #             values=list(stats['categories'].values()),
-    #             names=list(stats['categories'].keys()),
-    #             title="文本类别分布"
-    #         )
-    #         st.plotly_chart(fig_pie, use_container_width=True)
-        
-    #     # 错误率分析
-    #     st.subheader("📈 质量分析")
-    #     accuracy_data = {
-    #         '状态': ['正确', '错误'],
-    #         '数量': [stats['clickable_texts'] - stats['marked_errors'], stats['marked_errors']]
-    #     }
-        
-    #     fig_bar = px.bar(
-    #         accuracy_data, x='状态', y='数量', title="识别质量分布",
-    #         color='状态', color_discrete_map={'正确': 'green', '错误': 'red'}
-    #     )
-    #     st.plotly_chart(fig_bar, use_container_width=True)
-    
 if __name__ == "__main__":
     main()