streamlit_ocr_validator.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. #!/usr/bin/env python3
  2. """
  3. 基于Streamlit的OCR可视化校验工具(主入口)
  4. """
  5. import streamlit as st
  6. from pathlib import Path
  7. import json
  8. from streamlit_validator_core import StreamlitOCRValidator
  9. from streamlit_validator_ui import (
  10. setup_page_config, create_data_source_selector, message_box
  11. )
  12. from streamlit_validator_table import display_html_table_as_dataframe
  13. from streamlit_validator_cross import (
  14. cross_validation_dialog, show_batch_cross_validation_results_dialog
  15. )
  16. from streamlit_validator_result import display_single_page_cross_validation
  17. from ocr_validator_utils import get_data_source_display_name
  18. def reset_cross_validation_results():
  19. """重置交叉验证结果"""
  20. if 'cross_validation_batch_result' in st.session_state:
  21. st.session_state.cross_validation_batch_result = None
  22. print("🔄 数据源已变更,交叉验证结果已清空")
  23. def main():
  24. """主应用"""
  25. # 初始化应用
  26. if 'validator' not in st.session_state:
  27. validator = StreamlitOCRValidator()
  28. st.session_state.validator = validator
  29. setup_page_config(validator.config)
  30. # 页面标题
  31. config = st.session_state.validator.config
  32. st.title(config['ui']['page_title'])
  33. # 初始化数据源追踪
  34. st.session_state.current_ocr_source = validator.current_source_key
  35. st.session_state.current_verify_source = validator.verify_source_key
  36. else:
  37. validator = st.session_state.validator
  38. config = st.session_state.validator.config
  39. if 'selected_text' not in st.session_state:
  40. st.session_state.selected_text = None
  41. if 'marked_errors' not in st.session_state:
  42. st.session_state.marked_errors = set()
  43. # 数据源选择器
  44. create_data_source_selector(validator)
  45. # ✅ 检测数据源是否变更
  46. ocr_source_changed = False
  47. verify_source_changed = False
  48. if 'current_ocr_source' in st.session_state:
  49. if st.session_state.current_ocr_source != validator.current_source_key:
  50. ocr_source_changed = True
  51. st.session_state.current_ocr_source = validator.current_source_key
  52. print(f"🔄 OCR数据源已切换到: {validator.current_source_key}")
  53. if 'current_verify_source' in st.session_state:
  54. if st.session_state.current_verify_source != validator.verify_source_key:
  55. verify_source_changed = True
  56. st.session_state.current_verify_source = validator.verify_source_key
  57. print(f"🔄 验证数据源已切换到: {validator.verify_source_key}")
  58. # ✅ 如果任一数据源变更,清空交叉验证结果
  59. if ocr_source_changed or verify_source_changed:
  60. reset_cross_validation_results()
  61. # 显示提示信息
  62. if ocr_source_changed and verify_source_changed:
  63. st.info("ℹ️ OCR数据源和验证数据源已变更,请重新运行交叉验证")
  64. elif ocr_source_changed:
  65. st.info("ℹ️ OCR数据源已变更,请重新运行交叉验证")
  66. elif verify_source_changed:
  67. st.info("ℹ️ 验证数据源已变更,请重新运行交叉验证")
  68. # 如果没有可用的数据源,提前返回
  69. if not validator.all_sources:
  70. st.stop()
  71. # 文件选择区域
  72. with st.container(height=75, horizontal=True, horizontal_alignment='left', gap="medium"):
  73. if 'selected_file_index' not in st.session_state:
  74. st.session_state.selected_file_index = 0
  75. if validator.display_options:
  76. selected_index = st.selectbox(
  77. "选择OCR结果文件",
  78. range(len(validator.display_options)),
  79. format_func=lambda i: validator.display_options[i],
  80. index=st.session_state.selected_file_index,
  81. key="selected_selectbox",
  82. label_visibility="collapsed"
  83. )
  84. if selected_index != st.session_state.selected_file_index:
  85. st.session_state.selected_file_index = selected_index
  86. selected_file = validator.file_paths[selected_index]
  87. current_page = validator.file_info[selected_index]['page']
  88. page_input = st.number_input(
  89. "输入页码",
  90. placeholder="输入页码",
  91. label_visibility="collapsed",
  92. min_value=1,
  93. max_value=len(validator.display_options),
  94. value=current_page,
  95. step=1,
  96. key="page_input"
  97. )
  98. if page_input != current_page:
  99. for i, info in enumerate(validator.file_info):
  100. if info['page'] == page_input:
  101. st.session_state.selected_file_index = i
  102. selected_file = validator.file_paths[i]
  103. st.rerun()
  104. break
  105. if (st.session_state.selected_file_index >= 0
  106. and validator.selected_file_index != st.session_state.selected_file_index
  107. and selected_file):
  108. validator.selected_file_index = st.session_state.selected_file_index
  109. st.session_state.validator.load_ocr_data(selected_file)
  110. current_source_name = get_data_source_display_name(validator.current_source_config)
  111. st.success(f"✅ 已加载 {current_source_name} - 第{validator.file_info[st.session_state.selected_file_index]['page']}页")
  112. st.rerun()
  113. else:
  114. st.warning("当前数据源中未找到OCR结果文件")
  115. # ✅ 交叉验证按钮 - 添加数据源检查
  116. cross_validation_enabled = (
  117. validator.current_source_key != validator.verify_source_key
  118. and validator.image_path
  119. and validator.md_content
  120. )
  121. if st.button(
  122. "交叉验证",
  123. type="primary",
  124. icon=":material/compare_arrows:",
  125. disabled=not cross_validation_enabled,
  126. help="需要选择不同的OCR数据源和验证数据源" if not cross_validation_enabled else "开始批量交叉验证"
  127. ):
  128. cross_validation_dialog(validator)
  129. # ✅ 查看验证结果按钮 - 检查是否有验证结果
  130. has_validation_results = (
  131. 'cross_validation_batch_result' in st.session_state
  132. and st.session_state.cross_validation_batch_result is not None
  133. )
  134. if st.button(
  135. "查看验证结果",
  136. type="secondary",
  137. icon=":material/quick_reference_all:",
  138. disabled=not has_validation_results,
  139. help="暂无验证结果,请先运行交叉验证" if not has_validation_results else "查看批量验证结果"
  140. ):
  141. show_batch_cross_validation_results_dialog()
  142. # 显示当前数据源统计信息
  143. with st.expander("🔧 OCR工具统计信息", expanded=False):
  144. stats = validator.get_statistics()
  145. col1, col2, col3, col4, col5 = st.columns(5)
  146. with col1:
  147. st.metric("📊 总文本块", stats['total_texts'])
  148. with col2:
  149. st.metric("🔗 可点击文本", stats['clickable_texts'])
  150. with col3:
  151. st.metric("❌ 标记错误", stats['marked_errors'])
  152. with col4:
  153. st.metric("✅ 准确率", f"{stats['accuracy_rate']:.1f}%")
  154. with col5:
  155. if validator.current_source_config:
  156. tool_display = validator.current_source_config['ocr_tool'].upper()
  157. st.metric("🔧 OCR工具", tool_display)
  158. if stats['tool_info']:
  159. st.write("**详细信息:**", stats['tool_info'])
  160. tab1, tab2, tab3 = st.tabs(["📄 内容人工检查", "🔍 交叉验证结果", "📊 表格分析"])
  161. with tab1:
  162. validator.create_compact_layout(config)
  163. with tab2:
  164. # ✅ 使用封装的函数显示单页交叉验证结果
  165. display_single_page_cross_validation(validator, config)
  166. with tab3:
  167. st.header("📊 表格数据分析")
  168. if validator.md_content and '<table' in validator.md_content.lower():
  169. st.subheader("🔍 表格数据预览")
  170. display_html_table_as_dataframe(validator.md_content)
  171. else:
  172. st.info("当前OCR结果中没有检测到表格数据")
  173. if __name__ == "__main__":
  174. main()