| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286 |
- """
- UI组件和页面配置
- """
- import streamlit as st
- from ocr_validator_file_utils import load_css_styles
- from ocr_validator_utils import get_data_source_display_name
- def setup_page_config(config):
- """设置页面配置"""
- ui_config = config['ui']
- st.set_page_config(
- page_title=ui_config['page_title'],
- page_icon=ui_config['page_icon'],
- layout=ui_config['layout'],
- initial_sidebar_state=ui_config['sidebar_state']
- )
-
- css_content = load_css_styles()
- st.markdown(f"<style>{css_content}</style>", unsafe_allow_html=True)
- def _parse_document_from_source_key(source_key: str, documents: list) -> str:
- """
- 🎯 从数据源 key 解析文档名
-
- 数据源 key 格式: {文档名}_{result_dir}
- 例如: "德_内蒙古银行照_mineru_vllm_results"
-
- Args:
- source_key: 数据源 key
- documents: 可用文档列表
-
- Returns:
- 文档名,如果无法解析则返回 None
- """
- # 🎯 按文档名长度降序排序,优先匹配长文档名
- sorted_docs = sorted(documents, key=len, reverse=True)
-
- for doc in sorted_docs:
- # 检查数据源 key 是否以 "文档名_" 开头
- if source_key.startswith(f"{doc}_"):
- return doc
-
- # 如果没有匹配,尝试直接匹配
- if source_key in documents:
- return source_key
-
- return None
- def create_data_source_selector(validator):
- """
- 🎯 新版数据源选择器 - 3 列布局
- 1. 选择文档
- 2. 选择 OCR 工具
- 3. 选择验证工具
- """
- from config_manager import ConfigManager
-
- # 获取配置管理器
- if 'config_manager' not in st.session_state:
- st.error("配置管理器未初始化")
- return
-
- config_manager: ConfigManager = st.session_state.config_manager
-
- # ============================================================
- # 3 列布局
- # ============================================================
- col1, col2, col3 = st.columns(3)
-
- # ============================================================
- # 第 1 列:选择文档
- # ============================================================
- with col1:
- st.markdown("#### 📄 选择文档")
-
- documents = config_manager.list_documents()
- if not documents:
- st.error("未找到任何文档配置")
- return
-
- # 🎯 从当前数据源 key 解析文档名
- current_doc = None
- if validator.current_source_key:
- current_doc = _parse_document_from_source_key(validator.current_source_key, documents)
-
- # 文档下拉框
- selected_doc_index = 0
- if current_doc and current_doc in documents:
- selected_doc_index = documents.index(current_doc)
-
- selected_doc = st.selectbox(
- "文档",
- options=documents,
- index=selected_doc_index,
- key="selected_document",
- label_visibility="collapsed",
- help="选择要处理的文档"
- )
-
- # 显示文档详情
- doc_config = config_manager.get_document(selected_doc)
- if doc_config:
- enabled_count = len([r for r in doc_config.ocr_results if r.enabled])
- with st.expander("📋 文档详情", expanded=False):
- st.caption(f"**基础目录:** `{doc_config.base_dir}`")
- st.caption(f"**可用工具:** {enabled_count} 个")
-
- # ============================================================
- # 第 2 列:选择 OCR 工具
- # ============================================================
- with col2:
- st.markdown("#### 🔧 OCR 数据源")
-
- if not doc_config:
- st.error(f"文档配置不存在: {selected_doc}")
- return
-
- # 获取该文档的所有启用的 OCR 结果
- enabled_ocr_results = [r for r in doc_config.ocr_results if r.enabled]
-
- if not enabled_ocr_results:
- st.warning(f"文档 {selected_doc} 没有可用的 OCR 工具")
- return
-
- # 🎯 构建 OCR 工具选项(使用 result_dir)
- ocr_tool_options = []
- ocr_source_keys = [] # 对应的数据源 key
-
- for ocr_result in enabled_ocr_results:
- # 🎯 显示名称:优先使用 description
- if ocr_result.description:
- display_name = ocr_result.description
- else:
- tool_config = config_manager.get_ocr_tool(ocr_result.tool)
- display_name = tool_config.name if tool_config else ocr_result.tool
-
- # 🎯 构建数据源 key(使用 result_dir)
- source_key = f"{selected_doc}_{ocr_result.result_dir}"
-
- ocr_tool_options.append(display_name)
- ocr_source_keys.append(source_key)
-
- # 获取当前选中的 OCR 工具索引
- current_ocr_index = 0
- if validator.current_source_key:
- # 🎯 检查当前数据源是否属于当前文档
- if validator.current_source_key.startswith(f"{selected_doc}_"):
- if validator.current_source_key in ocr_source_keys:
- current_ocr_index = ocr_source_keys.index(validator.current_source_key)
-
- # OCR 工具下拉框
- selected_ocr_index = st.selectbox(
- "OCR 工具",
- options=range(len(ocr_tool_options)),
- format_func=lambda i: ocr_tool_options[i],
- index=current_ocr_index,
- key="selected_ocr_tool",
- label_visibility="collapsed",
- help="选择 OCR 识别工具"
- )
-
- selected_ocr_source_key = ocr_source_keys[selected_ocr_index]
-
- # 🎯 切换 OCR 数据源
- if validator.current_source_key != selected_ocr_source_key:
- validator.switch_to_source(selected_ocr_source_key)
- st.success(f"✅ 已切换 OCR 工具")
- # 重置文件选择
- if 'selected_file_index' in st.session_state:
- st.session_state.selected_file_index = 0
- st.rerun()
-
- # 显示 OCR 数据源详情
- if validator.current_source_config:
- with st.expander("📋 OCR 详情", expanded=False):
- st.caption(f"**工具:** {validator.current_source_config['ocr_tool']}")
- st.caption(f"**结果目录:** `{enabled_ocr_results[selected_ocr_index].result_dir}`")
- st.caption(f"**文件数:** {len(validator.file_info)}")
-
- # ============================================================
- # 第 3 列:选择验证工具
- # ============================================================
- with col3:
- st.markdown("#### 🔍 验证数据源")
-
- # 🎯 验证工具选项(排除当前 OCR 工具)
- verify_tool_options = []
- verify_source_keys = []
- verify_results = [] # 保存对应的 ocr_result,用于显示详情
-
- for i, ocr_result in enumerate(enabled_ocr_results):
- # 跳过当前 OCR 工具
- if ocr_source_keys[i] == selected_ocr_source_key:
- continue
-
- # 🎯 显示名称
- if ocr_result.description:
- display_name = ocr_result.description
- else:
- tool_config = config_manager.get_ocr_tool(ocr_result.tool)
- display_name = tool_config.name if tool_config else ocr_result.tool
-
- verify_tool_options.append(display_name)
- verify_source_keys.append(ocr_source_keys[i])
- verify_results.append(ocr_result)
-
- if not verify_tool_options:
- st.warning("⚠️ 没有其他可用的验证工具")
- st.info("💡 可以添加更多 OCR 工具到配置文件")
- return
-
- # 获取当前选中的验证工具索引
- current_verify_index = 0
- if validator.verify_source_key:
- # 🎯 检查验证数据源是否属于当前文档
- if validator.verify_source_key.startswith(f"{selected_doc}_"):
- if validator.verify_source_key in verify_source_keys:
- current_verify_index = verify_source_keys.index(validator.verify_source_key)
-
- # 验证工具下拉框
- selected_verify_index = st.selectbox(
- "验证工具",
- options=range(len(verify_tool_options)),
- format_func=lambda i: verify_tool_options[i],
- index=current_verify_index,
- key="selected_verify_tool",
- label_visibility="collapsed",
- help="选择用于交叉验证的工具"
- )
-
- selected_verify_source_key = verify_source_keys[selected_verify_index]
-
- # 🎯 切换验证数据源
- if validator.verify_source_key != selected_verify_source_key:
- validator.switch_to_verify_source(selected_verify_source_key)
- st.success(f"✅ 已切换验证工具")
- st.rerun()
-
- # 显示验证数据源详情
- if validator.verify_source_config:
- verify_result = verify_results[selected_verify_index]
- with st.expander("📋 验证详情", expanded=False):
- st.caption(f"**工具:** {validator.verify_source_config['ocr_tool']}")
- st.caption(f"**结果目录:** `{verify_result.result_dir}`")
- st.caption(f"**文件数:** {len(validator.verify_file_info)}")
-
- # ============================================================
- # 状态提示(全宽)
- # ============================================================
- if validator.current_source_key == validator.verify_source_key:
- st.warning("⚠️ OCR数据源和验证数据源相同,建议选择不同的工具进行交叉验证")
- else:
- # 🎯 检查是否有交叉验证结果
- has_results = (
- 'cross_validation_batch_result' in st.session_state
- and st.session_state.cross_validation_batch_result is not None
- )
-
- if has_results:
- # 检查验证结果是否与当前数据源匹配
- result = st.session_state.cross_validation_batch_result
- result_ocr_source = result.get('ocr_source', '')
- result_verify_source = result.get('verify_source', '')
- current_ocr_source = get_data_source_display_name(validator.current_source_config)
- current_verify_source = get_data_source_display_name(validator.verify_source_config)
-
- if result_ocr_source == current_ocr_source and result_verify_source == current_verify_source:
- st.success(f"✅ 已选择 **{selected_doc}** 文档,使用 **{ocr_tool_options[selected_ocr_index]}** 与 **{verify_tool_options[selected_verify_index]}** 进行交叉验证 **(已有验证结果)**")
- else:
- st.info(f"ℹ️ 已选择 **{selected_doc}** 文档,使用 **{ocr_tool_options[selected_ocr_index]}** 与 **{verify_tool_options[selected_verify_index]}** 进行交叉验证 **(验证结果已过期,请重新验证)**")
- else:
- st.success(f"✅ 已选择 **{selected_doc}** 文档,使用 **{ocr_tool_options[selected_ocr_index]}** 与 **{verify_tool_options[selected_verify_index]}** 进行交叉验证")
- @st.dialog("message", width="small", dismissible=True, on_dismiss="rerun")
- def message_box(msg: str, msg_type: str = "info"):
- """消息对话框"""
- if msg_type == "info":
- st.info(msg)
- elif msg_type == "warning":
- st.warning(msg)
- elif msg_type == "error":
- st.error(msg)
|