Răsfoiți Sursa

新增多数据源支持,重构文件加载逻辑,优化数据源选择器和统计信息显示

zhch158_admin 1 lună în urmă
părinte
comite
fb8276c6fb
1 a modificat fișierele cu 162 adăugiri și 53 ștergeri
  1. 162 53
      streamlit_ocr_validator.py

+ 162 - 53
streamlit_ocr_validator.py

@@ -20,7 +20,8 @@ from ocr_validator_utils import (
     load_config, load_css_styles, load_ocr_data_file, process_ocr_data,
     draw_bbox_on_image, get_ocr_statistics, convert_html_table_to_markdown,
     parse_html_tables, find_available_ocr_files, create_dynamic_css,
-    export_tables_to_excel, get_table_statistics, group_texts_by_category
+    export_tables_to_excel, get_table_statistics, group_texts_by_category,
+    find_available_ocr_files_multi_source, get_data_source_display_name
 )
 from ocr_validator_layout import OCRLayoutManager
 from ocr_by_vlm import ocr_with_vlm
@@ -36,26 +37,51 @@ class StreamlitOCRValidator:
         self.text_bbox_mapping = {}
         self.selected_text = None
         self.marked_errors = set()
+        
+        # 多数据源相关
+        self.all_sources = {}
+        self.current_source_key = None
+        self.current_source_config = None
         self.file_info = []
-        self.selected_file_index = -1 # 初始化不指向有效文件index
+        self.selected_file_index = -1
         self.display_options = []
         self.file_paths = []
 
         # 初始化布局管理器
         self.layout_manager = OCRLayoutManager(self)
 
-        # 加载文件信息
-        self.load_file_info()
+        # 加载多数据源文件信息
+        self.load_multi_source_info()
         
-    def load_file_info(self):
-        # 查找可用的OCR文件
-        self.file_info = find_available_ocr_files(self.config['paths']['ocr_out_dir'])
-        # 初始化session_state中的选择索引
-        if self.file_info:
-            # 创建显示选项列表
-            self.display_options = [f"{info['display_name']}" for info in self.file_info]
-            self.file_paths = [info['path'] for info in self.file_info]
-
+    def load_multi_source_info(self):
+        """加载多数据源文件信息"""
+        self.all_sources = find_available_ocr_files_multi_source(self.config)
+        
+        # 如果有数据源,默认选择第一个
+        if self.all_sources:
+            first_source_key = list(self.all_sources.keys())[0]
+            self.switch_to_source(first_source_key)
+    
+    def switch_to_source(self, source_key: str):
+        """切换到指定数据源"""
+        if source_key in self.all_sources:
+            self.current_source_key = source_key
+            source_data = self.all_sources[source_key]
+            self.current_source_config = source_data['config']
+            self.file_info = source_data['files']
+            
+            if self.file_info:
+                # 创建显示选项列表
+                self.display_options = [f"{info['display_name']}" for info in self.file_info]
+                self.file_paths = [info['path'] for info in self.file_info]
+                
+                # 重置文件选择
+                self.selected_file_index = -1
+                
+                print(f"✅ 切换到数据源: {source_key}")
+            else:
+                print(f"⚠️ 数据源 {source_key} 没有可用文件")
+    
     def setup_page_config(self):
         """设置页面配置"""
         ui_config = self.config['ui']
@@ -69,14 +95,83 @@ class StreamlitOCRValidator:
         # 加载CSS样式
         css_content = load_css_styles()
         st.markdown(f"<style>{css_content}</style>", unsafe_allow_html=True)
+
+    def create_data_source_selector(self):
+        """创建数据源选择器"""
+        if not self.all_sources:
+            st.warning("❌ 未找到任何数据源,请检查配置文件")
+            return
+        
+        # 数据源选择
+        source_options = {}
+        for source_key, source_data in self.all_sources.items():
+            display_name = get_data_source_display_name(source_data['config'])
+            source_options[display_name] = source_key
+        
+        # 获取当前选择的显示名称
+        current_display_name = None
+        if self.current_source_key:
+            for display_name, key in source_options.items():
+                if key == self.current_source_key:
+                    current_display_name = display_name
+                    break
+        
+        selected_display_name = st.selectbox(
+            "📁 选择数据源",
+            options=list(source_options.keys()),
+            index=list(source_options.keys()).index(current_display_name) if current_display_name else 0,
+            key="data_source_selector",
+            help="选择要分析的OCR数据源"
+        )
+        
+        selected_source_key = source_options[selected_display_name]
+        
+        # 如果数据源发生变化,切换数据源
+        if selected_source_key != self.current_source_key:
+            self.switch_to_source(selected_source_key)
+            # 重置session state
+            if 'selected_file_index' in st.session_state:
+                st.session_state.selected_file_index = 0
+            st.rerun()
+        
+        # 显示数据源信息
+        if self.current_source_config:
+            with st.expander("📋 数据源详情", expanded=False):
+                col1, col2, col3 = st.columns(3)
+                with col1:
+                    st.write(f"**名称:** {self.current_source_config['name']}")
+                    st.write(f"**OCR工具:** {self.current_source_config['ocr_tool']}")
+                with col2:
+                    st.write(f"**输出目录:** {self.current_source_config['ocr_out_dir']}")
+                    st.write(f"**图片目录:** {self.current_source_config.get('src_img_dir', 'N/A')}")
+                with col3:
+                    st.write(f"**描述:** {self.current_source_config.get('description', 'N/A')}")
+                    st.write(f"**文件数量:** {len(self.file_info)}")
     
     def load_ocr_data(self, json_path: str, md_path: Optional[str] = None, image_path: Optional[str] = None):
-        """加载OCR相关数据"""
+        """加载OCR相关数据 - 支持多数据源配置"""
         try:
-            self.ocr_data, self.md_content, self.image_path = load_ocr_data_file(json_path, self.config)
+            # 使用当前数据源的配置加载数据
+            if self.current_source_config:
+                # 临时修改config以使用当前数据源的配置
+                temp_config = self.config.copy()
+                temp_config['paths'] = {
+                    'ocr_out_dir': self.current_source_config['ocr_out_dir'],
+                    'src_img_dir': self.current_source_config.get('src_img_dir', ''),
+                    'pre_validation_dir': self.config['pre_validation']['out_dir']
+                }
+                
+                # 设置OCR工具类型
+                temp_config['current_ocr_tool'] = self.current_source_config['ocr_tool']
+                
+                self.ocr_data, self.md_content, self.image_path = load_ocr_data_file(json_path, temp_config)
+            else:
+                self.ocr_data, self.md_content, self.image_path = load_ocr_data_file(json_path, self.config)
+                
             self.process_data()
         except Exception as e:
             st.error(f"❌ 加载失败: {e}")
+            st.exception(e)
     
     def process_data(self):
         """处理OCR数据"""
@@ -373,7 +468,7 @@ class StreamlitOCRValidator:
                     st.error("❌ 当前OCR结果的Markdown文件不存在,无法进行对比")
                     return
                 # 第一步:准备目录
-                pre_validation_dir = Path(self.config['paths'].get('pre_validation_dir', './output/pre_validation/')).resolve()
+                pre_validation_dir = Path(self.config['pre_validation'].get('out_dir', './output/pre_validation/')).resolve()
                 pre_validation_dir.mkdir(parents=True, exist_ok=True)
                 status_text.write(f"工作目录: {pre_validation_dir}")
 
@@ -815,12 +910,11 @@ def main():
         validator = StreamlitOCRValidator()
         st.session_state.validator = validator
         st.session_state.validator.setup_page_config()
+        
         # 页面标题
         config = st.session_state.validator.config
         st.title(config['ui']['page_title'])
-        # st.markdown("---")
     else:
-        # 主内容区域
         validator = st.session_state.validator
         config = st.session_state.validator.config
     
@@ -830,39 +924,51 @@ def main():
     if 'marked_errors' not in st.session_state:
         st.session_state.marked_errors = set()
     
+    # 数据源选择器
+    validator.create_data_source_selector()
+    
+    # 如果没有可用的数据源,提前返回
+    if not validator.all_sources:
+        st.stop()
+    
+    # 文件选择区域
     with st.container(height=75, horizontal=True, horizontal_alignment='left', gap="medium"):
-        # st.subheader("📁 文件选择")
         # 初始化session_state中的选择索引
         if 'selected_file_index' not in st.session_state:
             st.session_state.selected_file_index = 0
+            
         if validator.display_options:
-            # 创建显示选项列表
-            selected_index = st.selectbox("选择OCR结果文件", 
+            # 文件选择下拉框
+            selected_index = st.selectbox(
+                "选择OCR结果文件", 
                 range(len(validator.display_options)),
                 format_func=lambda i: validator.display_options[i],
                 index=st.session_state.selected_file_index,
-                width=100,
                 key="selected_selectbox",
-                label_visibility="collapsed")
+                label_visibility="collapsed"
+            )
+            
             # 更新session_state
             if selected_index != st.session_state.selected_file_index:
                 st.session_state.selected_file_index = selected_index
 
             selected_file = validator.file_paths[selected_index]
 
-            # number_input, 范围是文件数量,默认值是1,步长是1
             # 页码输入器
             current_page = validator.file_info[selected_index]['page']
-            page_input = st.number_input("输入一个数字", 
+            page_input = st.number_input(
+                "输入页码", 
                 placeholder="输入页码", 
-                width=200,
                 label_visibility="collapsed",
-                min_value=1, max_value=len(validator.display_options), value=current_page, step=1,
+                min_value=1, 
+                max_value=len(validator.display_options), 
+                value=current_page, 
+                step=1,
                 key="page_input"
             )
+            
             # 当页码输入改变时,更新文件选择
             if page_input != current_page:
-                # 查找对应页码的文件索引
                 for i, info in enumerate(validator.file_info):
                     if info['page'] == page_input:
                         st.session_state.selected_file_index = i
@@ -870,35 +976,36 @@ def main():
                         st.rerun()
                         break
 
+            # 自动加载文件
             if (st.session_state.selected_file_index >= 0
                 and validator.selected_file_index != st.session_state.selected_file_index
                 and selected_file):
                 validator.selected_file_index = st.session_state.selected_file_index
                 st.session_state.validator.load_ocr_data(selected_file)
-                st.success(f"✅ 已加载第{validator.file_info[st.session_state.selected_file_index]['page']}页")
+                
+                # 显示加载成功信息
+                current_source_name = get_data_source_display_name(validator.current_source_config)
+                st.success(f"✅ 已加载 {current_source_name} - 第{validator.file_info[st.session_state.selected_file_index]['page']}页")
                 st.rerun()
-            # if st.button("🔄 加载文件", type="secondary") and selected_file:
-            #     st.session_state.validator.load_ocr_data(selected_file)
-            #     st.success(f"✅ 已加载第{validator.file_info[selected_index]['page']}页")
-            #     st.rerun()
         else:
-            st.warning("未找到OCR结果文件")
-            st.info("请确保output目录下有OCR结果文件")
+            st.warning("当前数据源中未找到OCR结果文件")
         
+        # VLM预校验按钮
         if st.button("VLM预校验", type="primary", icon=":material/compare_arrows:"):
             if validator.image_path and validator.md_content:
-                # 创建新的页面区域来显示VLM预校验结果
                 validator.vlm_pre_validation()
             else:
-                message_box("❌ 请先加载OCR数据文件", "error")
+                message_box("❌ 请先选择OCR数据文件", "error")
 
+        # 查看预校验结果按钮
         if st.button("查看预校验结果", type="secondary", icon=":material/quick_reference_all:"):
             validator.show_comparison_results_dialog()
 
+    # 显示当前数据源统计信息
     with st.expander("🔧 OCR工具统计信息", expanded=False):
-        # 显示统计信息
         stats = validator.get_statistics()
-        col1, col2, col3, col4, col5 = st.columns(5)  # 增加一列
+        col1, col2, col3, col4, col5 = st.columns(5)
+        
         with col1:
             st.metric("📊 总文本块", stats['total_texts'])
         with col2:
@@ -908,25 +1015,23 @@ def main():
         with col4:
             st.metric("✅ 准确率", f"{stats['accuracy_rate']:.1f}%")
         with col5:
-            # 显示OCR工具信息
-            if stats['tool_info']:
-                tool_names = list(stats['tool_info'].keys())
-                main_tool = tool_names[0] if tool_names else "未知"
-                st.metric("🔧 OCR工具", main_tool)
+            # 显示当前数据源信息
+            if validator.current_source_config:
+                tool_display = validator.current_source_config['ocr_tool'].upper()
+                st.metric("🔧 OCR工具", tool_display)
+        
         # 详细工具信息
         if stats['tool_info']:
-            st.write(stats['tool_info'])
+            st.write("**详细信息:**", stats['tool_info'])
     
-    # st.markdown("---")
-    
-    # 创建标签页
+    # 其余标签页保持不变...
     tab1, tab2, tab3, tab4 = st.tabs(["📄 内容校验", "📊 表格分析", "📈 数据统计", "🚀 快速导航"])
     
     with tab1:
         validator.create_compact_layout(config)
 
     with tab2:
-        # 表格分析页面
+        # 表格分析页面 - 保持原有逻辑
         st.header("📊 表格数据分析")
         
         if validator.md_content and '<table' in validator.md_content.lower():
@@ -946,16 +1051,20 @@ def main():
                         st.download_button(
                             label="📥 下载Excel文件",
                             data=output.getvalue(),
-                            file_name="ocr_tables.xlsx",
+                            file_name=f"ocr_tables_{validator.current_source_config['ocr_tool']}.xlsx",
                             mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
                         )
         else:
             st.info("当前OCR结果中没有检测到表格数据")
     
     with tab3:
-        # 数据统计页面
+        # 数据统计页面 - 保持原有逻辑
         st.header("📈 OCR数据统计")
         
+        # 添加数据源特定的统计信息
+        if validator.current_source_config:
+            st.subheader(f"📊 {get_data_source_display_name(validator.current_source_config)} - 统计信息")
+        
         if stats['categories']:
             st.subheader("📊 类别分布")
             fig_pie = px.pie(
@@ -979,7 +1088,7 @@ def main():
         st.plotly_chart(fig_bar, use_container_width=True)
     
     with tab4:
-        # 快速导航功能
+        # 快速导航功能 - 保持原有逻辑
         st.header("🚀 快速导航")
         
         if not validator.text_bbox_mapping:
@@ -991,7 +1100,7 @@ def main():
             # 创建导航按钮
             for category, texts in categories.items():
                 with st.expander(f"{category} ({len(texts)}项)", expanded=False):
-                    cols = st.columns(3)  # 每行3个按钮
+                    cols = st.columns(3)
                     for i, text in enumerate(texts):
                         col_idx = i % 3
                         with cols[col_idx]: