streamlit_ocr_validator.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386
  1. #!/usr/bin/env python3
  2. """
  3. 基于Streamlit的OCR可视化校验工具(重构版)
  4. 提供丰富的交互组件和更好的用户体验
  5. """
  6. import streamlit as st
  7. from pathlib import Path
  8. from PIL import Image
  9. from typing import Dict, List, Optional
  10. import plotly.graph_objects as go
  11. from io import BytesIO
  12. import pandas as pd
  13. import numpy as np
  14. import plotly.express as px
  15. # 导入工具模块
  16. from ocr_validator_utils import (
  17. load_config, load_css_styles, load_ocr_data_file, process_ocr_data,
  18. draw_bbox_on_image, get_ocr_statistics, convert_html_table_to_markdown,
  19. parse_html_tables, find_available_ocr_files, create_dynamic_css,
  20. export_tables_to_excel, get_table_statistics, group_texts_by_category
  21. )
  22. from ocr_validator_layout import OCRLayoutManager
  23. class StreamlitOCRValidator:
  24. def __init__(self):
  25. self.config = load_config()
  26. self.ocr_data = []
  27. self.md_content = ""
  28. self.image_path = ""
  29. self.text_bbox_mapping = {}
  30. self.selected_text = None
  31. self.marked_errors = set()
  32. # 初始化布局管理器
  33. self.layout_manager = OCRLayoutManager(self)
  34. def setup_page_config(self):
  35. """设置页面配置"""
  36. ui_config = self.config['ui']
  37. st.set_page_config(
  38. page_title=ui_config['page_title'],
  39. page_icon=ui_config['page_icon'],
  40. layout=ui_config['layout'],
  41. initial_sidebar_state=ui_config['sidebar_state']
  42. )
  43. # 加载CSS样式
  44. css_content = load_css_styles()
  45. st.markdown(f"<style>{css_content}</style>", unsafe_allow_html=True)
  46. def load_ocr_data(self, json_path: str, md_path: Optional[str] = None, image_path: Optional[str] = None):
  47. """加载OCR相关数据"""
  48. try:
  49. self.ocr_data, self.md_content, self.image_path = load_ocr_data_file(json_path, self.config)
  50. self.process_data()
  51. except Exception as e:
  52. st.error(f"❌ 加载失败: {e}")
  53. def process_data(self):
  54. """处理OCR数据"""
  55. self.text_bbox_mapping = process_ocr_data(self.ocr_data, self.config)
  56. def get_statistics(self) -> Dict:
  57. """获取统计信息"""
  58. return get_ocr_statistics(self.ocr_data, self.text_bbox_mapping, self.marked_errors)
  59. def display_html_table_as_dataframe(self, html_content: str, enable_editing: bool = False):
  60. """将HTML表格解析为DataFrame显示"""
  61. tables = parse_html_tables(html_content)
  62. if not tables:
  63. st.warning("未找到可解析的表格")
  64. st.markdown(html_content, unsafe_allow_html=True)
  65. return
  66. for i, table in enumerate(tables):
  67. st.subheader(f"📊 表格 {i+1}")
  68. # 创建表格操作按钮
  69. col1, col2, col3, col4 = st.columns(4)
  70. with col1:
  71. show_info = st.checkbox(f"显示表格信息", key=f"info_{i}")
  72. with col2:
  73. show_stats = st.checkbox(f"显示统计信息", key=f"stats_{i}")
  74. with col3:
  75. enable_filter = st.checkbox(f"启用过滤", key=f"filter_{i}")
  76. with col4:
  77. enable_sort = st.checkbox(f"启用排序", key=f"sort_{i}")
  78. # 数据过滤和排序逻辑
  79. filtered_table = self._apply_table_filters_and_sorts(table, i, enable_filter, enable_sort)
  80. # 显示表格
  81. if enable_editing:
  82. edited_table = st.data_editor(filtered_table, use_container_width=True, key=f"editor_{i}")
  83. if not edited_table.equals(filtered_table):
  84. st.success("✏️ 表格已编辑,可以导出修改后的数据")
  85. else:
  86. st.dataframe(filtered_table, use_container_width=True)
  87. # 显示表格信息和统计
  88. self._display_table_info_and_stats(table, filtered_table, show_info, show_stats, i)
  89. st.markdown("---")
  90. def _apply_table_filters_and_sorts(self, table: pd.DataFrame, table_index: int, enable_filter: bool, enable_sort: bool) -> pd.DataFrame:
  91. """应用表格过滤和排序"""
  92. filtered_table = table.copy()
  93. # 数据过滤
  94. if enable_filter and not table.empty:
  95. filter_col = st.selectbox(
  96. f"选择过滤列 (表格 {table_index+1})",
  97. options=['无'] + list(table.columns),
  98. key=f"filter_col_{table_index}"
  99. )
  100. if filter_col != '无':
  101. filter_value = st.text_input(f"过滤值 (表格 {table_index+1})", key=f"filter_value_{table_index}")
  102. if filter_value:
  103. filtered_table = table[table[filter_col].astype(str).str.contains(filter_value, na=False)]
  104. # 数据排序
  105. if enable_sort and not filtered_table.empty:
  106. sort_col = st.selectbox(
  107. f"选择排序列 (表格 {table_index+1})",
  108. options=['无'] + list(filtered_table.columns),
  109. key=f"sort_col_{table_index}"
  110. )
  111. if sort_col != '无':
  112. sort_order = st.radio(
  113. f"排序方式 (表格 {table_index+1})",
  114. options=['升序', '降序'],
  115. horizontal=True,
  116. key=f"sort_order_{table_index}"
  117. )
  118. ascending = (sort_order == '升序')
  119. filtered_table = filtered_table.sort_values(sort_col, ascending=ascending)
  120. return filtered_table
  121. def _display_table_info_and_stats(self, original_table: pd.DataFrame, filtered_table: pd.DataFrame,
  122. show_info: bool, show_stats: bool, table_index: int):
  123. """显示表格信息和统计数据"""
  124. if show_info:
  125. st.write("**表格信息:**")
  126. st.write(f"- 原始行数: {len(original_table)}")
  127. st.write(f"- 过滤后行数: {len(filtered_table)}")
  128. st.write(f"- 列数: {len(original_table.columns)}")
  129. st.write(f"- 列名: {', '.join(original_table.columns)}")
  130. if show_stats:
  131. st.write("**统计信息:**")
  132. numeric_cols = filtered_table.select_dtypes(include=[np.number]).columns
  133. if len(numeric_cols) > 0:
  134. st.dataframe(filtered_table[numeric_cols].describe())
  135. else:
  136. st.info("表格中没有数值列")
  137. # 导出功能
  138. if st.button(f"📥 导出表格 {table_index+1}", key=f"export_{table_index}"):
  139. self._create_export_buttons(filtered_table, table_index)
  140. def _create_export_buttons(self, table: pd.DataFrame, table_index: int):
  141. """创建导出按钮"""
  142. # CSV导出
  143. csv_data = table.to_csv(index=False)
  144. st.download_button(
  145. label=f"下载CSV (表格 {table_index+1})",
  146. data=csv_data,
  147. file_name=f"table_{table_index+1}.csv",
  148. mime="text/csv",
  149. key=f"download_csv_{table_index}"
  150. )
  151. # Excel导出
  152. excel_buffer = BytesIO()
  153. table.to_excel(excel_buffer, index=False)
  154. st.download_button(
  155. label=f"下载Excel (表格 {table_index+1})",
  156. data=excel_buffer.getvalue(),
  157. file_name=f"table_{table_index+1}.xlsx",
  158. mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
  159. key=f"download_excel_{table_index}"
  160. )
  161. # 布局方法现在委托给布局管理器
  162. def create_standard_layout(self, font_size: int = 12, zoom_level: float = 1.0):
  163. """创建标准布局"""
  164. return self.layout_manager.create_standard_layout(font_size, zoom_level)
  165. def create_compact_layout(self, font_size: int = 12, zoom_level: float = 1.0):
  166. """创建滚动凑布局"""
  167. return self.layout_manager.create_compact_layout(font_size, zoom_level)
  168. def main():
  169. """主应用"""
  170. # 初始化应用
  171. if 'validator' not in st.session_state:
  172. st.session_state.validator = StreamlitOCRValidator()
  173. st.session_state.validator.setup_page_config()
  174. if 'selected_text' not in st.session_state:
  175. st.session_state.selected_text = None
  176. if 'marked_errors' not in st.session_state:
  177. st.session_state.marked_errors = set()
  178. # 同步标记的错误到validator
  179. st.session_state.validator.marked_errors = st.session_state.marked_errors
  180. # 页面标题
  181. config = st.session_state.validator.config
  182. st.title(config['ui']['page_title'])
  183. st.markdown("---")
  184. # 侧边栏 - 文件选择和控制
  185. with st.sidebar:
  186. st.header("📁 文件选择")
  187. # 查找可用的OCR文件
  188. available_files = find_available_ocr_files(config['paths']['ocr_out_dir'])
  189. if available_files:
  190. selected_file = st.selectbox("选择OCR结果文件", available_files, index=0)
  191. if st.button("🔄 加载文件", type="primary") and selected_file:
  192. st.session_state.validator.load_ocr_data(selected_file)
  193. st.success("✅ 文件加载成功!")
  194. st.rerun()
  195. else:
  196. st.warning("未找到OCR结果文件")
  197. st.info("请确保output目录下有OCR结果文件")
  198. st.markdown("---")
  199. # 控制面板
  200. st.header("🎛️ 控制面板")
  201. if st.button("🧹 清除选择"):
  202. st.session_state.selected_text = None
  203. st.rerun()
  204. if st.button("❌ 清除错误标记"):
  205. st.session_state.marked_errors = set()
  206. st.rerun()
  207. # 主内容区域
  208. validator = st.session_state.validator
  209. if not validator.ocr_data:
  210. st.info("👈 请在左侧选择并加载OCR结果文件")
  211. return
  212. # 显示统计信息
  213. stats = validator.get_statistics()
  214. col1, col2, col3, col4, col5 = st.columns(5) # 增加一列
  215. with col1:
  216. st.metric("📊 总文本块", stats['total_texts'])
  217. with col2:
  218. st.metric("🔗 可点击文本", stats['clickable_texts'])
  219. with col3:
  220. st.metric("❌ 标记错误", stats['marked_errors'])
  221. with col4:
  222. st.metric("✅ 准确率", f"{stats['accuracy_rate']:.1f}%")
  223. with col5:
  224. # 显示OCR工具信息
  225. if stats['tool_info']:
  226. tool_names = list(stats['tool_info'].keys())
  227. main_tool = tool_names[0] if tool_names else "未知"
  228. st.metric("🔧 OCR工具", main_tool)
  229. # 详细工具信息
  230. if stats['tool_info']:
  231. st.expander("🔧 OCR工具详情", expanded=False).write(stats['tool_info'])
  232. st.markdown("---")
  233. # 创建标签页
  234. tab1, tab2, tab3, tab4 = st.tabs(["📄 内容校验", "📊 表格分析", "📈 数据统计", "🚀 快速导航"])
  235. with tab1:
  236. # 顶部控制区域
  237. control_col1, control_col2 = st.columns(2)
  238. with control_col1:
  239. layout_mode = st.selectbox(
  240. "布局模式",
  241. ["标准布局", "滚动布局"],
  242. key="layout_mode"
  243. )
  244. with control_col2:
  245. font_size = st.selectbox("字体大小", [10, 12, 14, 16], index=0, key="font_size_select")
  246. # 根据选择的布局模式显示不同的界面,传递参数
  247. if layout_mode == "滚动布局":
  248. validator.create_compact_layout(font_size, 1.0)
  249. else:
  250. # 调用封装的标准布局方法
  251. validator.create_standard_layout(font_size, 1.0)
  252. with tab2:
  253. # 表格分析页面
  254. st.header("📊 表格数据分析")
  255. if validator.md_content and '<table' in validator.md_content.lower():
  256. col1, col2 = st.columns([2, 1])
  257. with col1:
  258. st.subheader("🔍 表格数据预览")
  259. validator.display_html_table_as_dataframe(validator.md_content)
  260. with col2:
  261. st.subheader("⚙️ 表格操作")
  262. if st.button("📥 导出表格数据", type="primary"):
  263. tables = parse_html_tables(validator.md_content)
  264. if tables:
  265. output = export_tables_to_excel(tables)
  266. st.download_button(
  267. label="📥 下载Excel文件",
  268. data=output.getvalue(),
  269. file_name="ocr_tables.xlsx",
  270. mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
  271. )
  272. else:
  273. st.info("当前OCR结果中没有检测到表格数据")
  274. with tab3:
  275. # 数据统计页面
  276. st.header("📈 OCR数据统计")
  277. if stats['categories']:
  278. st.subheader("📊 类别分布")
  279. fig_pie = px.pie(
  280. values=list(stats['categories'].values()),
  281. names=list(stats['categories'].keys()),
  282. title="文本类别分布"
  283. )
  284. st.plotly_chart(fig_pie, use_container_width=True)
  285. # 错误率分析
  286. st.subheader("📈 质量分析")
  287. accuracy_data = {
  288. '状态': ['正确', '错误'],
  289. '数量': [stats['clickable_texts'] - stats['marked_errors'], stats['marked_errors']]
  290. }
  291. fig_bar = px.bar(
  292. accuracy_data, x='状态', y='数量', title="识别质量分布",
  293. color='状态', color_discrete_map={'正确': 'green', '错误': 'red'}
  294. )
  295. st.plotly_chart(fig_bar, use_container_width=True)
  296. with tab4:
  297. # 快速导航功能
  298. st.header("🚀 快速导航")
  299. if not validator.text_bbox_mapping:
  300. st.info("没有可用的文本项进行导航")
  301. else:
  302. # 按类别分组
  303. categories = group_texts_by_category(validator.text_bbox_mapping)
  304. # 创建导航按钮
  305. for category, texts in categories.items():
  306. with st.expander(f"{category} ({len(texts)}项)", expanded=False):
  307. cols = st.columns(3) # 每行3个按钮
  308. for i, text in enumerate(texts):
  309. col_idx = i % 3
  310. with cols[col_idx]:
  311. display_text = text[:15] + "..." if len(text) > 15 else text
  312. if st.button(display_text, key=f"nav_{category}_{i}"):
  313. st.session_state.selected_text = text
  314. st.rerun()
  315. if __name__ == "__main__":
  316. main()