streamlit_ocr_validator.py 50 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173
  1. #!/usr/bin/env python3
  2. """
  3. 基于Streamlit的OCR可视化校验工具(重构版)
  4. 提供丰富的交互组件和更好的用户体验
  5. """
  6. import streamlit as st
  7. from pathlib import Path
  8. from PIL import Image
  9. from typing import Dict, List, Optional
  10. import plotly.graph_objects as go
  11. from io import BytesIO
  12. import pandas as pd
  13. import numpy as np
  14. import plotly.express as px
  15. import json
  16. # 导入工具模块
  17. from ocr_validator_utils import (
  18. load_config, load_ocr_data_file, process_ocr_data,
  19. get_ocr_statistics,
  20. find_available_ocr_files,
  21. group_texts_by_category,
  22. find_available_ocr_files_multi_source, get_data_source_display_name
  23. )
  24. from ocr_validator_file_utils import (
  25. load_css_styles,
  26. draw_bbox_on_image,
  27. convert_html_table_to_markdown,
  28. parse_html_tables,
  29. create_dynamic_css,
  30. export_tables_to_excel,
  31. get_table_statistics,
  32. )
  33. from ocr_validator_layout import OCRLayoutManager
  34. from ocr_by_vlm import ocr_with_vlm
  35. from compare_ocr_results import compare_ocr_results
  36. class StreamlitOCRValidator:
  37. def __init__(self):
  38. self.config = load_config()
  39. self.ocr_data = []
  40. self.md_content = ""
  41. self.image_path = ""
  42. self.text_bbox_mapping = {}
  43. self.selected_text = None
  44. self.marked_errors = set()
  45. # 多数据源相关
  46. self.all_sources = {}
  47. self.current_source_key = None
  48. self.current_source_config = None
  49. self.file_info = []
  50. self.selected_file_index = -1
  51. self.display_options = []
  52. self.file_paths = []
  53. # ✅ 新增:交叉验证数据源
  54. self.verify_source_key = None
  55. self.verify_source_config = None
  56. self.verify_file_info = []
  57. self.verify_display_options = []
  58. self.verify_file_paths = []
  59. # 初始化布局管理器
  60. self.layout_manager = OCRLayoutManager(self)
  61. # 加载多数据源文件信息
  62. self.load_multi_source_info()
  63. def load_multi_source_info(self):
  64. """加载多数据源文件信息"""
  65. self.all_sources = find_available_ocr_files_multi_source(self.config)
  66. # 如果有数据源,默认选择第一个作为OCR源
  67. if self.all_sources:
  68. source_keys = list(self.all_sources.keys())
  69. first_source_key = source_keys[0]
  70. self.switch_to_source(first_source_key)
  71. # 如果有第二个数据源,默认作为验证源
  72. if len(source_keys) > 1:
  73. self.switch_to_verify_source(source_keys[1])
  74. def switch_to_source(self, source_key: str):
  75. """切换到指定OCR数据源"""
  76. if source_key in self.all_sources:
  77. self.current_source_key = source_key
  78. source_data = self.all_sources[source_key]
  79. self.current_source_config = source_data['config']
  80. self.file_info = source_data['files']
  81. if self.file_info:
  82. # 创建显示选项列表
  83. self.display_options = [f"{info['display_name']}" for info in self.file_info]
  84. self.file_paths = [info['path'] for info in self.file_info]
  85. # 重置文件选择
  86. self.selected_file_index = -1
  87. print(f"✅ 切换到OCR数据源: {source_key}")
  88. else:
  89. print(f"⚠️ 数据源 {source_key} 没有可用文件")
  90. def switch_to_verify_source(self, source_key: str):
  91. """切换到指定验证数据源"""
  92. if source_key in self.all_sources:
  93. self.verify_source_key = source_key
  94. source_data = self.all_sources[source_key]
  95. self.verify_source_config = source_data['config']
  96. self.verify_file_info = source_data['files']
  97. if self.verify_file_info:
  98. self.verify_display_options = [f"{info['display_name']}" for info in self.verify_file_info]
  99. self.verify_file_paths = [info['path'] for info in self.verify_file_info]
  100. print(f"✅ 切换到验证数据源: {source_key}")
  101. else:
  102. print(f"⚠️ 验证数据源 {source_key} 没有可用文件")
  103. def setup_page_config(self):
  104. """设置页面配置"""
  105. ui_config = self.config['ui']
  106. st.set_page_config(
  107. page_title=ui_config['page_title'],
  108. page_icon=ui_config['page_icon'],
  109. layout=ui_config['layout'],
  110. initial_sidebar_state=ui_config['sidebar_state']
  111. )
  112. # 加载CSS样式
  113. css_content = load_css_styles()
  114. st.markdown(f"<style>{css_content}</style>", unsafe_allow_html=True)
  115. def create_data_source_selector(self):
  116. """创建双数据源选择器 - 支持交叉验证"""
  117. if not self.all_sources:
  118. st.warning("❌ 未找到任何数据源,请检查配置文件")
  119. return
  120. # 准备数据源选项
  121. source_options = {}
  122. for source_key, source_data in self.all_sources.items():
  123. display_name = get_data_source_display_name(source_data['config'])
  124. source_options[display_name] = source_key
  125. # 创建两列布局
  126. col1, col2 = st.columns(2)
  127. with col1:
  128. st.markdown("#### 📊 OCR数据源")
  129. # OCR数据源选择
  130. current_display_name = None
  131. if self.current_source_key:
  132. for display_name, key in source_options.items():
  133. if key == self.current_source_key:
  134. current_display_name = display_name
  135. break
  136. selected_ocr_display = st.selectbox(
  137. "选择OCR数据源",
  138. options=list(source_options.keys()),
  139. index=list(source_options.keys()).index(current_display_name) if current_display_name else 0,
  140. key="ocr_source_selector",
  141. label_visibility="collapsed",
  142. help="选择要分析的OCR数据源"
  143. )
  144. selected_ocr_key = source_options[selected_ocr_display]
  145. # 如果OCR数据源发生变化,切换数据源
  146. if selected_ocr_key != self.current_source_key:
  147. self.switch_to_source(selected_ocr_key)
  148. if 'selected_file_index' in st.session_state:
  149. st.session_state.selected_file_index = 0
  150. st.rerun()
  151. # 显示OCR数据源信息
  152. if self.current_source_config:
  153. with st.expander("📋 OCR数据源详情", expanded=False):
  154. st.write(f"**工具:** {self.current_source_config['ocr_tool']}")
  155. st.write(f"**文件数:** {len(self.file_info)}")
  156. with col2:
  157. st.markdown("#### 🔍 验证数据源")
  158. # 验证数据源选择
  159. verify_display_name = None
  160. if self.verify_source_key:
  161. for display_name, key in source_options.items():
  162. if key == self.verify_source_key:
  163. verify_display_name = display_name
  164. break
  165. selected_verify_display = st.selectbox(
  166. "选择验证数据源",
  167. options=list(source_options.keys()),
  168. index=list(source_options.keys()).index(verify_display_name) if verify_display_name else (1 if len(source_options) > 1 else 0),
  169. key="verify_source_selector",
  170. label_visibility="collapsed",
  171. help="选择用于交叉验证的数据源"
  172. )
  173. selected_verify_key = source_options[selected_verify_display]
  174. # 如果验证数据源发生变化,切换数据源
  175. if selected_verify_key != self.verify_source_key:
  176. self.switch_to_verify_source(selected_verify_key)
  177. st.rerun()
  178. # 显示验证数据源信息
  179. if self.verify_source_config:
  180. with st.expander("📋 验证数据源详情", expanded=False):
  181. st.write(f"**工具:** {self.verify_source_config['ocr_tool']}")
  182. st.write(f"**文件数:** {len(self.verify_file_info)}")
  183. # 数据源对比提示
  184. if self.current_source_key == self.verify_source_key:
  185. st.warning("⚠️ OCR数据源和验证数据源相同,建议选择不同的数据源进行交叉验证")
  186. else:
  187. st.success(f"✅ 已选择 {selected_ocr_display} 与 {selected_verify_display} 进行交叉验证")
  188. def load_ocr_data(self, json_path: str, md_path: Optional[str] = None, image_path: Optional[str] = None):
  189. """加载OCR相关数据 - 支持多数据源配置"""
  190. try:
  191. # 使用当前数据源的配置加载数据
  192. if self.current_source_config:
  193. # 临时修改config以使用当前数据源的配置
  194. temp_config = self.config.copy()
  195. temp_config['paths'] = {
  196. 'ocr_out_dir': self.current_source_config['ocr_out_dir'],
  197. 'src_img_dir': self.current_source_config.get('src_img_dir', ''),
  198. 'pre_validation_dir': self.config['pre_validation']['out_dir']
  199. }
  200. # 设置OCR工具类型
  201. temp_config['current_ocr_tool'] = self.current_source_config['ocr_tool']
  202. self.ocr_data, self.md_content, self.image_path = load_ocr_data_file(json_path, temp_config)
  203. else:
  204. self.ocr_data, self.md_content, self.image_path = load_ocr_data_file(json_path, self.config)
  205. self.process_data()
  206. except Exception as e:
  207. st.error(f"❌ 加载失败: {e}")
  208. st.exception(e)
  209. def process_data(self):
  210. """处理OCR数据"""
  211. self.text_bbox_mapping = process_ocr_data(self.ocr_data, self.config)
  212. def get_statistics(self) -> Dict:
  213. """获取统计信息"""
  214. return get_ocr_statistics(self.ocr_data, self.text_bbox_mapping, self.marked_errors)
  215. def display_html_table_as_dataframe(self, html_content: str, enable_editing: bool = False):
  216. """将HTML表格解析为DataFrame显示 - 增强版本支持横向滚动"""
  217. tables = parse_html_tables(html_content)
  218. wide_table_threshold = 15 # 超宽表格列数阈值
  219. if not tables:
  220. st.warning("未找到可解析的表格")
  221. # 对于无法解析的HTML表格,使用自定义CSS显示
  222. st.markdown("""
  223. <style>
  224. .scrollable-table {
  225. overflow-x: auto;
  226. white-space: nowrap;
  227. border: 1px solid #ddd;
  228. border-radius: 5px;
  229. margin: 10px 0;
  230. }
  231. .scrollable-table table {
  232. width: 100%;
  233. border-collapse: collapse;
  234. }
  235. .scrollable-table th, .scrollable-table td {
  236. border: 1px solid #ddd;
  237. padding: 8px;
  238. text-align: left;
  239. min-width: 100px;
  240. }
  241. .scrollable-table th {
  242. background-color: #f5f5f5;
  243. font-weight: bold;
  244. }
  245. </style>
  246. """, unsafe_allow_html=True)
  247. st.markdown(f'<div class="scrollable-table">{html_content}</div>', unsafe_allow_html=True)
  248. return
  249. for i, table in enumerate(tables):
  250. st.subheader(f"📊 表格 {i+1}")
  251. # 表格信息显示
  252. col_info1, col_info2, col_info3, col_info4 = st.columns(4)
  253. with col_info1:
  254. st.metric("行数", len(table))
  255. with col_info2:
  256. st.metric("列数", len(table.columns))
  257. with col_info3:
  258. # 检查是否有超宽表格
  259. is_wide_table = len(table.columns) > wide_table_threshold
  260. st.metric("表格类型", "超宽表格" if is_wide_table else "普通表格")
  261. with col_info4:
  262. # 表格操作模式选择
  263. display_mode = st.selectbox(
  264. f"显示模式 (表格{i+1})",
  265. ["完整显示", "分页显示", "筛选列显示"],
  266. key=f"display_mode_{i}"
  267. )
  268. # 创建表格操作按钮
  269. col1, col2, col3, col4 = st.columns(4)
  270. with col1:
  271. show_info = st.checkbox(f"显示详细信息", key=f"info_{i}")
  272. with col2:
  273. show_stats = st.checkbox(f"显示统计信息", key=f"stats_{i}")
  274. with col3:
  275. enable_filter = st.checkbox(f"启用过滤", key=f"filter_{i}")
  276. with col4:
  277. enable_sort = st.checkbox(f"启用排序", key=f"sort_{i}")
  278. # 根据显示模式处理表格
  279. display_table = self._process_table_display_mode(table, i, display_mode)
  280. # 数据过滤和排序逻辑
  281. filtered_table = self._apply_table_filters_and_sorts(display_table, i, enable_filter, enable_sort)
  282. # 显示表格 - 使用自定义CSS支持横向滚动
  283. st.markdown("""
  284. <style>
  285. .dataframe-container {
  286. overflow-x: auto;
  287. border: 1px solid #ddd;
  288. border-radius: 5px;
  289. margin: 10px 0;
  290. }
  291. /* 为超宽表格特殊样式 */
  292. .wide-table-container {
  293. overflow-x: auto;
  294. max-height: 500px;
  295. overflow-y: auto;
  296. border: 2px solid #0288d1;
  297. border-radius: 8px;
  298. background: linear-gradient(90deg, #f8f9fa 0%, #ffffff 100%);
  299. }
  300. .dataframe thead th {
  301. position: sticky;
  302. top: 0;
  303. background-color: #f5f5f5 !important;
  304. z-index: 10;
  305. border-bottom: 2px solid #0288d1;
  306. }
  307. .dataframe tbody td {
  308. white-space: nowrap;
  309. min-width: 100px;
  310. max-width: 300px;
  311. overflow: hidden;
  312. text-overflow: ellipsis;
  313. }
  314. </style>
  315. """, unsafe_allow_html=True)
  316. # 根据表格宽度选择显示容器
  317. container_class = "wide-table-container" if len(table.columns) > wide_table_threshold else "dataframe-container"
  318. if enable_editing:
  319. st.markdown(f'<div class="{container_class}">', unsafe_allow_html=True)
  320. edited_table = st.data_editor(
  321. filtered_table,
  322. use_container_width=True,
  323. key=f"editor_{i}",
  324. height=400 if len(table.columns) > 8 else None
  325. )
  326. st.markdown('</div>', unsafe_allow_html=True)
  327. if not edited_table.equals(filtered_table):
  328. st.success("✏️ 表格已编辑,可以导出修改后的数据")
  329. else:
  330. st.markdown(f'<div class="{container_class}">', unsafe_allow_html=True)
  331. st.dataframe(
  332. filtered_table,
  333. # use_container_width=True,
  334. width =400 if len(table.columns) > wide_table_threshold else "stretch"
  335. )
  336. st.markdown('</div>', unsafe_allow_html=True)
  337. # 显示表格信息和统计
  338. self._display_table_info_and_stats(table, filtered_table, show_info, show_stats, i)
  339. st.markdown("---")
  340. def _apply_table_filters_and_sorts(self, table: pd.DataFrame, table_index: int, enable_filter: bool, enable_sort: bool) -> pd.DataFrame:
  341. """应用表格过滤和排序"""
  342. filtered_table = table.copy()
  343. # 数据过滤
  344. if enable_filter and not table.empty:
  345. filter_col = st.selectbox(
  346. f"选择过滤列 (表格 {table_index+1})",
  347. options=['无'] + list(table.columns),
  348. key=f"filter_col_{table_index}"
  349. )
  350. if filter_col != '无':
  351. filter_value = st.text_input(f"过滤值 (表格 {table_index+1})", key=f"filter_value_{table_index}")
  352. if filter_value:
  353. filtered_table = table[table[filter_col].astype(str).str.contains(filter_value, na=False)]
  354. # 数据排序
  355. if enable_sort and not filtered_table.empty:
  356. sort_col = st.selectbox(
  357. f"选择排序列 (表格 {table_index+1})",
  358. options=['无'] + list(filtered_table.columns),
  359. key=f"sort_col_{table_index}"
  360. )
  361. if sort_col != '无':
  362. sort_order = st.radio(
  363. f"排序方式 (表格 {table_index+1})",
  364. options=['升序', '降序'],
  365. horizontal=True,
  366. key=f"sort_order_{table_index}"
  367. )
  368. ascending = (sort_order == '升序')
  369. filtered_table = filtered_table.sort_values(sort_col, ascending=ascending)
  370. return filtered_table
  371. def _display_table_info_and_stats(self, original_table: pd.DataFrame, filtered_table: pd.DataFrame,
  372. show_info: bool, show_stats: bool, table_index: int):
  373. """显示表格信息和统计数据"""
  374. if show_info:
  375. st.write("**表格信息:**")
  376. st.write(f"- 原始行数: {len(original_table)}")
  377. st.write(f"- 过滤后行数: {len(filtered_table)}")
  378. st.write(f"- 列数: {len(original_table.columns)}")
  379. st.write(f"- 列名: {', '.join(original_table.columns)}")
  380. if show_stats:
  381. st.write("**统计信息:**")
  382. numeric_cols = filtered_table.select_dtypes(include=[np.number]).columns
  383. if len(numeric_cols) > 0:
  384. st.dataframe(filtered_table[numeric_cols].describe())
  385. else:
  386. st.info("表格中没有数值列")
  387. # 导出功能
  388. if st.button(f"📥 导出表格 {table_index+1}", key=f"export_{table_index}"):
  389. self._create_export_buttons(filtered_table, table_index)
  390. def _create_export_buttons(self, table: pd.DataFrame, table_index: int):
  391. """创建导出按钮"""
  392. # CSV导出
  393. csv_data = table.to_csv(index=False)
  394. st.download_button(
  395. label=f"下载CSV (表格 {table_index+1})",
  396. data=csv_data,
  397. file_name=f"table_{table_index+1}.csv",
  398. mime="text/csv",
  399. key=f"download_csv_{table_index}"
  400. )
  401. # Excel导出
  402. excel_buffer = BytesIO()
  403. table.to_excel(excel_buffer, index=False)
  404. st.download_button(
  405. label=f"下载Excel (表格 {table_index+1})",
  406. data=excel_buffer.getvalue(),
  407. file_name=f"table_{table_index+1}.xlsx",
  408. mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
  409. key=f"download_excel_{table_index}"
  410. )
  411. def _process_table_display_mode(self, table: pd.DataFrame, table_index: int, display_mode: str) -> pd.DataFrame:
  412. """根据显示模式处理表格"""
  413. if display_mode == "分页显示":
  414. # 分页显示
  415. page_size = st.selectbox(
  416. f"每页显示行数 (表格 {table_index+1})",
  417. [10, 20, 50, 100],
  418. key=f"page_size_{table_index}"
  419. )
  420. total_pages = (len(table) - 1) // page_size + 1
  421. if total_pages > 1:
  422. page_number = st.selectbox(
  423. f"页码 (表格 {table_index+1})",
  424. range(1, total_pages + 1),
  425. key=f"page_number_{table_index}"
  426. )
  427. start_idx = (page_number - 1) * page_size
  428. end_idx = start_idx + page_size
  429. return table.iloc[start_idx:end_idx]
  430. return table
  431. elif display_mode == "筛选列显示":
  432. # 列筛选显示
  433. if len(table.columns) > 5:
  434. selected_columns = st.multiselect(
  435. f"选择要显示的列 (表格 {table_index+1})",
  436. table.columns.tolist(),
  437. default=table.columns.tolist()[:5], # 默认显示前5列
  438. key=f"selected_columns_{table_index}"
  439. )
  440. if selected_columns:
  441. return table[selected_columns]
  442. return table
  443. else: # 完整显示
  444. return table
  445. def find_verify_md_path(self, selected_file_index: int) -> Optional[Path]:
  446. """查找当前OCR文件对应的验证文件路径"""
  447. current_page = self.file_info[selected_file_index]['page']
  448. verify_md_path = None
  449. for i, info in enumerate(self.verify_file_info):
  450. if info['page'] == current_page:
  451. verify_md_path = Path(self.verify_file_paths[i]).with_suffix('.md')
  452. break
  453. return verify_md_path
  454. @st.dialog("交叉验证", width="large", dismissible=True, on_dismiss="rerun")
  455. def cross_validation(self):
  456. """交叉验证功能 - 比对两个数据源的OCR结果"""
  457. if not self.image_path or not self.md_content:
  458. st.error("❌ 请先加载OCR数据文件")
  459. return
  460. if self.current_source_key == self.verify_source_key:
  461. st.error("❌ OCR数据源和验证数据源不能相同")
  462. return
  463. # 初始化对比结果存储
  464. if 'cross_validation_result' not in st.session_state:
  465. st.session_state.cross_validation_result = None
  466. # 初始化对比结果存储
  467. if 'cross_validation_result' not in st.session_state:
  468. st.session_state.cross_validation_result = None
  469. # 创建进度条和状态显示
  470. with st.spinner("正在进行交叉验证...", show_time=True):
  471. status_text = st.empty()
  472. try:
  473. # 第一步:获取当前OCR结果文件路径
  474. current_md_path = Path(self.file_paths[self.selected_file_index]).with_suffix('.md')
  475. if not current_md_path.exists():
  476. st.error("❌ 当前OCR结果的Markdown文件不存在")
  477. return
  478. status_text.text(f"📄 OCR文件: {current_md_path.name}")
  479. # 第二步:查找对应的验证文件
  480. verify_md_path = self.find_verify_md_path(self.selected_file_index)
  481. if not verify_md_path or not verify_md_path.exists():
  482. st.error(f"❌ 未找到验证数据源中第{current_md_path}页的对应文件")
  483. return
  484. status_text.text(f"🔍 验证文件: {verify_md_path.name}")
  485. # 第三步:准备输出目录
  486. pre_validation_dir = Path(self.config['pre_validation'].get('out_dir', './output/pre_validation/')).resolve()
  487. pre_validation_dir.mkdir(parents=True, exist_ok=True)
  488. # 第四步:调用对比功能
  489. status_text.text("📊 正在对比OCR结果...")
  490. comparison_result_path = pre_validation_dir / f"{current_md_path.stem}_cross_validation"
  491. # 在expander中显示对比过程
  492. with st.expander("🔍 交叉验证对比过程", expanded=True):
  493. compare_output = st.empty()
  494. # 捕获对比输出
  495. import io
  496. import contextlib
  497. output_buffer = io.StringIO()
  498. with contextlib.redirect_stdout(output_buffer):
  499. comparison_result = compare_ocr_results(
  500. file1_path=str(current_md_path),
  501. file2_path=str(verify_md_path),
  502. output_file=str(comparison_result_path),
  503. output_format='both',
  504. ignore_images=True,
  505. table_mode='flow_list', # ✅ 使用流水表格模式
  506. similarity_algorithm='ratio'
  507. )
  508. # 显示对比过程输出
  509. compare_output.code(output_buffer.getvalue(), language='text')
  510. status_text.text("✅ 交叉验证完成")
  511. st.session_state.cross_validation_result = {
  512. "ocr_source": get_data_source_display_name(self.current_source_config),
  513. "verify_source": get_data_source_display_name(self.verify_source_config),
  514. "ocr_file": str(current_md_path),
  515. "verify_file": str(verify_md_path),
  516. "comparison_result_json": f"{comparison_result_path}.json",
  517. "comparison_result_md": f"{comparison_result_path}.md",
  518. "comparison_result": comparison_result
  519. }
  520. # 第五步:显示对比结果
  521. self.display_comparison_results(comparison_result, detailed=False)
  522. except Exception as e:
  523. st.error(f"❌ 交叉验证失败: {e}")
  524. st.exception(e)
  525. @st.dialog("查看交叉验证结果", width="large", dismissible=True, on_dismiss="rerun")
  526. def show_cross_validation_results_dialog(self):
  527. """显示交叉验证结果的对话框"""
  528. current_md_path = Path(self.file_paths[self.selected_file_index]).with_suffix('.md')
  529. pre_validation_dir = Path(self.config['pre_validation'].get('out_dir', './output/pre_validation/')).resolve()
  530. comparison_result_path = pre_validation_dir / f"{current_md_path.stem}_cross_validation.json"
  531. if 'cross_validation_result' in st.session_state and st.session_state.cross_validation_result:
  532. result = st.session_state.cross_validation_result
  533. # 显示数据源信息
  534. col1, col2 = st.columns(2)
  535. with col1:
  536. st.info(f"**OCR数据源:** {result['ocr_source']}")
  537. with col2:
  538. st.info(f"**验证数据源:** {result['verify_source']}")
  539. self.display_comparison_results(result['comparison_result'])
  540. elif comparison_result_path.exists():
  541. # 如果有历史结果文件,提示加载
  542. if st.button("📂 加载历史验证结果"):
  543. with open(comparison_result_path, "r", encoding="utf-8") as f:
  544. comparison_json_result = json.load(f)
  545. cross_validation_result = {
  546. "ocr_source": get_data_source_display_name(self.current_source_config),
  547. "verify_source": get_data_source_display_name(self.verify_source_config),
  548. "ocr_file": comparison_json_result['file1_path'],
  549. "verify_file": comparison_json_result['file2_path'],
  550. "comparison_result_json": str(comparison_result_path),
  551. "comparison_result_md": str(comparison_result_path.with_suffix('.md')),
  552. "comparison_result": comparison_json_result
  553. }
  554. st.session_state.cross_validation_result = cross_validation_result
  555. self.display_comparison_results(comparison_json_result)
  556. else:
  557. st.info("暂无交叉验证结果,请先运行交叉验证")
  558. def display_comparison_results(self, comparison_result: dict, detailed: bool = True):
  559. """显示对比结果摘要 - 使用DataFrame展示"""
  560. st.header("📊 VLM预校验结果")
  561. # 统计信息
  562. stats = comparison_result['statistics']
  563. # 统计信息概览
  564. col1, col2, col3, col4 = st.columns(4)
  565. with col1:
  566. st.metric("总差异数", stats['total_differences'])
  567. with col2:
  568. st.metric("表格差异", stats['table_differences'])
  569. with col3:
  570. st.metric("金额差异", stats['amount_differences'])
  571. with col4:
  572. st.metric("段落差异", stats['paragraph_differences'])
  573. # 结果判断
  574. if stats['total_differences'] == 0:
  575. st.success("🎉 完美匹配!VLM识别结果与原OCR结果完全一致")
  576. else:
  577. st.warning(f"⚠️ 发现 {stats['total_differences']} 个差异,建议人工检查")
  578. # 使用DataFrame显示差异详情
  579. if comparison_result['differences']:
  580. st.subheader("🔍 差异详情对比")
  581. # 准备DataFrame数据
  582. diff_data = []
  583. for i, diff in enumerate(comparison_result['differences'], 1):
  584. diff_data.append({
  585. '序号': i,
  586. '位置': diff['position'],
  587. '类型': diff['type'],
  588. '原OCR结果': diff['file1_value'][:100] + ('...' if len(diff['file1_value']) > 100 else ''),
  589. 'VLM识别结果': diff['file2_value'][:100] + ('...' if len(diff['file2_value']) > 100 else ''),
  590. '描述': diff['description'][:80] + ('...' if len(diff['description']) > 80 else ''),
  591. '严重程度': self._get_severity_level(diff)
  592. })
  593. # 创建DataFrame
  594. df_differences = pd.DataFrame(diff_data)
  595. # 添加样式
  596. def highlight_severity(val):
  597. """根据严重程度添加颜色"""
  598. if val == '高':
  599. return 'background-color: #ffebee; color: #c62828'
  600. elif val == '中':
  601. return 'background-color: #fff3e0; color: #ef6c00'
  602. elif val == '低':
  603. return 'background-color: #e8f5e8; color: #2e7d32'
  604. return ''
  605. # 显示DataFrame
  606. styled_df = df_differences.style.applymap(
  607. highlight_severity,
  608. subset=['严重程度']
  609. ).format({
  610. '序号': '{:d}',
  611. })
  612. st.dataframe(
  613. styled_df,
  614. use_container_width=True,
  615. height=400,
  616. hide_index=True,
  617. column_config={
  618. "序号": st.column_config.NumberColumn(
  619. "序号",
  620. width=None, # 自动调整宽度
  621. pinned=True,
  622. help="差异项序号"
  623. ),
  624. "位置": st.column_config.TextColumn(
  625. "位置",
  626. width=None, # 自动调整宽度
  627. pinned=True,
  628. help="差异在文档中的位置"
  629. ),
  630. "类型": st.column_config.TextColumn(
  631. "类型",
  632. width=None, # 自动调整宽度
  633. pinned=True,
  634. help="差异类型"
  635. ),
  636. "原OCR结果": st.column_config.TextColumn(
  637. "原OCR结果",
  638. width="large", # 自动调整宽度
  639. pinned=True,
  640. help="原始OCR识别结果"
  641. ),
  642. "VLM识别结果": st.column_config.TextColumn(
  643. "VLM识别结果",
  644. width="large", # 自动调整宽度
  645. help="VLM重新识别的结果"
  646. ),
  647. "描述": st.column_config.TextColumn(
  648. "描述",
  649. width="medium", # 自动调整宽度
  650. help="差异详细描述"
  651. ),
  652. "严重程度": st.column_config.TextColumn(
  653. "严重程度",
  654. width=None, # 自动调整宽度
  655. help="差异严重程度评级"
  656. )
  657. }
  658. )
  659. # 详细差异查看
  660. st.subheader("🔍 详细差异查看")
  661. if detailed:
  662. # 选择要查看的差异
  663. selected_diff_index = st.selectbox(
  664. "选择要查看的差异:",
  665. options=range(len(comparison_result['differences'])),
  666. format_func=lambda x: f"差异 {x+1}: {comparison_result['differences'][x]['position']} - {comparison_result['differences'][x]['type']}",
  667. key="selected_diff"
  668. )
  669. if selected_diff_index is not None:
  670. diff = comparison_result['differences'][selected_diff_index]
  671. # 并排显示完整内容
  672. col1, col2 = st.columns(2)
  673. with col1:
  674. st.write("**原OCR结果:**")
  675. st.text_area(
  676. "原OCR结果详情",
  677. value=diff['file1_value'],
  678. height=200,
  679. key=f"original_{selected_diff_index}",
  680. label_visibility="collapsed"
  681. )
  682. with col2:
  683. st.write("**验证数据源识别结果:**")
  684. st.text_area(
  685. "验证数据源识别结果详情",
  686. value=diff['file2_value'],
  687. height=200,
  688. key=f"vlm_{selected_diff_index}",
  689. label_visibility="collapsed"
  690. )
  691. # 差异详细信息
  692. st.info(f"**位置:** {diff['position']}")
  693. st.info(f"**类型:** {diff['type']}")
  694. st.info(f"**描述:** {diff['description']}")
  695. st.info(f"**严重程度:** {self._get_severity_level(diff)}")
  696. # 差异统计图表
  697. st.subheader("📈 差异类型分布")
  698. # 按类型统计差异
  699. type_counts = {}
  700. severity_counts = {'高': 0, '中': 0, '低': 0}
  701. for diff in comparison_result['differences']:
  702. diff_type = diff['type']
  703. type_counts[diff_type] = type_counts.get(diff_type, 0) + 1
  704. severity = self._get_severity_level(diff)
  705. severity_counts[severity] += 1
  706. col1, col2 = st.columns(2)
  707. with col1:
  708. # 类型分布饼图
  709. if type_counts:
  710. fig_type = px.pie(
  711. values=list(type_counts.values()),
  712. names=list(type_counts.keys()),
  713. title="差异类型分布"
  714. )
  715. st.plotly_chart(fig_type, use_container_width=True)
  716. with col2:
  717. # 严重程度分布条形图
  718. fig_severity = px.bar(
  719. x=list(severity_counts.keys()),
  720. y=list(severity_counts.values()),
  721. title="差异严重程度分布",
  722. color=list(severity_counts.keys()),
  723. color_discrete_map={'高': '#f44336', '中': '#ff9800', '低': '#4caf50'}
  724. )
  725. st.plotly_chart(fig_severity, use_container_width=True)
  726. # 下载选项
  727. if detailed:
  728. self._provide_download_options_in_results(comparison_result)
  729. def _get_severity_level(self, diff: dict) -> str:
  730. """根据差异类型和内容判断严重程度"""
  731. # 如果差异中已经包含严重程度,直接使用
  732. if 'severity' in diff:
  733. severity_map = {'high': '高', 'medium': '中', 'low': '低'}
  734. return severity_map.get(diff['severity'], '中')
  735. # 原有的逻辑作为后备
  736. diff_type = diff['type'].lower()
  737. # 金额相关差异为高严重程度
  738. if 'amount' in diff_type or 'number' in diff_type:
  739. return '高'
  740. # 表格结构差异为中等严重程度
  741. if 'table' in diff_type or 'structure' in diff_type:
  742. return '中'
  743. # 检查相似度
  744. if 'similarity' in diff:
  745. similarity = diff['similarity']
  746. if similarity < 50:
  747. return '高'
  748. elif similarity < 85:
  749. return '中'
  750. else:
  751. return '低'
  752. # 检查内容长度差异
  753. len_diff = abs(len(diff['file1_value']) - len(diff['file2_value']))
  754. if len_diff > 50:
  755. return '高'
  756. elif len_diff > 10:
  757. return '中'
  758. else:
  759. return '低'
  760. def _provide_download_options_in_results(self, comparison_result: dict):
  761. """在结果页面提供下载选项"""
  762. st.subheader("📥 导出预校验结果")
  763. col1, col2, col3 = st.columns(3)
  764. with col1:
  765. # 导出差异详情为Excel
  766. if comparison_result['differences']:
  767. diff_data = []
  768. for i, diff in enumerate(comparison_result['differences'], 1):
  769. diff_data.append({
  770. '序号': i,
  771. '位置': diff['position'],
  772. '类型': diff['type'],
  773. '原OCR结果': diff['file1_value'],
  774. 'VLM识别结果': diff['file2_value'],
  775. '描述': diff['description'],
  776. '严重程度': self._get_severity_level(diff)
  777. })
  778. df_export = pd.DataFrame(diff_data)
  779. excel_buffer = BytesIO()
  780. df_export.to_excel(excel_buffer, index=False, sheet_name='差异详情')
  781. st.download_button(
  782. label="📊 下载差异详情(Excel)",
  783. data=excel_buffer.getvalue(),
  784. file_name=f"vlm_comparison_differences_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.xlsx",
  785. mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
  786. key="download_differences_excel"
  787. )
  788. with col2:
  789. # 导出统计报告
  790. stats_data = {
  791. '统计项目': ['总差异数', '表格差异', '金额差异', '段落差异'],
  792. '数量': [
  793. comparison_result['statistics']['total_differences'],
  794. comparison_result['statistics']['table_differences'],
  795. comparison_result['statistics']['amount_differences'],
  796. comparison_result['statistics']['paragraph_differences']
  797. ]
  798. }
  799. df_stats = pd.DataFrame(stats_data)
  800. csv_stats = df_stats.to_csv(index=False)
  801. st.download_button(
  802. label="📈 下载统计报告(CSV)",
  803. data=csv_stats,
  804. file_name=f"vlm_comparison_stats_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.csv",
  805. mime="text/csv",
  806. key="download_stats_csv"
  807. )
  808. with col3:
  809. # 导出完整报告为JSON
  810. import json
  811. report_json = json.dumps(comparison_result, ensure_ascii=False, indent=2)
  812. st.download_button(
  813. label="📄 下载完整报告(JSON)",
  814. data=report_json,
  815. file_name=f"vlm_comparison_full_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.json",
  816. mime="application/json",
  817. key="download_full_json"
  818. )
  819. # 操作建议
  820. st.subheader("🚀 后续操作建议")
  821. total_diffs = comparison_result['statistics']['total_differences']
  822. if total_diffs == 0:
  823. st.success("✅ VLM识别结果与原OCR完全一致,可信度很高,无需人工校验")
  824. elif total_diffs <= 5:
  825. st.warning("⚠️ 发现少量差异,建议重点检查高严重程度的差异项")
  826. elif total_diffs <= 20:
  827. st.warning("🔍 发现中等数量差异,建议详细检查差异表格中标红的项目")
  828. else:
  829. st.error("❌ 发现大量差异,建议重新进行OCR识别或检查原始图片质量")
  830. def create_compact_layout(self, config):
  831. """创建滚动凑布局"""
  832. return self.layout_manager.create_compact_layout(config)
  833. @st.dialog("message", width="small", dismissible=True, on_dismiss="rerun")
  834. def message_box(msg: str, msg_type: str = "info"):
  835. if msg_type == "info":
  836. st.info(msg)
  837. elif msg_type == "warning":
  838. st.warning(msg)
  839. elif msg_type == "error":
  840. st.error(msg)
  841. def main():
  842. """主应用"""
  843. # 初始化应用
  844. if 'validator' not in st.session_state:
  845. validator = StreamlitOCRValidator()
  846. st.session_state.validator = validator
  847. st.session_state.validator.setup_page_config()
  848. # 页面标题
  849. config = st.session_state.validator.config
  850. st.title(config['ui']['page_title'])
  851. else:
  852. validator = st.session_state.validator
  853. config = st.session_state.validator.config
  854. if 'selected_text' not in st.session_state:
  855. st.session_state.selected_text = None
  856. if 'marked_errors' not in st.session_state:
  857. st.session_state.marked_errors = set()
  858. # 数据源选择器
  859. validator.create_data_source_selector()
  860. # 如果没有可用的数据源,提前返回
  861. if not validator.all_sources:
  862. st.stop()
  863. # 文件选择区域
  864. with st.container(height=75, horizontal=True, horizontal_alignment='left', gap="medium"):
  865. # 初始化session_state中的选择索引
  866. if 'selected_file_index' not in st.session_state:
  867. st.session_state.selected_file_index = 0
  868. if validator.display_options:
  869. # 文件选择下拉框
  870. selected_index = st.selectbox(
  871. "选择OCR结果文件",
  872. range(len(validator.display_options)),
  873. format_func=lambda i: validator.display_options[i],
  874. index=st.session_state.selected_file_index,
  875. key="selected_selectbox",
  876. label_visibility="collapsed"
  877. )
  878. # 更新session_state
  879. if selected_index != st.session_state.selected_file_index:
  880. st.session_state.selected_file_index = selected_index
  881. selected_file = validator.file_paths[selected_index]
  882. # 页码输入器
  883. current_page = validator.file_info[selected_index]['page']
  884. page_input = st.number_input(
  885. "输入页码",
  886. placeholder="输入页码",
  887. label_visibility="collapsed",
  888. min_value=1,
  889. max_value=len(validator.display_options),
  890. value=current_page,
  891. step=1,
  892. key="page_input"
  893. )
  894. # 当页码输入改变时,更新文件选择
  895. if page_input != current_page:
  896. for i, info in enumerate(validator.file_info):
  897. if info['page'] == page_input:
  898. st.session_state.selected_file_index = i
  899. selected_file = validator.file_paths[i]
  900. st.rerun()
  901. break
  902. # 自动加载文件
  903. if (st.session_state.selected_file_index >= 0
  904. and validator.selected_file_index != st.session_state.selected_file_index
  905. and selected_file):
  906. validator.selected_file_index = st.session_state.selected_file_index
  907. st.session_state.validator.load_ocr_data(selected_file)
  908. # 显示加载成功信息
  909. current_source_name = get_data_source_display_name(validator.current_source_config)
  910. st.success(f"✅ 已加载 {current_source_name} - 第{validator.file_info[st.session_state.selected_file_index]['page']}页")
  911. st.rerun()
  912. else:
  913. st.warning("当前数据源中未找到OCR结果文件")
  914. # 交叉验证按钮
  915. if st.button("交叉验证", type="primary", icon=":material/compare_arrows:"):
  916. if validator.image_path and validator.md_content:
  917. validator.cross_validation()
  918. else:
  919. message_box("❌ 请先选择OCR数据文件", "error")
  920. # 查看预校验结果按钮
  921. if st.button("查看验证结果", type="secondary", icon=":material/quick_reference_all:"):
  922. validator.show_cross_validation_results_dialog()
  923. # 显示当前数据源统计信息
  924. with st.expander("🔧 OCR工具统计信息", expanded=False):
  925. stats = validator.get_statistics()
  926. col1, col2, col3, col4, col5 = st.columns(5)
  927. with col1:
  928. st.metric("📊 总文本块", stats['total_texts'])
  929. with col2:
  930. st.metric("🔗 可点击文本", stats['clickable_texts'])
  931. with col3:
  932. st.metric("❌ 标记错误", stats['marked_errors'])
  933. with col4:
  934. st.metric("✅ 准确率", f"{stats['accuracy_rate']:.1f}%")
  935. with col5:
  936. # 显示当前数据源信息
  937. if validator.current_source_config:
  938. tool_display = validator.current_source_config['ocr_tool'].upper()
  939. st.metric("🔧 OCR工具", tool_display)
  940. # 详细工具信息
  941. if stats['tool_info']:
  942. st.write("**详细信息:**", stats['tool_info'])
  943. # 其余标签页保持不变...
  944. tab1, tab2, tab3 = st.tabs(["📄 内容人工检查", "🔍 交叉验证结果", "📊 表格分析"])
  945. with tab1:
  946. validator.create_compact_layout(config)
  947. with tab2:
  948. # st.header("📄 VLM预校验识别结果")
  949. current_md_path = Path(validator.file_paths[validator.selected_file_index]).with_suffix('.md')
  950. pre_validation_dir = Path(validator.config['pre_validation'].get('out_dir', './output/pre_validation/')).resolve()
  951. comparison_result_path = pre_validation_dir / f"{current_md_path.stem}_cross_validation.json"
  952. # pre_validation_path = pre_validation_dir / f"{current_md_path.stem}.md"
  953. verify_md_path = validator.find_verify_md_path(validator.selected_file_index)
  954. if comparison_result_path.exists():
  955. # 加载并显示验证结果
  956. with open(comparison_result_path, "r", encoding="utf-8") as f:
  957. comparison_result = json.load(f)
  958. # 左边显示OCR结果,右边显示VLM结果
  959. col1, col2 = st.columns([1,1])
  960. with col1:
  961. st.subheader("🤖 原OCR识别结果")
  962. with open(current_md_path, "r", encoding="utf-8") as f:
  963. original_md_content = f.read()
  964. font_size = config['styles'].get('font_size', 10)
  965. height = config['styles']['layout'].get('default_height', 800)
  966. layout_type = "compact"
  967. validator.layout_manager.render_content_by_mode(original_md_content, "HTML渲染", font_size, height, layout_type)
  968. with col2:
  969. st.subheader("🤖 VLM识别结果")
  970. with open(str(verify_md_path), "r", encoding="utf-8") as f:
  971. verify_md_content = f.read()
  972. font_size = config['styles'].get('font_size', 10)
  973. height = config['styles']['layout'].get('default_height', 800)
  974. layout_type = "compact"
  975. validator.layout_manager.render_content_by_mode(verify_md_content, "HTML渲染", font_size, height, layout_type)
  976. # 显示差异统计
  977. st.markdown("---")
  978. validator.display_comparison_results(comparison_result, detailed=True)
  979. else:
  980. st.info("暂无预校验结果,请先运行VLM预校验")
  981. with tab3:
  982. # 表格分析页面 - 保持原有逻辑
  983. st.header("📊 表格数据分析")
  984. if validator.md_content and '<table' in validator.md_content.lower():
  985. st.subheader("🔍 表格数据预览")
  986. validator.display_html_table_as_dataframe(validator.md_content)
  987. else:
  988. st.info("当前OCR结果中没有检测到表格数据")
  989. if __name__ == "__main__":
  990. main()