streamlit_validator_ui.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286
  1. """
  2. UI组件和页面配置
  3. """
  4. import streamlit as st
  5. from ocr_validator_file_utils import load_css_styles
  6. from ocr_validator_utils import get_data_source_display_name
  7. def setup_page_config(config):
  8. """设置页面配置"""
  9. ui_config = config['ui']
  10. st.set_page_config(
  11. page_title=ui_config['page_title'],
  12. page_icon=ui_config['page_icon'],
  13. layout=ui_config['layout'],
  14. initial_sidebar_state=ui_config['sidebar_state']
  15. )
  16. css_content = load_css_styles()
  17. st.markdown(f"<style>{css_content}</style>", unsafe_allow_html=True)
  18. def _parse_document_from_source_key(source_key: str, documents: list) -> str:
  19. """
  20. 🎯 从数据源 key 解析文档名
  21. 数据源 key 格式: {文档名}_{result_dir}
  22. 例如: "德_内蒙古银行照_mineru_vllm_results"
  23. Args:
  24. source_key: 数据源 key
  25. documents: 可用文档列表
  26. Returns:
  27. 文档名,如果无法解析则返回 None
  28. """
  29. # 🎯 按文档名长度降序排序,优先匹配长文档名
  30. sorted_docs = sorted(documents, key=len, reverse=True)
  31. for doc in sorted_docs:
  32. # 检查数据源 key 是否以 "文档名_" 开头
  33. if source_key.startswith(f"{doc}_"):
  34. return doc
  35. # 如果没有匹配,尝试直接匹配
  36. if source_key in documents:
  37. return source_key
  38. return None
  39. def create_data_source_selector(validator):
  40. """
  41. 🎯 新版数据源选择器 - 3 列布局
  42. 1. 选择文档
  43. 2. 选择 OCR 工具
  44. 3. 选择验证工具
  45. """
  46. from config_manager import ConfigManager
  47. # 获取配置管理器
  48. if 'config_manager' not in st.session_state:
  49. st.error("配置管理器未初始化")
  50. return
  51. config_manager: ConfigManager = st.session_state.config_manager
  52. # ============================================================
  53. # 3 列布局
  54. # ============================================================
  55. col1, col2, col3 = st.columns(3)
  56. # ============================================================
  57. # 第 1 列:选择文档
  58. # ============================================================
  59. with col1:
  60. st.markdown("#### 📄 选择文档")
  61. documents = config_manager.list_documents()
  62. if not documents:
  63. st.error("未找到任何文档配置")
  64. return
  65. # 🎯 从当前数据源 key 解析文档名
  66. current_doc = None
  67. if validator.current_source_key:
  68. current_doc = _parse_document_from_source_key(validator.current_source_key, documents)
  69. # 文档下拉框
  70. selected_doc_index = 0
  71. if current_doc and current_doc in documents:
  72. selected_doc_index = documents.index(current_doc)
  73. selected_doc = st.selectbox(
  74. "文档",
  75. options=documents,
  76. index=selected_doc_index,
  77. key="selected_document",
  78. label_visibility="collapsed",
  79. help="选择要处理的文档"
  80. )
  81. # 显示文档详情
  82. doc_config = config_manager.get_document(selected_doc)
  83. if doc_config:
  84. enabled_count = len([r for r in doc_config.ocr_results if r.enabled])
  85. with st.expander("📋 文档详情", expanded=False):
  86. st.caption(f"**基础目录:** `{doc_config.base_dir}`")
  87. st.caption(f"**可用工具:** {enabled_count} 个")
  88. # ============================================================
  89. # 第 2 列:选择 OCR 工具
  90. # ============================================================
  91. with col2:
  92. st.markdown("#### 🔧 OCR 数据源")
  93. if not doc_config:
  94. st.error(f"文档配置不存在: {selected_doc}")
  95. return
  96. # 获取该文档的所有启用的 OCR 结果
  97. enabled_ocr_results = [r for r in doc_config.ocr_results if r.enabled]
  98. if not enabled_ocr_results:
  99. st.warning(f"文档 {selected_doc} 没有可用的 OCR 工具")
  100. return
  101. # 🎯 构建 OCR 工具选项(使用 result_dir)
  102. ocr_tool_options = []
  103. ocr_source_keys = [] # 对应的数据源 key
  104. for ocr_result in enabled_ocr_results:
  105. # 🎯 显示名称:优先使用 description
  106. if ocr_result.description:
  107. display_name = ocr_result.description
  108. else:
  109. tool_config = config_manager.get_ocr_tool(ocr_result.tool)
  110. display_name = tool_config.name if tool_config else ocr_result.tool
  111. # 🎯 构建数据源 key(使用 result_dir)
  112. source_key = f"{selected_doc}_{ocr_result.result_dir}"
  113. ocr_tool_options.append(display_name)
  114. ocr_source_keys.append(source_key)
  115. # 获取当前选中的 OCR 工具索引
  116. current_ocr_index = 0
  117. if validator.current_source_key:
  118. # 🎯 检查当前数据源是否属于当前文档
  119. if validator.current_source_key.startswith(f"{selected_doc}_"):
  120. if validator.current_source_key in ocr_source_keys:
  121. current_ocr_index = ocr_source_keys.index(validator.current_source_key)
  122. # OCR 工具下拉框
  123. selected_ocr_index = st.selectbox(
  124. "OCR 工具",
  125. options=range(len(ocr_tool_options)),
  126. format_func=lambda i: ocr_tool_options[i],
  127. index=current_ocr_index,
  128. key="selected_ocr_tool",
  129. label_visibility="collapsed",
  130. help="选择 OCR 识别工具"
  131. )
  132. selected_ocr_source_key = ocr_source_keys[selected_ocr_index]
  133. # 🎯 切换 OCR 数据源
  134. if validator.current_source_key != selected_ocr_source_key:
  135. validator.switch_to_source(selected_ocr_source_key)
  136. st.success(f"✅ 已切换 OCR 工具")
  137. # 重置文件选择
  138. if 'selected_file_index' in st.session_state:
  139. st.session_state.selected_file_index = 0
  140. st.rerun()
  141. # 显示 OCR 数据源详情
  142. if validator.current_source_config:
  143. with st.expander("📋 OCR 详情", expanded=False):
  144. st.caption(f"**工具:** {validator.current_source_config['ocr_tool']}")
  145. st.caption(f"**结果目录:** `{enabled_ocr_results[selected_ocr_index].result_dir}`")
  146. st.caption(f"**文件数:** {len(validator.file_info)}")
  147. # ============================================================
  148. # 第 3 列:选择验证工具
  149. # ============================================================
  150. with col3:
  151. st.markdown("#### 🔍 验证数据源")
  152. # 🎯 验证工具选项(排除当前 OCR 工具)
  153. verify_tool_options = []
  154. verify_source_keys = []
  155. verify_results = [] # 保存对应的 ocr_result,用于显示详情
  156. for i, ocr_result in enumerate(enabled_ocr_results):
  157. # 跳过当前 OCR 工具
  158. if ocr_source_keys[i] == selected_ocr_source_key:
  159. continue
  160. # 🎯 显示名称
  161. if ocr_result.description:
  162. display_name = ocr_result.description
  163. else:
  164. tool_config = config_manager.get_ocr_tool(ocr_result.tool)
  165. display_name = tool_config.name if tool_config else ocr_result.tool
  166. verify_tool_options.append(display_name)
  167. verify_source_keys.append(ocr_source_keys[i])
  168. verify_results.append(ocr_result)
  169. if not verify_tool_options:
  170. st.warning("⚠️ 没有其他可用的验证工具")
  171. st.info("💡 可以添加更多 OCR 工具到配置文件")
  172. return
  173. # 获取当前选中的验证工具索引
  174. current_verify_index = 0
  175. if validator.verify_source_key:
  176. # 🎯 检查验证数据源是否属于当前文档
  177. if validator.verify_source_key.startswith(f"{selected_doc}_"):
  178. if validator.verify_source_key in verify_source_keys:
  179. current_verify_index = verify_source_keys.index(validator.verify_source_key)
  180. # 验证工具下拉框
  181. selected_verify_index = st.selectbox(
  182. "验证工具",
  183. options=range(len(verify_tool_options)),
  184. format_func=lambda i: verify_tool_options[i],
  185. index=current_verify_index,
  186. key="selected_verify_tool",
  187. label_visibility="collapsed",
  188. help="选择用于交叉验证的工具"
  189. )
  190. selected_verify_source_key = verify_source_keys[selected_verify_index]
  191. # 🎯 切换验证数据源
  192. if validator.verify_source_key != selected_verify_source_key:
  193. validator.switch_to_verify_source(selected_verify_source_key)
  194. st.success(f"✅ 已切换验证工具")
  195. st.rerun()
  196. # 显示验证数据源详情
  197. if validator.verify_source_config:
  198. verify_result = verify_results[selected_verify_index]
  199. with st.expander("📋 验证详情", expanded=False):
  200. st.caption(f"**工具:** {validator.verify_source_config['ocr_tool']}")
  201. st.caption(f"**结果目录:** `{verify_result.result_dir}`")
  202. st.caption(f"**文件数:** {len(validator.verify_file_info)}")
  203. # ============================================================
  204. # 状态提示(全宽)
  205. # ============================================================
  206. if validator.current_source_key == validator.verify_source_key:
  207. st.warning("⚠️ OCR数据源和验证数据源相同,建议选择不同的工具进行交叉验证")
  208. else:
  209. # 🎯 检查是否有交叉验证结果
  210. has_results = (
  211. 'cross_validation_batch_result' in st.session_state
  212. and st.session_state.cross_validation_batch_result is not None
  213. )
  214. if has_results:
  215. # 检查验证结果是否与当前数据源匹配
  216. result = st.session_state.cross_validation_batch_result
  217. result_ocr_source = result.get('ocr_source', '')
  218. result_verify_source = result.get('verify_source', '')
  219. current_ocr_source = get_data_source_display_name(validator.current_source_config)
  220. current_verify_source = get_data_source_display_name(validator.verify_source_config)
  221. if result_ocr_source == current_ocr_source and result_verify_source == current_verify_source:
  222. st.success(f"✅ 已选择 **{selected_doc}** 文档,使用 **{ocr_tool_options[selected_ocr_index]}** 与 **{verify_tool_options[selected_verify_index]}** 进行交叉验证 **(已有验证结果)**")
  223. else:
  224. st.info(f"ℹ️ 已选择 **{selected_doc}** 文档,使用 **{ocr_tool_options[selected_ocr_index]}** 与 **{verify_tool_options[selected_verify_index]}** 进行交叉验证 **(验证结果已过期,请重新验证)**")
  225. else:
  226. st.success(f"✅ 已选择 **{selected_doc}** 文档,使用 **{ocr_tool_options[selected_ocr_index]}** 与 **{verify_tool_options[selected_verify_index]}** 进行交叉验证")
  227. @st.dialog("message", width="small", dismissible=True, on_dismiss="rerun")
  228. def message_box(msg: str, msg_type: str = "info"):
  229. """消息对话框"""
  230. if msg_type == "info":
  231. st.info(msg)
  232. elif msg_type == "warning":
  233. st.warning(msg)
  234. elif msg_type == "error":
  235. st.error(msg)