streamlit_validator_ui.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. """
  2. UI组件和页面配置
  3. """
  4. import streamlit as st
  5. from ocr_validator_file_utils import load_css_styles
  6. from ocr_validator_utils import get_data_source_display_name
  7. def setup_page_config(config):
  8. """设置页面配置"""
  9. ui_config = config['ui']
  10. st.set_page_config(
  11. page_title=ui_config['page_title'],
  12. page_icon=ui_config['page_icon'],
  13. layout=ui_config['layout'],
  14. initial_sidebar_state=ui_config['sidebar_state']
  15. )
  16. css_content = load_css_styles()
  17. st.markdown(f"<style>{css_content}</style>", unsafe_allow_html=True)
  18. def create_data_source_selector(validator):
  19. """创建双数据源选择器"""
  20. if not validator.all_sources:
  21. st.warning("❌ 未找到任何数据源,请检查配置文件")
  22. return
  23. source_options = {}
  24. for source_key, source_data in validator.all_sources.items():
  25. display_name = get_data_source_display_name(source_data['config'])
  26. source_options[display_name] = source_key
  27. col1, col2 = st.columns(2)
  28. with col1:
  29. st.markdown("#### 📊 OCR数据源")
  30. current_display_name = None
  31. if validator.current_source_key:
  32. for display_name, key in source_options.items():
  33. if key == validator.current_source_key:
  34. current_display_name = display_name
  35. break
  36. selected_ocr_display = st.selectbox(
  37. "选择OCR数据源",
  38. options=list(source_options.keys()),
  39. index=list(source_options.keys()).index(current_display_name) if current_display_name else 0,
  40. key="ocr_source_selector",
  41. label_visibility="collapsed",
  42. help="选择要分析的OCR数据源"
  43. )
  44. selected_ocr_key = source_options[selected_ocr_display]
  45. if selected_ocr_key != validator.current_source_key:
  46. validator.switch_to_source(selected_ocr_key)
  47. if 'selected_file_index' in st.session_state:
  48. st.session_state.selected_file_index = 0
  49. # ✅ 数据源变更会在主函数中检测并重置验证结果
  50. st.rerun()
  51. if validator.current_source_config:
  52. with st.expander("📋 OCR数据源详情", expanded=False):
  53. st.write(f"**工具:** {validator.current_source_config['ocr_tool']}")
  54. st.write(f"**文件数:** {len(validator.file_info)}")
  55. with col2:
  56. st.markdown("#### 🔍 验证数据源")
  57. verify_display_name = None
  58. if validator.verify_source_key:
  59. for display_name, key in source_options.items():
  60. if key == validator.verify_source_key:
  61. verify_display_name = display_name
  62. break
  63. selected_verify_display = st.selectbox(
  64. "选择验证数据源",
  65. options=list(source_options.keys()),
  66. index=list(source_options.keys()).index(verify_display_name) if verify_display_name else (1 if len(source_options) > 1 else 0),
  67. key="verify_source_selector",
  68. label_visibility="collapsed",
  69. help="选择用于交叉验证的数据源"
  70. )
  71. selected_verify_key = source_options[selected_verify_display]
  72. if selected_verify_key != validator.verify_source_key:
  73. validator.switch_to_verify_source(selected_verify_key)
  74. # ✅ 数据源变更会在主函数中检测并重置验证结果
  75. st.rerun()
  76. if validator.verify_source_config:
  77. with st.expander("📋 验证数据源详情", expanded=False):
  78. st.write(f"**工具:** {validator.verify_source_config['ocr_tool']}")
  79. st.write(f"**文件数:** {len(validator.verify_file_info)}")
  80. # ✅ 显示数据源状态提示
  81. if validator.current_source_key == validator.verify_source_key:
  82. st.warning("⚠️ OCR数据源和验证数据源相同,建议选择不同的数据源进行交叉验证")
  83. else:
  84. # 检查是否有交叉验证结果
  85. has_results = 'cross_validation_batch_result' in st.session_state and st.session_state.cross_validation_batch_result is not None
  86. if has_results:
  87. # 检查验证结果是否与当前数据源匹配
  88. result = st.session_state.cross_validation_batch_result
  89. result_ocr_source = result.get('ocr_source', '')
  90. result_verify_source = result.get('verify_source', '')
  91. current_ocr_source = get_data_source_display_name(validator.current_source_config)
  92. current_verify_source = get_data_source_display_name(validator.verify_source_config)
  93. if result_ocr_source == current_ocr_source and result_verify_source == current_verify_source:
  94. st.success(f"✅ 已选择 {selected_ocr_display} 与 {selected_verify_display} 进行交叉验证(已有验证结果)")
  95. else:
  96. st.info(f"ℹ️ 已选择 {selected_ocr_display} 与 {selected_verify_display} 进行交叉验证(验证结果已过期,请重新验证)")
  97. else:
  98. st.success(f"✅ 已选择 {selected_ocr_display} 与 {selected_verify_display} 进行交叉验证")
  99. @st.dialog("message", width="small", dismissible=True, on_dismiss="rerun")
  100. def message_box(msg: str, msg_type: str = "info"):
  101. """消息对话框"""
  102. if msg_type == "info":
  103. st.info(msg)
  104. elif msg_type == "warning":
  105. st.warning(msg)
  106. elif msg_type == "error":
  107. st.error(msg)