streamlit_ocr_validator.py 42 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003
  1. #!/usr/bin/env python3
  2. """
  3. 基于Streamlit的OCR可视化校验工具(重构版)
  4. 提供丰富的交互组件和更好的用户体验
  5. """
  6. import streamlit as st
  7. from pathlib import Path
  8. from PIL import Image
  9. from typing import Dict, List, Optional
  10. import plotly.graph_objects as go
  11. from io import BytesIO
  12. import pandas as pd
  13. import numpy as np
  14. import plotly.express as px
  15. import json
  16. # 导入工具模块
  17. from ocr_validator_utils import (
  18. load_config, load_css_styles, load_ocr_data_file, process_ocr_data,
  19. draw_bbox_on_image, get_ocr_statistics, convert_html_table_to_markdown,
  20. parse_html_tables, find_available_ocr_files, create_dynamic_css,
  21. export_tables_to_excel, get_table_statistics, group_texts_by_category
  22. )
  23. from ocr_validator_layout import OCRLayoutManager
  24. from ocr_by_vlm import ocr_with_vlm
  25. from compare_ocr_results import compare_ocr_results
  26. class StreamlitOCRValidator:
  27. def __init__(self):
  28. self.config = load_config()
  29. self.ocr_data = []
  30. self.md_content = ""
  31. self.image_path = ""
  32. self.text_bbox_mapping = {}
  33. self.selected_text = None
  34. self.marked_errors = set()
  35. self.file_info = []
  36. self.selected_file_index = -1 # 初始化不指向有效文件index
  37. self.display_options = []
  38. self.file_paths = []
  39. # 初始化布局管理器
  40. self.layout_manager = OCRLayoutManager(self)
  41. # 加载文件信息
  42. self.load_file_info()
  43. def load_file_info(self):
  44. # 查找可用的OCR文件
  45. self.file_info = find_available_ocr_files(self.config['paths']['ocr_out_dir'])
  46. # 初始化session_state中的选择索引
  47. if self.file_info:
  48. # 创建显示选项列表
  49. self.display_options = [f"{info['display_name']}" for info in self.file_info]
  50. self.file_paths = [info['path'] for info in self.file_info]
  51. def setup_page_config(self):
  52. """设置页面配置"""
  53. ui_config = self.config['ui']
  54. st.set_page_config(
  55. page_title=ui_config['page_title'],
  56. page_icon=ui_config['page_icon'],
  57. layout=ui_config['layout'],
  58. initial_sidebar_state=ui_config['sidebar_state']
  59. )
  60. # 加载CSS样式
  61. css_content = load_css_styles()
  62. st.markdown(f"<style>{css_content}</style>", unsafe_allow_html=True)
  63. def load_ocr_data(self, json_path: str, md_path: Optional[str] = None, image_path: Optional[str] = None):
  64. """加载OCR相关数据"""
  65. try:
  66. self.ocr_data, self.md_content, self.image_path = load_ocr_data_file(json_path, self.config)
  67. self.process_data()
  68. except Exception as e:
  69. st.error(f"❌ 加载失败: {e}")
  70. def process_data(self):
  71. """处理OCR数据"""
  72. self.text_bbox_mapping = process_ocr_data(self.ocr_data, self.config)
  73. def get_statistics(self) -> Dict:
  74. """获取统计信息"""
  75. return get_ocr_statistics(self.ocr_data, self.text_bbox_mapping, self.marked_errors)
  76. def display_html_table_as_dataframe(self, html_content: str, enable_editing: bool = False):
  77. """将HTML表格解析为DataFrame显示 - 增强版本支持横向滚动"""
  78. tables = parse_html_tables(html_content)
  79. if not tables:
  80. st.warning("未找到可解析的表格")
  81. # 对于无法解析的HTML表格,使用自定义CSS显示
  82. st.markdown("""
  83. <style>
  84. .scrollable-table {
  85. overflow-x: auto;
  86. white-space: nowrap;
  87. border: 1px solid #ddd;
  88. border-radius: 5px;
  89. margin: 10px 0;
  90. }
  91. .scrollable-table table {
  92. width: 100%;
  93. border-collapse: collapse;
  94. }
  95. .scrollable-table th, .scrollable-table td {
  96. border: 1px solid #ddd;
  97. padding: 8px;
  98. text-align: left;
  99. min-width: 100px;
  100. }
  101. .scrollable-table th {
  102. background-color: #f5f5f5;
  103. font-weight: bold;
  104. }
  105. </style>
  106. """, unsafe_allow_html=True)
  107. st.markdown(f'<div class="scrollable-table">{html_content}</div>', unsafe_allow_html=True)
  108. return
  109. for i, table in enumerate(tables):
  110. st.subheader(f"📊 表格 {i+1}")
  111. # 表格信息显示
  112. col_info1, col_info2, col_info3, col_info4 = st.columns(4)
  113. with col_info1:
  114. st.metric("行数", len(table))
  115. with col_info2:
  116. st.metric("列数", len(table.columns))
  117. with col_info3:
  118. # 检查是否有超宽表格
  119. is_wide_table = len(table.columns) > 8
  120. st.metric("表格类型", "超宽表格" if is_wide_table else "普通表格")
  121. with col_info4:
  122. # 表格操作模式选择
  123. display_mode = st.selectbox(
  124. f"显示模式 (表格{i+1})",
  125. ["完整显示", "分页显示", "筛选列显示"],
  126. key=f"display_mode_{i}"
  127. )
  128. # 创建表格操作按钮
  129. col1, col2, col3, col4 = st.columns(4)
  130. with col1:
  131. show_info = st.checkbox(f"显示详细信息", key=f"info_{i}")
  132. with col2:
  133. show_stats = st.checkbox(f"显示统计信息", key=f"stats_{i}")
  134. with col3:
  135. enable_filter = st.checkbox(f"启用过滤", key=f"filter_{i}")
  136. with col4:
  137. enable_sort = st.checkbox(f"启用排序", key=f"sort_{i}")
  138. # 根据显示模式处理表格
  139. display_table = self._process_table_display_mode(table, i, display_mode)
  140. # 数据过滤和排序逻辑
  141. filtered_table = self._apply_table_filters_and_sorts(display_table, i, enable_filter, enable_sort)
  142. # 显示表格 - 使用自定义CSS支持横向滚动
  143. st.markdown("""
  144. <style>
  145. .dataframe-container {
  146. overflow-x: auto;
  147. border: 1px solid #ddd;
  148. border-radius: 5px;
  149. margin: 10px 0;
  150. }
  151. /* 为超宽表格特殊样式 */
  152. .wide-table-container {
  153. overflow-x: auto;
  154. max-height: 500px;
  155. overflow-y: auto;
  156. border: 2px solid #0288d1;
  157. border-radius: 8px;
  158. background: linear-gradient(90deg, #f8f9fa 0%, #ffffff 100%);
  159. }
  160. .dataframe thead th {
  161. position: sticky;
  162. top: 0;
  163. background-color: #f5f5f5 !important;
  164. z-index: 10;
  165. border-bottom: 2px solid #0288d1;
  166. }
  167. .dataframe tbody td {
  168. white-space: nowrap;
  169. min-width: 100px;
  170. max-width: 300px;
  171. overflow: hidden;
  172. text-overflow: ellipsis;
  173. }
  174. </style>
  175. """, unsafe_allow_html=True)
  176. # 根据表格宽度选择显示容器
  177. container_class = "wide-table-container" if len(table.columns) > 8 else "dataframe-container"
  178. if enable_editing:
  179. st.markdown(f'<div class="{container_class}">', unsafe_allow_html=True)
  180. edited_table = st.data_editor(
  181. filtered_table,
  182. use_container_width=True,
  183. key=f"editor_{i}",
  184. height=400 if len(table.columns) > 8 else None
  185. )
  186. st.markdown('</div>', unsafe_allow_html=True)
  187. if not edited_table.equals(filtered_table):
  188. st.success("✏️ 表格已编辑,可以导出修改后的数据")
  189. else:
  190. st.markdown(f'<div class="{container_class}">', unsafe_allow_html=True)
  191. st.dataframe(
  192. filtered_table,
  193. # use_container_width=True,
  194. width =400 if len(table.columns) > 8 else "stretch"
  195. )
  196. st.markdown('</div>', unsafe_allow_html=True)
  197. # 显示表格信息和统计
  198. self._display_table_info_and_stats(table, filtered_table, show_info, show_stats, i)
  199. st.markdown("---")
  200. def _apply_table_filters_and_sorts(self, table: pd.DataFrame, table_index: int, enable_filter: bool, enable_sort: bool) -> pd.DataFrame:
  201. """应用表格过滤和排序"""
  202. filtered_table = table.copy()
  203. # 数据过滤
  204. if enable_filter and not table.empty:
  205. filter_col = st.selectbox(
  206. f"选择过滤列 (表格 {table_index+1})",
  207. options=['无'] + list(table.columns),
  208. key=f"filter_col_{table_index}"
  209. )
  210. if filter_col != '无':
  211. filter_value = st.text_input(f"过滤值 (表格 {table_index+1})", key=f"filter_value_{table_index}")
  212. if filter_value:
  213. filtered_table = table[table[filter_col].astype(str).str.contains(filter_value, na=False)]
  214. # 数据排序
  215. if enable_sort and not filtered_table.empty:
  216. sort_col = st.selectbox(
  217. f"选择排序列 (表格 {table_index+1})",
  218. options=['无'] + list(filtered_table.columns),
  219. key=f"sort_col_{table_index}"
  220. )
  221. if sort_col != '无':
  222. sort_order = st.radio(
  223. f"排序方式 (表格 {table_index+1})",
  224. options=['升序', '降序'],
  225. horizontal=True,
  226. key=f"sort_order_{table_index}"
  227. )
  228. ascending = (sort_order == '升序')
  229. filtered_table = filtered_table.sort_values(sort_col, ascending=ascending)
  230. return filtered_table
  231. def _display_table_info_and_stats(self, original_table: pd.DataFrame, filtered_table: pd.DataFrame,
  232. show_info: bool, show_stats: bool, table_index: int):
  233. """显示表格信息和统计数据"""
  234. if show_info:
  235. st.write("**表格信息:**")
  236. st.write(f"- 原始行数: {len(original_table)}")
  237. st.write(f"- 过滤后行数: {len(filtered_table)}")
  238. st.write(f"- 列数: {len(original_table.columns)}")
  239. st.write(f"- 列名: {', '.join(original_table.columns)}")
  240. if show_stats:
  241. st.write("**统计信息:**")
  242. numeric_cols = filtered_table.select_dtypes(include=[np.number]).columns
  243. if len(numeric_cols) > 0:
  244. st.dataframe(filtered_table[numeric_cols].describe())
  245. else:
  246. st.info("表格中没有数值列")
  247. # 导出功能
  248. if st.button(f"📥 导出表格 {table_index+1}", key=f"export_{table_index}"):
  249. self._create_export_buttons(filtered_table, table_index)
  250. def _create_export_buttons(self, table: pd.DataFrame, table_index: int):
  251. """创建导出按钮"""
  252. # CSV导出
  253. csv_data = table.to_csv(index=False)
  254. st.download_button(
  255. label=f"下载CSV (表格 {table_index+1})",
  256. data=csv_data,
  257. file_name=f"table_{table_index+1}.csv",
  258. mime="text/csv",
  259. key=f"download_csv_{table_index}"
  260. )
  261. # Excel导出
  262. excel_buffer = BytesIO()
  263. table.to_excel(excel_buffer, index=False)
  264. st.download_button(
  265. label=f"下载Excel (表格 {table_index+1})",
  266. data=excel_buffer.getvalue(),
  267. file_name=f"table_{table_index+1}.xlsx",
  268. mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
  269. key=f"download_excel_{table_index}"
  270. )
  271. def _process_table_display_mode(self, table: pd.DataFrame, table_index: int, display_mode: str) -> pd.DataFrame:
  272. """根据显示模式处理表格"""
  273. if display_mode == "分页显示":
  274. # 分页显示
  275. page_size = st.selectbox(
  276. f"每页显示行数 (表格 {table_index+1})",
  277. [10, 20, 50, 100],
  278. key=f"page_size_{table_index}"
  279. )
  280. total_pages = (len(table) - 1) // page_size + 1
  281. if total_pages > 1:
  282. page_number = st.selectbox(
  283. f"页码 (表格 {table_index+1})",
  284. range(1, total_pages + 1),
  285. key=f"page_number_{table_index}"
  286. )
  287. start_idx = (page_number - 1) * page_size
  288. end_idx = start_idx + page_size
  289. return table.iloc[start_idx:end_idx]
  290. return table
  291. elif display_mode == "筛选列显示":
  292. # 列筛选显示
  293. if len(table.columns) > 5:
  294. selected_columns = st.multiselect(
  295. f"选择要显示的列 (表格 {table_index+1})",
  296. table.columns.tolist(),
  297. default=table.columns.tolist()[:5], # 默认显示前5列
  298. key=f"selected_columns_{table_index}"
  299. )
  300. if selected_columns:
  301. return table[selected_columns]
  302. return table
  303. else: # 完整显示
  304. return table
  305. @st.dialog("VLM预校验", width="large", dismissible=True, on_dismiss="rerun")
  306. def vlm_pre_validation(self):
  307. """VLM预校验功能 - 封装OCR识别和结果对比"""
  308. if not self.image_path or not self.md_content:
  309. st.error("❌ 请先加载OCR数据文件")
  310. return
  311. # 初始化对比结果存储
  312. if 'comparison_result' not in st.session_state:
  313. st.session_state.comparison_result = None
  314. # 创建进度条和状态显示
  315. with st.spinner("正在进行VLM预校验...", show_time=True):
  316. status_text = st.empty()
  317. try:
  318. current_md_path = Path(self.file_paths[self.selected_file_index]).with_suffix('.md')
  319. if not current_md_path.exists():
  320. st.error("❌ 当前OCR结果的Markdown文件不存在,无法进行对比")
  321. return
  322. # 第一步:准备目录
  323. pre_validation_dir = Path(self.config['paths'].get('pre_validation_dir', './output/pre_validation/')).resolve()
  324. pre_validation_dir.mkdir(parents=True, exist_ok=True)
  325. status_text.write(f"工作目录: {pre_validation_dir}")
  326. # 第二步:调用VLM进行OCR识别
  327. status_text.text("🤖 正在调用VLM进行OCR识别...")
  328. # 在expander中显示OCR过程
  329. with st.expander("🔍 VLM OCR识别过程", expanded=True):
  330. ocr_output = st.empty()
  331. # 捕获OCR输出
  332. import io
  333. import contextlib
  334. # 创建字符串缓冲区来捕获print输出
  335. output_buffer = io.StringIO()
  336. with contextlib.redirect_stdout(output_buffer):
  337. ocr_result = ocr_with_vlm(
  338. image_path=str(self.image_path),
  339. output_dir=str(pre_validation_dir),
  340. normalize_numbers=True
  341. )
  342. # 显示OCR过程输出
  343. ocr_output.code(output_buffer.getvalue(), language='text')
  344. status_text.text("✅ VLM OCR识别完成")
  345. # 第三步:获取VLM生成的文件路径
  346. vlm_md_path = pre_validation_dir / f"{Path(self.image_path).stem}.md"
  347. if not vlm_md_path.exists():
  348. st.error("❌ VLM OCR结果文件未生成")
  349. return
  350. # 第四步:调用对比功能
  351. status_text.text("📊 正在对比OCR结果...")
  352. # 在expander中显示对比过程
  353. comparison_result_path = pre_validation_dir / f"{current_md_path.stem}_comparison_result"
  354. with st.expander("🔍 OCR结果对比过程", expanded=True):
  355. compare_output = st.empty()
  356. # 捕获对比输出
  357. output_buffer = io.StringIO()
  358. with contextlib.redirect_stdout(output_buffer):
  359. comparison_result = compare_ocr_results(
  360. file1_path=str(current_md_path),
  361. file2_path=str(vlm_md_path),
  362. output_file=str(comparison_result_path),
  363. output_format='both',
  364. ignore_images=True
  365. )
  366. # 显示对比过程输出
  367. compare_output.code(output_buffer.getvalue(), language='text')
  368. status_text.text("✅ VLM预校验完成")
  369. st.session_state.comparison_result = {
  370. "image_path": self.image_path,
  371. "comparison_result_json": f"{comparison_result_path}.json",
  372. "comparison_result_md": f"{comparison_result_path}.md",
  373. "comparison_result": comparison_result
  374. }
  375. # 第五步:显示对比结果
  376. self.display_comparison_results(comparison_result)
  377. # 第六步:提供文件下载
  378. # self.provide_download_options(pre_validation_dir, vlm_md_path, comparison_result)
  379. except Exception as e:
  380. st.error(f"❌ VLM预校验失败: {e}")
  381. st.exception(e)
  382. def display_comparison_results(self, comparison_result: dict):
  383. """显示对比结果摘要 - 使用DataFrame展示"""
  384. st.header("📊 VLM预校验结果")
  385. # 统计信息
  386. stats = comparison_result['statistics']
  387. # 统计信息概览
  388. col1, col2, col3, col4 = st.columns(4)
  389. with col1:
  390. st.metric("总差异数", stats['total_differences'])
  391. with col2:
  392. st.metric("表格差异", stats['table_differences'])
  393. with col3:
  394. st.metric("金额差异", stats['amount_differences'])
  395. with col4:
  396. st.metric("段落差异", stats['paragraph_differences'])
  397. # 结果判断
  398. if stats['total_differences'] == 0:
  399. st.success("🎉 完美匹配!VLM识别结果与原OCR结果完全一致")
  400. else:
  401. st.warning(f"⚠️ 发现 {stats['total_differences']} 个差异,建议人工检查")
  402. # 使用DataFrame显示差异详情
  403. if comparison_result['differences']:
  404. st.subheader("🔍 差异详情对比")
  405. # 准备DataFrame数据
  406. diff_data = []
  407. for i, diff in enumerate(comparison_result['differences'], 1):
  408. diff_data.append({
  409. '序号': i,
  410. '位置': diff['position'],
  411. '类型': diff['type'],
  412. '原OCR结果': diff['file1_value'][:100] + ('...' if len(diff['file1_value']) > 100 else ''),
  413. 'VLM识别结果': diff['file2_value'][:100] + ('...' if len(diff['file2_value']) > 100 else ''),
  414. '描述': diff['description'][:80] + ('...' if len(diff['description']) > 80 else ''),
  415. '严重程度': self._get_severity_level(diff)
  416. })
  417. # 创建DataFrame
  418. df_differences = pd.DataFrame(diff_data)
  419. # 添加样式
  420. def highlight_severity(val):
  421. """根据严重程度添加颜色"""
  422. if val == '高':
  423. return 'background-color: #ffebee; color: #c62828'
  424. elif val == '中':
  425. return 'background-color: #fff3e0; color: #ef6c00'
  426. elif val == '低':
  427. return 'background-color: #e8f5e8; color: #2e7d32'
  428. return ''
  429. # 显示DataFrame
  430. styled_df = df_differences.style.applymap(
  431. highlight_severity,
  432. subset=['严重程度']
  433. ).format({
  434. '序号': '{:d}',
  435. })
  436. st.dataframe(
  437. styled_df,
  438. use_container_width=True,
  439. height=400,
  440. hide_index=True,
  441. column_config={
  442. "序号": st.column_config.NumberColumn(
  443. "序号",
  444. width=None, # 自动调整宽度
  445. pinned=True,
  446. help="差异项序号"
  447. ),
  448. "位置": st.column_config.TextColumn(
  449. "位置",
  450. width=None, # 自动调整宽度
  451. pinned=True,
  452. help="差异在文档中的位置"
  453. ),
  454. "类型": st.column_config.TextColumn(
  455. "类型",
  456. width=None, # 自动调整宽度
  457. pinned=True,
  458. help="差异类型"
  459. ),
  460. "原OCR结果": st.column_config.TextColumn(
  461. "原OCR结果",
  462. width="large", # 自动调整宽度
  463. pinned=True,
  464. help="原始OCR识别结果"
  465. ),
  466. "VLM识别结果": st.column_config.TextColumn(
  467. "VLM识别结果",
  468. width="large", # 自动调整宽度
  469. help="VLM重新识别的结果"
  470. ),
  471. "描述": st.column_config.TextColumn(
  472. "描述",
  473. width="medium", # 自动调整宽度
  474. help="差异详细描述"
  475. ),
  476. "严重程度": st.column_config.TextColumn(
  477. "严重程度",
  478. width=None, # 自动调整宽度
  479. help="差异严重程度评级"
  480. )
  481. }
  482. )
  483. # 详细差异查看
  484. st.subheader("🔍 详细差异查看")
  485. # 选择要查看的差异
  486. selected_diff_index = st.selectbox(
  487. "选择要查看的差异:",
  488. options=range(len(comparison_result['differences'])),
  489. format_func=lambda x: f"差异 {x+1}: {comparison_result['differences'][x]['position']} - {comparison_result['differences'][x]['type']}",
  490. key="selected_diff"
  491. )
  492. if selected_diff_index is not None:
  493. diff = comparison_result['differences'][selected_diff_index]
  494. # 并排显示完整内容
  495. col1, col2 = st.columns(2)
  496. with col1:
  497. st.write("**原OCR结果:**")
  498. st.text_area(
  499. "原OCR结果详情",
  500. value=diff['file1_value'],
  501. height=200,
  502. key=f"original_{selected_diff_index}",
  503. label_visibility="collapsed"
  504. )
  505. with col2:
  506. st.write("**VLM识别结果:**")
  507. st.text_area(
  508. "VLM识别结果详情",
  509. value=diff['file2_value'],
  510. height=200,
  511. key=f"vlm_{selected_diff_index}",
  512. label_visibility="collapsed"
  513. )
  514. # 差异详细信息
  515. st.info(f"**位置:** {diff['position']}")
  516. st.info(f"**类型:** {diff['type']}")
  517. st.info(f"**描述:** {diff['description']}")
  518. st.info(f"**严重程度:** {self._get_severity_level(diff)}")
  519. # 差异统计图表
  520. st.subheader("📈 差异类型分布")
  521. # 按类型统计差异
  522. type_counts = {}
  523. severity_counts = {'高': 0, '中': 0, '低': 0}
  524. for diff in comparison_result['differences']:
  525. diff_type = diff['type']
  526. type_counts[diff_type] = type_counts.get(diff_type, 0) + 1
  527. severity = self._get_severity_level(diff)
  528. severity_counts[severity] += 1
  529. col1, col2 = st.columns(2)
  530. with col1:
  531. # 类型分布饼图
  532. if type_counts:
  533. fig_type = px.pie(
  534. values=list(type_counts.values()),
  535. names=list(type_counts.keys()),
  536. title="差异类型分布"
  537. )
  538. st.plotly_chart(fig_type, use_container_width=True)
  539. with col2:
  540. # 严重程度分布条形图
  541. fig_severity = px.bar(
  542. x=list(severity_counts.keys()),
  543. y=list(severity_counts.values()),
  544. title="差异严重程度分布",
  545. color=list(severity_counts.keys()),
  546. color_discrete_map={'高': '#f44336', '中': '#ff9800', '低': '#4caf50'}
  547. )
  548. st.plotly_chart(fig_severity, use_container_width=True)
  549. # 下载选项
  550. self._provide_download_options_in_results(comparison_result)
  551. def _get_severity_level(self, diff: dict) -> str:
  552. """根据差异类型和内容判断严重程度"""
  553. # 如果差异中已经包含严重程度,直接使用
  554. if 'severity' in diff:
  555. severity_map = {'high': '高', 'medium': '中', 'low': '低'}
  556. return severity_map.get(diff['severity'], '中')
  557. # 原有的逻辑作为后备
  558. diff_type = diff['type'].lower()
  559. # 金额相关差异为高严重程度
  560. if 'amount' in diff_type or 'number' in diff_type:
  561. return '高'
  562. # 表格结构差异为中等严重程度
  563. if 'table' in diff_type or 'structure' in diff_type:
  564. return '中'
  565. # 检查相似度
  566. if 'similarity' in diff:
  567. similarity = diff['similarity']
  568. if similarity < 50:
  569. return '高'
  570. elif similarity < 85:
  571. return '中'
  572. else:
  573. return '低'
  574. # 检查内容长度差异
  575. len_diff = abs(len(diff['file1_value']) - len(diff['file2_value']))
  576. if len_diff > 50:
  577. return '高'
  578. elif len_diff > 10:
  579. return '中'
  580. else:
  581. return '低'
  582. def _provide_download_options_in_results(self, comparison_result: dict):
  583. """在结果页面提供下载选项"""
  584. st.subheader("📥 导出预校验结果")
  585. col1, col2, col3 = st.columns(3)
  586. with col1:
  587. # 导出差异详情为Excel
  588. if comparison_result['differences']:
  589. diff_data = []
  590. for i, diff in enumerate(comparison_result['differences'], 1):
  591. diff_data.append({
  592. '序号': i,
  593. '位置': diff['position'],
  594. '类型': diff['type'],
  595. '原OCR结果': diff['file1_value'],
  596. 'VLM识别结果': diff['file2_value'],
  597. '描述': diff['description'],
  598. '严重程度': self._get_severity_level(diff)
  599. })
  600. df_export = pd.DataFrame(diff_data)
  601. excel_buffer = BytesIO()
  602. df_export.to_excel(excel_buffer, index=False, sheet_name='差异详情')
  603. st.download_button(
  604. label="📊 下载差异详情(Excel)",
  605. data=excel_buffer.getvalue(),
  606. file_name=f"vlm_comparison_differences_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.xlsx",
  607. mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
  608. key="download_differences_excel"
  609. )
  610. with col2:
  611. # 导出统计报告
  612. stats_data = {
  613. '统计项目': ['总差异数', '表格差异', '金额差异', '段落差异'],
  614. '数量': [
  615. comparison_result['statistics']['total_differences'],
  616. comparison_result['statistics']['table_differences'],
  617. comparison_result['statistics']['amount_differences'],
  618. comparison_result['statistics']['paragraph_differences']
  619. ]
  620. }
  621. df_stats = pd.DataFrame(stats_data)
  622. csv_stats = df_stats.to_csv(index=False)
  623. st.download_button(
  624. label="📈 下载统计报告(CSV)",
  625. data=csv_stats,
  626. file_name=f"vlm_comparison_stats_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.csv",
  627. mime="text/csv",
  628. key="download_stats_csv"
  629. )
  630. with col3:
  631. # 导出完整报告为JSON
  632. import json
  633. report_json = json.dumps(comparison_result, ensure_ascii=False, indent=2)
  634. st.download_button(
  635. label="📄 下载完整报告(JSON)",
  636. data=report_json,
  637. file_name=f"vlm_comparison_full_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.json",
  638. mime="application/json",
  639. key="download_full_json"
  640. )
  641. # 操作建议
  642. st.subheader("🚀 后续操作建议")
  643. total_diffs = comparison_result['statistics']['total_differences']
  644. if total_diffs == 0:
  645. st.success("✅ VLM识别结果与原OCR完全一致,可信度很高,无需人工校验")
  646. elif total_diffs <= 5:
  647. st.warning("⚠️ 发现少量差异,建议重点检查高严重程度的差异项")
  648. elif total_diffs <= 20:
  649. st.warning("🔍 发现中等数量差异,建议详细检查差异表格中标红的项目")
  650. else:
  651. st.error("❌ 发现大量差异,建议重新进行OCR识别或检查原始图片质量")
  652. @st.dialog("查看预校验结果", width="large", dismissible=True, on_dismiss="rerun")
  653. def show_comparison_results_dialog(self):
  654. """显示VLM预校验结果的对话框"""
  655. current_md_path = Path(self.file_paths[self.selected_file_index]).with_suffix('.md')
  656. pre_validation_dir = Path(self.config['paths'].get('pre_validation_dir', './output/pre_validation/')).resolve()
  657. comparison_result_path = pre_validation_dir / f"{current_md_path.stem}_comparison_result.json"
  658. if 'comparison_result' in st.session_state and st.session_state.comparison_result:
  659. self.display_comparison_results(st.session_state.comparison_result['comparison_result'])
  660. elif comparison_result_path.exists():
  661. # 如果pre_validation_dir下有结果文件,提示用户加载
  662. if st.button("加载预校验结果"):
  663. with open(comparison_result_path, "r", encoding="utf-8") as f:
  664. comparison_json_result = json.load(f)
  665. comparison_result = {
  666. "image_path": self.image_path,
  667. "comparison_result_json": str(comparison_result_path),
  668. "comparison_result_md": str(comparison_result_path.with_suffix('.md')),
  669. "comparison_result": comparison_json_result
  670. }
  671. st.session_state.comparison_result = comparison_result
  672. self.display_comparison_results(comparison_json_result)
  673. else:
  674. st.info("暂无预校验结果,请先运行VLM预校验")
  675. def create_compact_layout(self, config):
  676. """创建滚动凑布局"""
  677. return self.layout_manager.create_compact_layout(config)
  678. @st.dialog("message", width="small", dismissible=True, on_dismiss="rerun")
  679. def message_box(msg: str, msg_type: str = "info"):
  680. if msg_type == "info":
  681. st.info(msg)
  682. elif msg_type == "warning":
  683. st.warning(msg)
  684. elif msg_type == "error":
  685. st.error(msg)
  686. def main():
  687. """主应用"""
  688. # 初始化应用
  689. if 'validator' not in st.session_state:
  690. validator = StreamlitOCRValidator()
  691. st.session_state.validator = validator
  692. st.session_state.validator.setup_page_config()
  693. # 页面标题
  694. config = st.session_state.validator.config
  695. st.title(config['ui']['page_title'])
  696. # st.markdown("---")
  697. else:
  698. # 主内容区域
  699. validator = st.session_state.validator
  700. config = st.session_state.validator.config
  701. if 'selected_text' not in st.session_state:
  702. st.session_state.selected_text = None
  703. if 'marked_errors' not in st.session_state:
  704. st.session_state.marked_errors = set()
  705. with st.container(height=75, horizontal=True, horizontal_alignment='left', gap="medium"):
  706. # st.subheader("📁 文件选择")
  707. # 初始化session_state中的选择索引
  708. if 'selected_file_index' not in st.session_state:
  709. st.session_state.selected_file_index = 0
  710. if validator.display_options:
  711. # 创建显示选项列表
  712. selected_index = st.selectbox("选择OCR结果文件",
  713. range(len(validator.display_options)),
  714. format_func=lambda i: validator.display_options[i],
  715. index=st.session_state.selected_file_index,
  716. width=100,
  717. key="selected_selectbox",
  718. label_visibility="collapsed")
  719. # 更新session_state
  720. if selected_index != st.session_state.selected_file_index:
  721. st.session_state.selected_file_index = selected_index
  722. selected_file = validator.file_paths[selected_index]
  723. # number_input, 范围是文件数量,默认值是1,步长是1
  724. # 页码输入器
  725. current_page = validator.file_info[selected_index]['page']
  726. page_input = st.number_input("输入一个数字",
  727. placeholder="输入页码",
  728. width=200,
  729. label_visibility="collapsed",
  730. min_value=1, max_value=len(validator.display_options), value=current_page, step=1,
  731. key="page_input"
  732. )
  733. # 当页码输入改变时,更新文件选择
  734. if page_input != current_page:
  735. # 查找对应页码的文件索引
  736. for i, info in enumerate(validator.file_info):
  737. if info['page'] == page_input:
  738. st.session_state.selected_file_index = i
  739. selected_file = validator.file_paths[i]
  740. st.rerun()
  741. break
  742. if (st.session_state.selected_file_index >= 0
  743. and validator.selected_file_index != st.session_state.selected_file_index
  744. and selected_file):
  745. validator.selected_file_index = st.session_state.selected_file_index
  746. st.session_state.validator.load_ocr_data(selected_file)
  747. st.success(f"✅ 已加载第{validator.file_info[st.session_state.selected_file_index]['page']}页")
  748. st.rerun()
  749. # if st.button("🔄 加载文件", type="secondary") and selected_file:
  750. # st.session_state.validator.load_ocr_data(selected_file)
  751. # st.success(f"✅ 已加载第{validator.file_info[selected_index]['page']}页")
  752. # st.rerun()
  753. else:
  754. st.warning("未找到OCR结果文件")
  755. st.info("请确保output目录下有OCR结果文件")
  756. if st.button("VLM预校验", type="primary", icon=":material/compare_arrows:"):
  757. if validator.image_path and validator.md_content:
  758. # 创建新的页面区域来显示VLM预校验结果
  759. validator.vlm_pre_validation()
  760. else:
  761. message_box("❌ 请先加载OCR数据文件", "error")
  762. if st.button("查看预校验结果", type="secondary", icon=":material/quick_reference_all:"):
  763. validator.show_comparison_results_dialog()
  764. with st.expander("🔧 OCR工具统计信息", expanded=False):
  765. # 显示统计信息
  766. stats = validator.get_statistics()
  767. col1, col2, col3, col4, col5 = st.columns(5) # 增加一列
  768. with col1:
  769. st.metric("📊 总文本块", stats['total_texts'])
  770. with col2:
  771. st.metric("🔗 可点击文本", stats['clickable_texts'])
  772. with col3:
  773. st.metric("❌ 标记错误", stats['marked_errors'])
  774. with col4:
  775. st.metric("✅ 准确率", f"{stats['accuracy_rate']:.1f}%")
  776. with col5:
  777. # 显示OCR工具信息
  778. if stats['tool_info']:
  779. tool_names = list(stats['tool_info'].keys())
  780. main_tool = tool_names[0] if tool_names else "未知"
  781. st.metric("🔧 OCR工具", main_tool)
  782. # 详细工具信息
  783. if stats['tool_info']:
  784. st.write(stats['tool_info'])
  785. # st.markdown("---")
  786. # 创建标签页
  787. tab1, tab2, tab3, tab4 = st.tabs(["📄 内容校验", "📊 表格分析", "📈 数据统计", "🚀 快速导航"])
  788. with tab1:
  789. validator.create_compact_layout(config)
  790. with tab2:
  791. # 表格分析页面
  792. st.header("📊 表格数据分析")
  793. if validator.md_content and '<table' in validator.md_content.lower():
  794. col1, col2 = st.columns([2, 1])
  795. with col1:
  796. st.subheader("🔍 表格数据预览")
  797. validator.display_html_table_as_dataframe(validator.md_content)
  798. with col2:
  799. st.subheader("⚙️ 表格操作")
  800. if st.button("📥 导出表格数据", type="primary"):
  801. tables = parse_html_tables(validator.md_content)
  802. if tables:
  803. output = export_tables_to_excel(tables)
  804. st.download_button(
  805. label="📥 下载Excel文件",
  806. data=output.getvalue(),
  807. file_name="ocr_tables.xlsx",
  808. mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
  809. )
  810. else:
  811. st.info("当前OCR结果中没有检测到表格数据")
  812. with tab3:
  813. # 数据统计页面
  814. st.header("📈 OCR数据统计")
  815. if stats['categories']:
  816. st.subheader("📊 类别分布")
  817. fig_pie = px.pie(
  818. values=list(stats['categories'].values()),
  819. names=list(stats['categories'].keys()),
  820. title="文本类别分布"
  821. )
  822. st.plotly_chart(fig_pie, use_container_width=True)
  823. # 错误率分析
  824. st.subheader("📈 质量分析")
  825. accuracy_data = {
  826. '状态': ['正确', '错误'],
  827. '数量': [stats['clickable_texts'] - stats['marked_errors'], stats['marked_errors']]
  828. }
  829. fig_bar = px.bar(
  830. accuracy_data, x='状态', y='数量', title="识别质量分布",
  831. color='状态', color_discrete_map={'正确': 'green', '错误': 'red'}
  832. )
  833. st.plotly_chart(fig_bar, use_container_width=True)
  834. with tab4:
  835. # 快速导航功能
  836. st.header("🚀 快速导航")
  837. if not validator.text_bbox_mapping:
  838. st.info("没有可用的文本项进行导航")
  839. else:
  840. # 按类别分组
  841. categories = group_texts_by_category(validator.text_bbox_mapping)
  842. # 创建导航按钮
  843. for category, texts in categories.items():
  844. with st.expander(f"{category} ({len(texts)}项)", expanded=False):
  845. cols = st.columns(3) # 每行3个按钮
  846. for i, text in enumerate(texts):
  847. col_idx = i % 3
  848. with cols[col_idx]:
  849. display_text = text[:15] + "..." if len(text) > 15 else text
  850. if st.button(display_text, key=f"nav_{category}_{i}"):
  851. st.session_state.selected_text = text
  852. st.rerun()
  853. if __name__ == "__main__":
  854. main()