streamlit_ocr_validator.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. #!/usr/bin/env python3
  2. """
  3. 基于Streamlit的OCR可视化校验工具(主入口)
  4. """
  5. import streamlit as st
  6. from pathlib import Path
  7. import json
  8. import sys
  9. # 添加 ocr_platform 根目录到 Python 路径(用于导入 ocr_utils)
  10. # 使用 resolve() 确保路径是绝对路径,避免相对路径导致的 IndexError
  11. _file_path = Path(__file__).resolve()
  12. ocr_platform_root = _file_path.parents[1] # streamlit_ocr_validator.py -> ocr_validator -> ocr_platform
  13. if str(ocr_platform_root) not in sys.path:
  14. sys.path.insert(0, str(ocr_platform_root))
  15. from streamlit_validator_core import StreamlitOCRValidator
  16. from streamlit_validator_ui import (
  17. setup_page_config, create_data_source_selector, message_box
  18. )
  19. from streamlit_validator_table import display_html_table_as_dataframe
  20. from streamlit_validator_cross import (
  21. cross_validation_dialog, show_batch_cross_validation_results_dialog
  22. )
  23. from streamlit_validator_result import display_single_page_cross_validation
  24. from ocr_validator_utils import get_data_source_display_name
  25. from config_manager import load_config # 🎯 使用新配置管理器
  26. def reset_cross_validation_results():
  27. """重置交叉验证结果"""
  28. if 'cross_validation_batch_result' in st.session_state:
  29. st.session_state.cross_validation_batch_result = None
  30. print("🔄 数据源已变更,交叉验证结果已清空")
  31. def main():
  32. """主应用"""
  33. # 🎯 初始化配置管理器
  34. if 'config_manager' not in st.session_state:
  35. try:
  36. st.session_state.config_manager = load_config(config_dir="config")
  37. # 🎯 生成 OCRValidator 所需的配置
  38. st.session_state.validator_config = st.session_state.config_manager.to_validator_config()
  39. print("✅ 配置管理器初始化成功")
  40. print(f"📄 发现 {len(st.session_state.config_manager.list_documents())} 个文档配置")
  41. print(f"🔧 发现 {len(st.session_state.config_manager.list_ocr_tools())} 个 OCR 工具")
  42. except Exception as e:
  43. st.error(f"❌ 配置加载失败: {e}")
  44. st.stop()
  45. config_manager = st.session_state.config_manager
  46. validator_config = st.session_state.validator_config
  47. # 初始化应用
  48. if 'validator' not in st.session_state:
  49. # 🎯 直接传递配置字典给 OCRValidator
  50. validator = StreamlitOCRValidator(config_dict=validator_config)
  51. st.session_state.validator = validator
  52. setup_page_config(validator_config)
  53. # 页面标题
  54. st.title(validator_config['ui']['page_title'])
  55. # 初始化数据源追踪
  56. st.session_state.current_ocr_source = validator.current_source_key
  57. st.session_state.current_verify_source = validator.verify_source_key
  58. else:
  59. validator = st.session_state.validator
  60. if 'selected_text' not in st.session_state:
  61. st.session_state.selected_text = None
  62. st.session_state.compact_search_query = None
  63. if 'marked_errors' not in st.session_state:
  64. st.session_state.marked_errors = set()
  65. # 数据源选择器
  66. create_data_source_selector(validator)
  67. # ✅ 检测数据源是否变更
  68. ocr_source_changed = False
  69. verify_source_changed = False
  70. if 'current_ocr_source' in st.session_state:
  71. if st.session_state.current_ocr_source != validator.current_source_key:
  72. ocr_source_changed = True
  73. st.session_state.current_ocr_source = validator.current_source_key
  74. print(f"🔄 OCR数据源已切换到: {validator.current_source_key}")
  75. if 'current_verify_source' in st.session_state:
  76. if st.session_state.current_verify_source != validator.verify_source_key:
  77. verify_source_changed = True
  78. st.session_state.current_verify_source = validator.verify_source_key
  79. print(f"🔄 验证数据源已切换到: {validator.verify_source_key}")
  80. # ✅ 如果任一数据源变更,清空交叉验证结果
  81. if ocr_source_changed or verify_source_changed:
  82. reset_cross_validation_results()
  83. # 显示提示信息
  84. if ocr_source_changed and verify_source_changed:
  85. st.info("ℹ️ OCR数据源和验证数据源已变更,请重新运行交叉验证")
  86. elif ocr_source_changed:
  87. st.info("ℹ️ OCR数据源已变更,请重新运行交叉验证")
  88. elif verify_source_changed:
  89. st.info("ℹ️ 验证数据源已变更,请重新运行交叉验证")
  90. # 如果没有可用的数据源,提前返回
  91. if not validator.all_sources:
  92. st.warning("⚠️ 未找到任何数据源,请检查配置文件")
  93. # 🎯 显示配置信息帮助调试
  94. with st.expander("🔍 配置信息", expanded=True):
  95. st.write("**已加载的文档:**")
  96. docs = config_manager.list_documents()
  97. if docs:
  98. for doc in docs:
  99. doc_config = config_manager.get_document(doc)
  100. st.write(f"- **{doc}**")
  101. st.write(f" - 基础目录: `{doc_config.base_dir}`")
  102. st.write(f" - OCR 结果: {len([r for r in doc_config.ocr_results if r.enabled])} 个已启用")
  103. else:
  104. st.write("无")
  105. st.write("**已加载的 OCR 工具:**")
  106. tools = config_manager.list_ocr_tools()
  107. if tools:
  108. for tool in tools:
  109. tool_config = config_manager.get_ocr_tool(tool)
  110. st.write(f"- **{tool_config.name}** (`{tool}`)")
  111. else:
  112. st.write("无")
  113. st.write("**配置文件路径:**")
  114. st.code(str(config_manager.config_dir / "global.yaml"))
  115. st.write("**生成的数据源:**")
  116. data_sources = config_manager.get_data_sources()
  117. if data_sources:
  118. for ds in data_sources:
  119. st.write(f"- `{ds.name}`")
  120. st.write(f" - 工具: {ds.ocr_tool}")
  121. st.write(f" - 结果目录: {ds.ocr_out_dir}")
  122. st.write(f" - 图片目录: {ds.src_img_dir}")
  123. else:
  124. st.write("无")
  125. st.stop()
  126. # 文件选择区域
  127. with st.container(height=75, horizontal=True, horizontal_alignment='left', gap="medium"):
  128. if 'selected_file_index' not in st.session_state:
  129. st.session_state.selected_file_index = 0
  130. st.session_state.file_selectbox = 0
  131. if validator.display_options:
  132. # 确保 selected_file_index 在有效范围内
  133. if st.session_state.selected_file_index >= len(validator.display_options):
  134. st.session_state.selected_file_index = 0
  135. st.session_state.file_selectbox = 0
  136. # 使用独立的 key 给 selectbox,避免 Streamlit 锁定 selected_file_index
  137. # 在创建 selectbox 之前,同步 file_selectbox 的值到 selected_file_index
  138. if 'file_selectbox' not in st.session_state:
  139. st.session_state.file_selectbox = st.session_state.selected_file_index
  140. elif st.session_state.file_selectbox != st.session_state.selected_file_index:
  141. # 如果 selected_file_index 被外部更新(如通过页码输入),同步到 file_selectbox
  142. st.session_state.file_selectbox = st.session_state.selected_file_index
  143. selected_index = st.selectbox(
  144. "选择OCR结果文件",
  145. range(len(validator.display_options)),
  146. format_func=lambda i: validator.display_options[i],
  147. index=st.session_state.selected_file_index,
  148. key="file_selectbox",
  149. label_visibility="collapsed"
  150. )
  151. # 手动同步 selectbox 的值到 selected_file_index
  152. if selected_index != st.session_state.selected_file_index:
  153. st.session_state.selected_file_index = selected_index
  154. selected_file = validator.file_paths[selected_index]
  155. current_page = validator.file_info[selected_index]['page']
  156. # 初始化或同步页码值
  157. if 'page_input_value' not in st.session_state:
  158. st.session_state.page_input_value = current_page
  159. # 如果当前页码与 session_state 中的值不一致,更新 session_state
  160. # 这会在下拉框改变或通过其他方式改变文件时同步页码
  161. if current_page != st.session_state.page_input_value:
  162. st.session_state.page_input_value = current_page
  163. page_input = st.number_input(
  164. "输入页码",
  165. placeholder="输入页码",
  166. label_visibility="collapsed",
  167. min_value=1,
  168. max_value=len(validator.display_options),
  169. value=st.session_state.page_input_value,
  170. step=1,
  171. key="page_input"
  172. )
  173. # 更新 session_state 中的页码值
  174. if page_input != st.session_state.page_input_value:
  175. st.session_state.page_input_value = page_input
  176. if page_input != current_page:
  177. for i, info in enumerate(validator.file_info):
  178. if info['page'] == page_input:
  179. # 更新 selected_file_index,selectbox 会在下一个运行周期自动同步
  180. st.session_state.selected_file_index = i
  181. selected_file = validator.file_paths[i]
  182. # 同步页码值
  183. st.session_state.page_input_value = page_input
  184. st.rerun()
  185. break
  186. if (st.session_state.selected_file_index >= 0
  187. and validator.selected_file_index != st.session_state.selected_file_index
  188. and selected_file):
  189. validator.selected_file_index = st.session_state.selected_file_index
  190. st.session_state.validator.load_ocr_data(selected_file)
  191. current_source_name = get_data_source_display_name(validator.current_source_config)
  192. st.success(f"✅ 已加载 {current_source_name} - 第{validator.file_info[st.session_state.selected_file_index]['page']}页")
  193. st.rerun()
  194. else:
  195. st.warning("当前数据源中未找到OCR结果文件")
  196. # ✅ 交叉验证按钮 - 添加数据源检查
  197. cross_validation_enabled = (
  198. validator.current_source_key != validator.verify_source_key
  199. and validator.image_path
  200. and validator.md_content
  201. )
  202. if st.button(
  203. "交叉验证",
  204. type="primary",
  205. icon=":material/compare_arrows:",
  206. disabled=not cross_validation_enabled,
  207. help="需要选择不同的OCR数据源和验证数据源" if not cross_validation_enabled else "开始批量交叉验证"
  208. ):
  209. cross_validation_dialog(validator)
  210. # ✅ 查看验证结果按钮 - 检查是否有验证结果
  211. has_validation_results = (
  212. 'cross_validation_batch_result' in st.session_state
  213. and st.session_state.cross_validation_batch_result is not None
  214. )
  215. if st.button(
  216. "查看验证结果",
  217. type="secondary",
  218. icon=":material/quick_reference_all:",
  219. disabled=not has_validation_results,
  220. help="暂无验证结果,请先运行交叉验证" if not has_validation_results else "查看批量验证结果"
  221. ):
  222. show_batch_cross_validation_results_dialog()
  223. # 显示当前数据源统计信息
  224. with st.expander("OCR工具统计信息", expanded=False):
  225. stats = validator.get_statistics()
  226. col1, col2, col3, col4, col5 = st.columns(5)
  227. with col1:
  228. st.metric("📊 总文本块", stats['total_texts'])
  229. with col2:
  230. st.metric("🔗 可点击文本", stats['clickable_texts'])
  231. with col3:
  232. st.metric("❌ 标记错误", stats['marked_errors'])
  233. with col4:
  234. st.metric("✅ 准确率", f"{stats['accuracy_rate']:.1f}%")
  235. with col5:
  236. if validator.current_source_config:
  237. tool_id = validator.current_source_config['ocr_tool']
  238. # 🎯 从配置管理器获取工具名称
  239. tool_config = config_manager.get_ocr_tool(tool_id)
  240. tool_display = tool_config.name if tool_config else tool_id.upper()
  241. st.metric("🔧 OCR工具", tool_display)
  242. if stats['tool_info']:
  243. st.write("**详细信息:**", stats['tool_info'])
  244. # 🎯 显示当前文档和 OCR 结果信息
  245. if validator.current_source_config:
  246. source_name = validator.current_source_config['name']
  247. # 解析数据源名称,提取文档名(更精确的解析)
  248. parts = source_name.split('_', 1)
  249. doc_name = parts[0] if parts else source_name
  250. doc_config = config_manager.get_document(doc_name)
  251. if doc_config:
  252. st.write("**文档信息:**")
  253. st.write(f"- 文档名称: {doc_config.name}")
  254. st.write(f"- 基础目录: {doc_config.base_dir}")
  255. st.write(f"- 可用 OCR 工具: {len([r for r in doc_config.ocr_results if r.enabled])} 个")
  256. # 🎯 添加配置管理面板
  257. with st.expander("⚙️ 配置管理", expanded=False):
  258. col1, col2 = st.columns(2)
  259. with col1:
  260. st.subheader("📄 已加载文档")
  261. docs = config_manager.list_documents()
  262. for doc_name in docs:
  263. doc_config = config_manager.get_document(doc_name)
  264. enabled_count = len([r for r in doc_config.ocr_results if r.enabled])
  265. total_count = len(doc_config.ocr_results)
  266. with st.container():
  267. st.write(f"✅ **{doc_name}**")
  268. st.caption(f"📊 {enabled_count}/{total_count} 工具已启用")
  269. # 显示每个 OCR 工具的状态
  270. for ocr_result in doc_config.ocr_results:
  271. status_icon = "🟢" if ocr_result.enabled else "⚪"
  272. tool_config = config_manager.get_ocr_tool(ocr_result.tool)
  273. tool_name = tool_config.name if tool_config else ocr_result.tool
  274. st.caption(f" {status_icon} {tool_name} - {ocr_result.description or ocr_result.result_dir}")
  275. with col2:
  276. st.subheader("🔧 已加载 OCR 工具")
  277. tools = config_manager.list_ocr_tools()
  278. for tool_id in tools:
  279. tool_config = config_manager.get_ocr_tool(tool_id)
  280. with st.container():
  281. st.write(f"🔧 **{tool_config.name}**")
  282. st.caption(f"ID: `{tool_id}`")
  283. st.caption(f"描述: {tool_config.description}")
  284. tab1, tab2, tab3 = st.tabs(["📄 内容人工检查", "🔍 交叉验证结果", "📊 表格分析"])
  285. with tab1:
  286. validator.create_compact_layout(validator_config)
  287. with tab2:
  288. # ✅ 使用封装的函数显示单页交叉验证结果
  289. display_single_page_cross_validation(validator, validator_config)
  290. with tab3:
  291. st.header("📊 表格数据分析")
  292. if validator.md_content and '<table' in validator.md_content.lower():
  293. st.subheader("🔍 表格数据预览")
  294. display_html_table_as_dataframe(validator.md_content)
  295. else:
  296. st.info("当前OCR结果中没有检测到表格数据")
  297. if __name__ == "__main__":
  298. main()