streamlit_ocr_validator.py 47 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104
  1. #!/usr/bin/env python3
  2. """
  3. 基于Streamlit的OCR可视化校验工具(重构版)
  4. 提供丰富的交互组件和更好的用户体验
  5. """
  6. import streamlit as st
  7. from pathlib import Path
  8. from PIL import Image
  9. from typing import Dict, List, Optional
  10. import plotly.graph_objects as go
  11. from io import BytesIO
  12. import pandas as pd
  13. import numpy as np
  14. import plotly.express as px
  15. import json
  16. # 导入工具模块
  17. from ocr_validator_utils import (
  18. load_config, load_css_styles, load_ocr_data_file, process_ocr_data,
  19. draw_bbox_on_image, get_ocr_statistics, convert_html_table_to_markdown,
  20. parse_html_tables, find_available_ocr_files, create_dynamic_css,
  21. export_tables_to_excel, get_table_statistics, group_texts_by_category,
  22. find_available_ocr_files_multi_source, get_data_source_display_name
  23. )
  24. from ocr_validator_layout import OCRLayoutManager
  25. from ocr_by_vlm import ocr_with_vlm
  26. from compare_ocr_results import compare_ocr_results
  27. class StreamlitOCRValidator:
  28. def __init__(self):
  29. self.config = load_config()
  30. self.ocr_data = []
  31. self.md_content = ""
  32. self.image_path = ""
  33. self.text_bbox_mapping = {}
  34. self.selected_text = None
  35. self.marked_errors = set()
  36. # 多数据源相关
  37. self.all_sources = {}
  38. self.current_source_key = None
  39. self.current_source_config = None
  40. self.file_info = []
  41. self.selected_file_index = -1
  42. self.display_options = []
  43. self.file_paths = []
  44. # 初始化布局管理器
  45. self.layout_manager = OCRLayoutManager(self)
  46. # 加载多数据源文件信息
  47. self.load_multi_source_info()
  48. def load_multi_source_info(self):
  49. """加载多数据源文件信息"""
  50. self.all_sources = find_available_ocr_files_multi_source(self.config)
  51. # 如果有数据源,默认选择第一个
  52. if self.all_sources:
  53. first_source_key = list(self.all_sources.keys())[0]
  54. self.switch_to_source(first_source_key)
  55. def switch_to_source(self, source_key: str):
  56. """切换到指定数据源"""
  57. if source_key in self.all_sources:
  58. self.current_source_key = source_key
  59. source_data = self.all_sources[source_key]
  60. self.current_source_config = source_data['config']
  61. self.file_info = source_data['files']
  62. if self.file_info:
  63. # 创建显示选项列表
  64. self.display_options = [f"{info['display_name']}" for info in self.file_info]
  65. self.file_paths = [info['path'] for info in self.file_info]
  66. # 重置文件选择
  67. self.selected_file_index = -1
  68. print(f"✅ 切换到数据源: {source_key}")
  69. else:
  70. print(f"⚠️ 数据源 {source_key} 没有可用文件")
  71. def setup_page_config(self):
  72. """设置页面配置"""
  73. ui_config = self.config['ui']
  74. st.set_page_config(
  75. page_title=ui_config['page_title'],
  76. page_icon=ui_config['page_icon'],
  77. layout=ui_config['layout'],
  78. initial_sidebar_state=ui_config['sidebar_state']
  79. )
  80. # 加载CSS样式
  81. css_content = load_css_styles()
  82. st.markdown(f"<style>{css_content}</style>", unsafe_allow_html=True)
  83. def create_data_source_selector(self):
  84. """创建数据源选择器"""
  85. if not self.all_sources:
  86. st.warning("❌ 未找到任何数据源,请检查配置文件")
  87. return
  88. # 数据源选择
  89. source_options = {}
  90. for source_key, source_data in self.all_sources.items():
  91. display_name = get_data_source_display_name(source_data['config'])
  92. source_options[display_name] = source_key
  93. # 获取当前选择的显示名称
  94. current_display_name = None
  95. if self.current_source_key:
  96. for display_name, key in source_options.items():
  97. if key == self.current_source_key:
  98. current_display_name = display_name
  99. break
  100. selected_display_name = st.selectbox(
  101. "📁 选择数据源",
  102. options=list(source_options.keys()),
  103. index=list(source_options.keys()).index(current_display_name) if current_display_name else 0,
  104. key="data_source_selector",
  105. help="选择要分析的OCR数据源"
  106. )
  107. selected_source_key = source_options[selected_display_name]
  108. # 如果数据源发生变化,切换数据源
  109. if selected_source_key != self.current_source_key:
  110. self.switch_to_source(selected_source_key)
  111. # 重置session state
  112. if 'selected_file_index' in st.session_state:
  113. st.session_state.selected_file_index = 0
  114. st.rerun()
  115. # 显示数据源信息
  116. if self.current_source_config:
  117. with st.expander("📋 数据源详情", expanded=False):
  118. col1, col2, col3 = st.columns(3)
  119. with col1:
  120. st.write(f"**名称:** {self.current_source_config['name']}")
  121. st.write(f"**OCR工具:** {self.current_source_config['ocr_tool']}")
  122. with col2:
  123. st.write(f"**输出目录:** {self.current_source_config['ocr_out_dir']}")
  124. st.write(f"**图片目录:** {self.current_source_config.get('src_img_dir', 'N/A')}")
  125. with col3:
  126. st.write(f"**描述:** {self.current_source_config.get('description', 'N/A')}")
  127. st.write(f"**文件数量:** {len(self.file_info)}")
  128. def load_ocr_data(self, json_path: str, md_path: Optional[str] = None, image_path: Optional[str] = None):
  129. """加载OCR相关数据 - 支持多数据源配置"""
  130. try:
  131. # 使用当前数据源的配置加载数据
  132. if self.current_source_config:
  133. # 临时修改config以使用当前数据源的配置
  134. temp_config = self.config.copy()
  135. temp_config['paths'] = {
  136. 'ocr_out_dir': self.current_source_config['ocr_out_dir'],
  137. 'src_img_dir': self.current_source_config.get('src_img_dir', ''),
  138. 'pre_validation_dir': self.config['pre_validation']['out_dir']
  139. }
  140. # 设置OCR工具类型
  141. temp_config['current_ocr_tool'] = self.current_source_config['ocr_tool']
  142. self.ocr_data, self.md_content, self.image_path = load_ocr_data_file(json_path, temp_config)
  143. else:
  144. self.ocr_data, self.md_content, self.image_path = load_ocr_data_file(json_path, self.config)
  145. self.process_data()
  146. except Exception as e:
  147. st.error(f"❌ 加载失败: {e}")
  148. st.exception(e)
  149. def process_data(self):
  150. """处理OCR数据"""
  151. self.text_bbox_mapping = process_ocr_data(self.ocr_data, self.config)
  152. def get_statistics(self) -> Dict:
  153. """获取统计信息"""
  154. return get_ocr_statistics(self.ocr_data, self.text_bbox_mapping, self.marked_errors)
  155. def display_html_table_as_dataframe(self, html_content: str, enable_editing: bool = False):
  156. """将HTML表格解析为DataFrame显示 - 增强版本支持横向滚动"""
  157. tables = parse_html_tables(html_content)
  158. wide_table_threshold = 15 # 超宽表格列数阈值
  159. if not tables:
  160. st.warning("未找到可解析的表格")
  161. # 对于无法解析的HTML表格,使用自定义CSS显示
  162. st.markdown("""
  163. <style>
  164. .scrollable-table {
  165. overflow-x: auto;
  166. white-space: nowrap;
  167. border: 1px solid #ddd;
  168. border-radius: 5px;
  169. margin: 10px 0;
  170. }
  171. .scrollable-table table {
  172. width: 100%;
  173. border-collapse: collapse;
  174. }
  175. .scrollable-table th, .scrollable-table td {
  176. border: 1px solid #ddd;
  177. padding: 8px;
  178. text-align: left;
  179. min-width: 100px;
  180. }
  181. .scrollable-table th {
  182. background-color: #f5f5f5;
  183. font-weight: bold;
  184. }
  185. </style>
  186. """, unsafe_allow_html=True)
  187. st.markdown(f'<div class="scrollable-table">{html_content}</div>', unsafe_allow_html=True)
  188. return
  189. for i, table in enumerate(tables):
  190. st.subheader(f"📊 表格 {i+1}")
  191. # 表格信息显示
  192. col_info1, col_info2, col_info3, col_info4 = st.columns(4)
  193. with col_info1:
  194. st.metric("行数", len(table))
  195. with col_info2:
  196. st.metric("列数", len(table.columns))
  197. with col_info3:
  198. # 检查是否有超宽表格
  199. is_wide_table = len(table.columns) > wide_table_threshold
  200. st.metric("表格类型", "超宽表格" if is_wide_table else "普通表格")
  201. with col_info4:
  202. # 表格操作模式选择
  203. display_mode = st.selectbox(
  204. f"显示模式 (表格{i+1})",
  205. ["完整显示", "分页显示", "筛选列显示"],
  206. key=f"display_mode_{i}"
  207. )
  208. # 创建表格操作按钮
  209. col1, col2, col3, col4 = st.columns(4)
  210. with col1:
  211. show_info = st.checkbox(f"显示详细信息", key=f"info_{i}")
  212. with col2:
  213. show_stats = st.checkbox(f"显示统计信息", key=f"stats_{i}")
  214. with col3:
  215. enable_filter = st.checkbox(f"启用过滤", key=f"filter_{i}")
  216. with col4:
  217. enable_sort = st.checkbox(f"启用排序", key=f"sort_{i}")
  218. # 根据显示模式处理表格
  219. display_table = self._process_table_display_mode(table, i, display_mode)
  220. # 数据过滤和排序逻辑
  221. filtered_table = self._apply_table_filters_and_sorts(display_table, i, enable_filter, enable_sort)
  222. # 显示表格 - 使用自定义CSS支持横向滚动
  223. st.markdown("""
  224. <style>
  225. .dataframe-container {
  226. overflow-x: auto;
  227. border: 1px solid #ddd;
  228. border-radius: 5px;
  229. margin: 10px 0;
  230. }
  231. /* 为超宽表格特殊样式 */
  232. .wide-table-container {
  233. overflow-x: auto;
  234. max-height: 500px;
  235. overflow-y: auto;
  236. border: 2px solid #0288d1;
  237. border-radius: 8px;
  238. background: linear-gradient(90deg, #f8f9fa 0%, #ffffff 100%);
  239. }
  240. .dataframe thead th {
  241. position: sticky;
  242. top: 0;
  243. background-color: #f5f5f5 !important;
  244. z-index: 10;
  245. border-bottom: 2px solid #0288d1;
  246. }
  247. .dataframe tbody td {
  248. white-space: nowrap;
  249. min-width: 100px;
  250. max-width: 300px;
  251. overflow: hidden;
  252. text-overflow: ellipsis;
  253. }
  254. </style>
  255. """, unsafe_allow_html=True)
  256. # 根据表格宽度选择显示容器
  257. container_class = "wide-table-container" if len(table.columns) > wide_table_threshold else "dataframe-container"
  258. if enable_editing:
  259. st.markdown(f'<div class="{container_class}">', unsafe_allow_html=True)
  260. edited_table = st.data_editor(
  261. filtered_table,
  262. use_container_width=True,
  263. key=f"editor_{i}",
  264. height=400 if len(table.columns) > 8 else None
  265. )
  266. st.markdown('</div>', unsafe_allow_html=True)
  267. if not edited_table.equals(filtered_table):
  268. st.success("✏️ 表格已编辑,可以导出修改后的数据")
  269. else:
  270. st.markdown(f'<div class="{container_class}">', unsafe_allow_html=True)
  271. st.dataframe(
  272. filtered_table,
  273. # use_container_width=True,
  274. width =400 if len(table.columns) > wide_table_threshold else "stretch"
  275. )
  276. st.markdown('</div>', unsafe_allow_html=True)
  277. # 显示表格信息和统计
  278. self._display_table_info_and_stats(table, filtered_table, show_info, show_stats, i)
  279. st.markdown("---")
  280. def _apply_table_filters_and_sorts(self, table: pd.DataFrame, table_index: int, enable_filter: bool, enable_sort: bool) -> pd.DataFrame:
  281. """应用表格过滤和排序"""
  282. filtered_table = table.copy()
  283. # 数据过滤
  284. if enable_filter and not table.empty:
  285. filter_col = st.selectbox(
  286. f"选择过滤列 (表格 {table_index+1})",
  287. options=['无'] + list(table.columns),
  288. key=f"filter_col_{table_index}"
  289. )
  290. if filter_col != '无':
  291. filter_value = st.text_input(f"过滤值 (表格 {table_index+1})", key=f"filter_value_{table_index}")
  292. if filter_value:
  293. filtered_table = table[table[filter_col].astype(str).str.contains(filter_value, na=False)]
  294. # 数据排序
  295. if enable_sort and not filtered_table.empty:
  296. sort_col = st.selectbox(
  297. f"选择排序列 (表格 {table_index+1})",
  298. options=['无'] + list(filtered_table.columns),
  299. key=f"sort_col_{table_index}"
  300. )
  301. if sort_col != '无':
  302. sort_order = st.radio(
  303. f"排序方式 (表格 {table_index+1})",
  304. options=['升序', '降序'],
  305. horizontal=True,
  306. key=f"sort_order_{table_index}"
  307. )
  308. ascending = (sort_order == '升序')
  309. filtered_table = filtered_table.sort_values(sort_col, ascending=ascending)
  310. return filtered_table
  311. def _display_table_info_and_stats(self, original_table: pd.DataFrame, filtered_table: pd.DataFrame,
  312. show_info: bool, show_stats: bool, table_index: int):
  313. """显示表格信息和统计数据"""
  314. if show_info:
  315. st.write("**表格信息:**")
  316. st.write(f"- 原始行数: {len(original_table)}")
  317. st.write(f"- 过滤后行数: {len(filtered_table)}")
  318. st.write(f"- 列数: {len(original_table.columns)}")
  319. st.write(f"- 列名: {', '.join(original_table.columns)}")
  320. if show_stats:
  321. st.write("**统计信息:**")
  322. numeric_cols = filtered_table.select_dtypes(include=[np.number]).columns
  323. if len(numeric_cols) > 0:
  324. st.dataframe(filtered_table[numeric_cols].describe())
  325. else:
  326. st.info("表格中没有数值列")
  327. # 导出功能
  328. if st.button(f"📥 导出表格 {table_index+1}", key=f"export_{table_index}"):
  329. self._create_export_buttons(filtered_table, table_index)
  330. def _create_export_buttons(self, table: pd.DataFrame, table_index: int):
  331. """创建导出按钮"""
  332. # CSV导出
  333. csv_data = table.to_csv(index=False)
  334. st.download_button(
  335. label=f"下载CSV (表格 {table_index+1})",
  336. data=csv_data,
  337. file_name=f"table_{table_index+1}.csv",
  338. mime="text/csv",
  339. key=f"download_csv_{table_index}"
  340. )
  341. # Excel导出
  342. excel_buffer = BytesIO()
  343. table.to_excel(excel_buffer, index=False)
  344. st.download_button(
  345. label=f"下载Excel (表格 {table_index+1})",
  346. data=excel_buffer.getvalue(),
  347. file_name=f"table_{table_index+1}.xlsx",
  348. mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
  349. key=f"download_excel_{table_index}"
  350. )
  351. def _process_table_display_mode(self, table: pd.DataFrame, table_index: int, display_mode: str) -> pd.DataFrame:
  352. """根据显示模式处理表格"""
  353. if display_mode == "分页显示":
  354. # 分页显示
  355. page_size = st.selectbox(
  356. f"每页显示行数 (表格 {table_index+1})",
  357. [10, 20, 50, 100],
  358. key=f"page_size_{table_index}"
  359. )
  360. total_pages = (len(table) - 1) // page_size + 1
  361. if total_pages > 1:
  362. page_number = st.selectbox(
  363. f"页码 (表格 {table_index+1})",
  364. range(1, total_pages + 1),
  365. key=f"page_number_{table_index}"
  366. )
  367. start_idx = (page_number - 1) * page_size
  368. end_idx = start_idx + page_size
  369. return table.iloc[start_idx:end_idx]
  370. return table
  371. elif display_mode == "筛选列显示":
  372. # 列筛选显示
  373. if len(table.columns) > 5:
  374. selected_columns = st.multiselect(
  375. f"选择要显示的列 (表格 {table_index+1})",
  376. table.columns.tolist(),
  377. default=table.columns.tolist()[:5], # 默认显示前5列
  378. key=f"selected_columns_{table_index}"
  379. )
  380. if selected_columns:
  381. return table[selected_columns]
  382. return table
  383. else: # 完整显示
  384. return table
  385. @st.dialog("VLM预校验", width="large", dismissible=True, on_dismiss="rerun")
  386. def vlm_pre_validation(self):
  387. """VLM预校验功能 - 封装OCR识别和结果对比"""
  388. if not self.image_path or not self.md_content:
  389. st.error("❌ 请先加载OCR数据文件")
  390. return
  391. # 初始化对比结果存储
  392. if 'comparison_result' not in st.session_state:
  393. st.session_state.comparison_result = None
  394. # 创建进度条和状态显示
  395. with st.spinner("正在进行VLM预校验...", show_time=True):
  396. status_text = st.empty()
  397. try:
  398. current_md_path = Path(self.file_paths[self.selected_file_index]).with_suffix('.md')
  399. if not current_md_path.exists():
  400. st.error("❌ 当前OCR结果的Markdown文件不存在,无法进行对比")
  401. return
  402. # 第一步:准备目录
  403. pre_validation_dir = Path(self.config['pre_validation'].get('out_dir', './output/pre_validation/')).resolve()
  404. pre_validation_dir.mkdir(parents=True, exist_ok=True)
  405. status_text.write(f"工作目录: {pre_validation_dir}")
  406. # 第二步:调用VLM进行OCR识别
  407. status_text.text("🤖 正在调用VLM进行OCR识别...")
  408. # 在expander中显示OCR过程
  409. with st.expander("🔍 VLM OCR识别过程", expanded=True):
  410. ocr_output = st.empty()
  411. # 捕获OCR输出
  412. import io
  413. import contextlib
  414. # 创建字符串缓冲区来捕获print输出
  415. output_buffer = io.StringIO()
  416. with contextlib.redirect_stdout(output_buffer):
  417. ocr_result = ocr_with_vlm(
  418. image_path=str(self.image_path),
  419. output_dir=str(pre_validation_dir),
  420. normalize_numbers=True
  421. )
  422. # 显示OCR过程输出
  423. ocr_output.code(output_buffer.getvalue(), language='text')
  424. status_text.text("✅ VLM OCR识别完成")
  425. # 第三步:获取VLM生成的文件路径
  426. vlm_md_path = pre_validation_dir / f"{Path(self.image_path).stem}.md"
  427. if not vlm_md_path.exists():
  428. st.error("❌ VLM OCR结果文件未生成")
  429. return
  430. # 第四步:调用对比功能
  431. status_text.text("📊 正在对比OCR结果...")
  432. # 在expander中显示对比过程
  433. comparison_result_path = pre_validation_dir / f"{current_md_path.stem}_comparison_result"
  434. with st.expander("🔍 OCR结果对比过程", expanded=True):
  435. compare_output = st.empty()
  436. # 捕获对比输出
  437. output_buffer = io.StringIO()
  438. with contextlib.redirect_stdout(output_buffer):
  439. comparison_result = compare_ocr_results(
  440. file1_path=str(current_md_path),
  441. file2_path=str(vlm_md_path),
  442. output_file=str(comparison_result_path),
  443. output_format='both',
  444. ignore_images=True
  445. )
  446. # 显示对比过程输出
  447. compare_output.code(output_buffer.getvalue(), language='text')
  448. status_text.text("✅ VLM预校验完成")
  449. st.session_state.comparison_result = {
  450. "image_path": self.image_path,
  451. "comparison_result_json": f"{comparison_result_path}.json",
  452. "comparison_result_md": f"{comparison_result_path}.md",
  453. "comparison_result": comparison_result
  454. }
  455. # 第五步:显示对比结果
  456. self.display_comparison_results(comparison_result, detailed=False)
  457. # 第六步:提供文件下载
  458. # self.provide_download_options(pre_validation_dir, vlm_md_path, comparison_result)
  459. except Exception as e:
  460. st.error(f"❌ VLM预校验失败: {e}")
  461. st.exception(e)
  462. def display_comparison_results(self, comparison_result: dict, detailed: bool = True):
  463. """显示对比结果摘要 - 使用DataFrame展示"""
  464. st.header("📊 VLM预校验结果")
  465. # 统计信息
  466. stats = comparison_result['statistics']
  467. # 统计信息概览
  468. col1, col2, col3, col4 = st.columns(4)
  469. with col1:
  470. st.metric("总差异数", stats['total_differences'])
  471. with col2:
  472. st.metric("表格差异", stats['table_differences'])
  473. with col3:
  474. st.metric("金额差异", stats['amount_differences'])
  475. with col4:
  476. st.metric("段落差异", stats['paragraph_differences'])
  477. # 结果判断
  478. if stats['total_differences'] == 0:
  479. st.success("🎉 完美匹配!VLM识别结果与原OCR结果完全一致")
  480. else:
  481. st.warning(f"⚠️ 发现 {stats['total_differences']} 个差异,建议人工检查")
  482. # 使用DataFrame显示差异详情
  483. if comparison_result['differences']:
  484. st.subheader("🔍 差异详情对比")
  485. # 准备DataFrame数据
  486. diff_data = []
  487. for i, diff in enumerate(comparison_result['differences'], 1):
  488. diff_data.append({
  489. '序号': i,
  490. '位置': diff['position'],
  491. '类型': diff['type'],
  492. '原OCR结果': diff['file1_value'][:100] + ('...' if len(diff['file1_value']) > 100 else ''),
  493. 'VLM识别结果': diff['file2_value'][:100] + ('...' if len(diff['file2_value']) > 100 else ''),
  494. '描述': diff['description'][:80] + ('...' if len(diff['description']) > 80 else ''),
  495. '严重程度': self._get_severity_level(diff)
  496. })
  497. # 创建DataFrame
  498. df_differences = pd.DataFrame(diff_data)
  499. # 添加样式
  500. def highlight_severity(val):
  501. """根据严重程度添加颜色"""
  502. if val == '高':
  503. return 'background-color: #ffebee; color: #c62828'
  504. elif val == '中':
  505. return 'background-color: #fff3e0; color: #ef6c00'
  506. elif val == '低':
  507. return 'background-color: #e8f5e8; color: #2e7d32'
  508. return ''
  509. # 显示DataFrame
  510. styled_df = df_differences.style.applymap(
  511. highlight_severity,
  512. subset=['严重程度']
  513. ).format({
  514. '序号': '{:d}',
  515. })
  516. st.dataframe(
  517. styled_df,
  518. use_container_width=True,
  519. height=400,
  520. hide_index=True,
  521. column_config={
  522. "序号": st.column_config.NumberColumn(
  523. "序号",
  524. width=None, # 自动调整宽度
  525. pinned=True,
  526. help="差异项序号"
  527. ),
  528. "位置": st.column_config.TextColumn(
  529. "位置",
  530. width=None, # 自动调整宽度
  531. pinned=True,
  532. help="差异在文档中的位置"
  533. ),
  534. "类型": st.column_config.TextColumn(
  535. "类型",
  536. width=None, # 自动调整宽度
  537. pinned=True,
  538. help="差异类型"
  539. ),
  540. "原OCR结果": st.column_config.TextColumn(
  541. "原OCR结果",
  542. width="large", # 自动调整宽度
  543. pinned=True,
  544. help="原始OCR识别结果"
  545. ),
  546. "VLM识别结果": st.column_config.TextColumn(
  547. "VLM识别结果",
  548. width="large", # 自动调整宽度
  549. help="VLM重新识别的结果"
  550. ),
  551. "描述": st.column_config.TextColumn(
  552. "描述",
  553. width="medium", # 自动调整宽度
  554. help="差异详细描述"
  555. ),
  556. "严重程度": st.column_config.TextColumn(
  557. "严重程度",
  558. width=None, # 自动调整宽度
  559. help="差异严重程度评级"
  560. )
  561. }
  562. )
  563. # 详细差异查看
  564. st.subheader("🔍 详细差异查看")
  565. if detailed:
  566. # 选择要查看的差异
  567. selected_diff_index = st.selectbox(
  568. "选择要查看的差异:",
  569. options=range(len(comparison_result['differences'])),
  570. format_func=lambda x: f"差异 {x+1}: {comparison_result['differences'][x]['position']} - {comparison_result['differences'][x]['type']}",
  571. key="selected_diff"
  572. )
  573. if selected_diff_index is not None:
  574. diff = comparison_result['differences'][selected_diff_index]
  575. # 并排显示完整内容
  576. col1, col2 = st.columns(2)
  577. with col1:
  578. st.write("**原OCR结果:**")
  579. st.text_area(
  580. "原OCR结果详情",
  581. value=diff['file1_value'],
  582. height=200,
  583. key=f"original_{selected_diff_index}",
  584. label_visibility="collapsed"
  585. )
  586. with col2:
  587. st.write("**VLM识别结果:**")
  588. st.text_area(
  589. "VLM识别结果详情",
  590. value=diff['file2_value'],
  591. height=200,
  592. key=f"vlm_{selected_diff_index}",
  593. label_visibility="collapsed"
  594. )
  595. # 差异详细信息
  596. st.info(f"**位置:** {diff['position']}")
  597. st.info(f"**类型:** {diff['type']}")
  598. st.info(f"**描述:** {diff['description']}")
  599. st.info(f"**严重程度:** {self._get_severity_level(diff)}")
  600. # 差异统计图表
  601. st.subheader("📈 差异类型分布")
  602. # 按类型统计差异
  603. type_counts = {}
  604. severity_counts = {'高': 0, '中': 0, '低': 0}
  605. for diff in comparison_result['differences']:
  606. diff_type = diff['type']
  607. type_counts[diff_type] = type_counts.get(diff_type, 0) + 1
  608. severity = self._get_severity_level(diff)
  609. severity_counts[severity] += 1
  610. col1, col2 = st.columns(2)
  611. with col1:
  612. # 类型分布饼图
  613. if type_counts:
  614. fig_type = px.pie(
  615. values=list(type_counts.values()),
  616. names=list(type_counts.keys()),
  617. title="差异类型分布"
  618. )
  619. st.plotly_chart(fig_type, use_container_width=True)
  620. with col2:
  621. # 严重程度分布条形图
  622. fig_severity = px.bar(
  623. x=list(severity_counts.keys()),
  624. y=list(severity_counts.values()),
  625. title="差异严重程度分布",
  626. color=list(severity_counts.keys()),
  627. color_discrete_map={'高': '#f44336', '中': '#ff9800', '低': '#4caf50'}
  628. )
  629. st.plotly_chart(fig_severity, use_container_width=True)
  630. # 下载选项
  631. if detailed:
  632. self._provide_download_options_in_results(comparison_result)
  633. def _get_severity_level(self, diff: dict) -> str:
  634. """根据差异类型和内容判断严重程度"""
  635. # 如果差异中已经包含严重程度,直接使用
  636. if 'severity' in diff:
  637. severity_map = {'high': '高', 'medium': '中', 'low': '低'}
  638. return severity_map.get(diff['severity'], '中')
  639. # 原有的逻辑作为后备
  640. diff_type = diff['type'].lower()
  641. # 金额相关差异为高严重程度
  642. if 'amount' in diff_type or 'number' in diff_type:
  643. return '高'
  644. # 表格结构差异为中等严重程度
  645. if 'table' in diff_type or 'structure' in diff_type:
  646. return '中'
  647. # 检查相似度
  648. if 'similarity' in diff:
  649. similarity = diff['similarity']
  650. if similarity < 50:
  651. return '高'
  652. elif similarity < 85:
  653. return '中'
  654. else:
  655. return '低'
  656. # 检查内容长度差异
  657. len_diff = abs(len(diff['file1_value']) - len(diff['file2_value']))
  658. if len_diff > 50:
  659. return '高'
  660. elif len_diff > 10:
  661. return '中'
  662. else:
  663. return '低'
  664. def _provide_download_options_in_results(self, comparison_result: dict):
  665. """在结果页面提供下载选项"""
  666. st.subheader("📥 导出预校验结果")
  667. col1, col2, col3 = st.columns(3)
  668. with col1:
  669. # 导出差异详情为Excel
  670. if comparison_result['differences']:
  671. diff_data = []
  672. for i, diff in enumerate(comparison_result['differences'], 1):
  673. diff_data.append({
  674. '序号': i,
  675. '位置': diff['position'],
  676. '类型': diff['type'],
  677. '原OCR结果': diff['file1_value'],
  678. 'VLM识别结果': diff['file2_value'],
  679. '描述': diff['description'],
  680. '严重程度': self._get_severity_level(diff)
  681. })
  682. df_export = pd.DataFrame(diff_data)
  683. excel_buffer = BytesIO()
  684. df_export.to_excel(excel_buffer, index=False, sheet_name='差异详情')
  685. st.download_button(
  686. label="📊 下载差异详情(Excel)",
  687. data=excel_buffer.getvalue(),
  688. file_name=f"vlm_comparison_differences_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.xlsx",
  689. mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
  690. key="download_differences_excel"
  691. )
  692. with col2:
  693. # 导出统计报告
  694. stats_data = {
  695. '统计项目': ['总差异数', '表格差异', '金额差异', '段落差异'],
  696. '数量': [
  697. comparison_result['statistics']['total_differences'],
  698. comparison_result['statistics']['table_differences'],
  699. comparison_result['statistics']['amount_differences'],
  700. comparison_result['statistics']['paragraph_differences']
  701. ]
  702. }
  703. df_stats = pd.DataFrame(stats_data)
  704. csv_stats = df_stats.to_csv(index=False)
  705. st.download_button(
  706. label="📈 下载统计报告(CSV)",
  707. data=csv_stats,
  708. file_name=f"vlm_comparison_stats_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.csv",
  709. mime="text/csv",
  710. key="download_stats_csv"
  711. )
  712. with col3:
  713. # 导出完整报告为JSON
  714. import json
  715. report_json = json.dumps(comparison_result, ensure_ascii=False, indent=2)
  716. st.download_button(
  717. label="📄 下载完整报告(JSON)",
  718. data=report_json,
  719. file_name=f"vlm_comparison_full_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.json",
  720. mime="application/json",
  721. key="download_full_json"
  722. )
  723. # 操作建议
  724. st.subheader("🚀 后续操作建议")
  725. total_diffs = comparison_result['statistics']['total_differences']
  726. if total_diffs == 0:
  727. st.success("✅ VLM识别结果与原OCR完全一致,可信度很高,无需人工校验")
  728. elif total_diffs <= 5:
  729. st.warning("⚠️ 发现少量差异,建议重点检查高严重程度的差异项")
  730. elif total_diffs <= 20:
  731. st.warning("🔍 发现中等数量差异,建议详细检查差异表格中标红的项目")
  732. else:
  733. st.error("❌ 发现大量差异,建议重新进行OCR识别或检查原始图片质量")
  734. @st.dialog("查看预校验结果", width="large", dismissible=True, on_dismiss="rerun")
  735. def show_comparison_results_dialog(self):
  736. """显示VLM预校验结果的对话框"""
  737. current_md_path = Path(self.file_paths[self.selected_file_index]).with_suffix('.md')
  738. pre_validation_dir = Path(self.config['pre_validation'].get('out_dir', './output/pre_validation/')).resolve()
  739. comparison_result_path = pre_validation_dir / f"{current_md_path.stem}_comparison_result.json"
  740. if 'comparison_result' in st.session_state and st.session_state.comparison_result:
  741. self.display_comparison_results(st.session_state.comparison_result['comparison_result'])
  742. elif comparison_result_path.exists():
  743. # 如果pre_validation_dir下有结果文件,提示用户加载
  744. if st.button("加载预校验结果"):
  745. with open(comparison_result_path, "r", encoding="utf-8") as f:
  746. comparison_json_result = json.load(f)
  747. comparison_result = {
  748. "image_path": self.image_path,
  749. "comparison_result_json": str(comparison_result_path),
  750. "comparison_result_md": str(comparison_result_path.with_suffix('.md')),
  751. "comparison_result": comparison_json_result
  752. }
  753. st.session_state.comparison_result = comparison_result
  754. self.display_comparison_results(comparison_json_result)
  755. else:
  756. st.info("暂无预校验结果,请先运行VLM预校验")
  757. def create_compact_layout(self, config):
  758. """创建滚动凑布局"""
  759. return self.layout_manager.create_compact_layout(config)
  760. @st.dialog("message", width="small", dismissible=True, on_dismiss="rerun")
  761. def message_box(msg: str, msg_type: str = "info"):
  762. if msg_type == "info":
  763. st.info(msg)
  764. elif msg_type == "warning":
  765. st.warning(msg)
  766. elif msg_type == "error":
  767. st.error(msg)
  768. def main():
  769. """主应用"""
  770. # 初始化应用
  771. if 'validator' not in st.session_state:
  772. validator = StreamlitOCRValidator()
  773. st.session_state.validator = validator
  774. st.session_state.validator.setup_page_config()
  775. # 页面标题
  776. config = st.session_state.validator.config
  777. st.title(config['ui']['page_title'])
  778. else:
  779. validator = st.session_state.validator
  780. config = st.session_state.validator.config
  781. if 'selected_text' not in st.session_state:
  782. st.session_state.selected_text = None
  783. if 'marked_errors' not in st.session_state:
  784. st.session_state.marked_errors = set()
  785. # 数据源选择器
  786. validator.create_data_source_selector()
  787. # 如果没有可用的数据源,提前返回
  788. if not validator.all_sources:
  789. st.stop()
  790. # 文件选择区域
  791. with st.container(height=75, horizontal=True, horizontal_alignment='left', gap="medium"):
  792. # 初始化session_state中的选择索引
  793. if 'selected_file_index' not in st.session_state:
  794. st.session_state.selected_file_index = 0
  795. if validator.display_options:
  796. # 文件选择下拉框
  797. selected_index = st.selectbox(
  798. "选择OCR结果文件",
  799. range(len(validator.display_options)),
  800. format_func=lambda i: validator.display_options[i],
  801. index=st.session_state.selected_file_index,
  802. key="selected_selectbox",
  803. label_visibility="collapsed"
  804. )
  805. # 更新session_state
  806. if selected_index != st.session_state.selected_file_index:
  807. st.session_state.selected_file_index = selected_index
  808. selected_file = validator.file_paths[selected_index]
  809. # 页码输入器
  810. current_page = validator.file_info[selected_index]['page']
  811. page_input = st.number_input(
  812. "输入页码",
  813. placeholder="输入页码",
  814. label_visibility="collapsed",
  815. min_value=1,
  816. max_value=len(validator.display_options),
  817. value=current_page,
  818. step=1,
  819. key="page_input"
  820. )
  821. # 当页码输入改变时,更新文件选择
  822. if page_input != current_page:
  823. for i, info in enumerate(validator.file_info):
  824. if info['page'] == page_input:
  825. st.session_state.selected_file_index = i
  826. selected_file = validator.file_paths[i]
  827. st.rerun()
  828. break
  829. # 自动加载文件
  830. if (st.session_state.selected_file_index >= 0
  831. and validator.selected_file_index != st.session_state.selected_file_index
  832. and selected_file):
  833. validator.selected_file_index = st.session_state.selected_file_index
  834. st.session_state.validator.load_ocr_data(selected_file)
  835. # 显示加载成功信息
  836. current_source_name = get_data_source_display_name(validator.current_source_config)
  837. st.success(f"✅ 已加载 {current_source_name} - 第{validator.file_info[st.session_state.selected_file_index]['page']}页")
  838. st.rerun()
  839. else:
  840. st.warning("当前数据源中未找到OCR结果文件")
  841. # VLM预校验按钮
  842. if st.button("VLM预校验", type="primary", icon=":material/compare_arrows:"):
  843. if validator.image_path and validator.md_content:
  844. validator.vlm_pre_validation()
  845. else:
  846. message_box("❌ 请先选择OCR数据文件", "error")
  847. # 查看预校验结果按钮
  848. if st.button("查看预校验结果", type="secondary", icon=":material/quick_reference_all:"):
  849. validator.show_comparison_results_dialog()
  850. # 显示当前数据源统计信息
  851. with st.expander("🔧 OCR工具统计信息", expanded=False):
  852. stats = validator.get_statistics()
  853. col1, col2, col3, col4, col5 = st.columns(5)
  854. with col1:
  855. st.metric("📊 总文本块", stats['total_texts'])
  856. with col2:
  857. st.metric("🔗 可点击文本", stats['clickable_texts'])
  858. with col3:
  859. st.metric("❌ 标记错误", stats['marked_errors'])
  860. with col4:
  861. st.metric("✅ 准确率", f"{stats['accuracy_rate']:.1f}%")
  862. with col5:
  863. # 显示当前数据源信息
  864. if validator.current_source_config:
  865. tool_display = validator.current_source_config['ocr_tool'].upper()
  866. st.metric("🔧 OCR工具", tool_display)
  867. # 详细工具信息
  868. if stats['tool_info']:
  869. st.write("**详细信息:**", stats['tool_info'])
  870. # 其余标签页保持不变...
  871. tab1, tab2, tab3 = st.tabs(["📄 内容校验", "📄 VLM预校验识别结果", "📊 表格分析"])
  872. with tab1:
  873. validator.create_compact_layout(config)
  874. with tab2:
  875. # st.header("📄 VLM预校验识别结果")
  876. current_md_path = Path(validator.file_paths[validator.selected_file_index]).with_suffix('.md')
  877. pre_validation_dir = Path(validator.config['pre_validation'].get('out_dir', './output/pre_validation/')).resolve()
  878. comparison_result_path = pre_validation_dir / f"{current_md_path.stem}_comparison_result.json"
  879. pre_validation_path = pre_validation_dir / f"{current_md_path.stem}.md"
  880. if comparison_result_path.exists():
  881. # 左边显示OCR结果,右边显示VLM结果
  882. col1, col2 = st.columns([1,1])
  883. with col1:
  884. st.subheader("🤖 原OCR识别结果")
  885. with open(current_md_path, "r", encoding="utf-8") as f:
  886. original_md_content = f.read()
  887. font_size = config['styles'].get('font_size', 10)
  888. height = config['styles']['layout'].get('default_height', 800)
  889. layout_type = "compact"
  890. validator.layout_manager.render_content_by_mode(original_md_content, "HTML渲染", font_size, height, layout_type)
  891. with col2:
  892. st.subheader("🤖 VLM识别结果")
  893. with open(pre_validation_path, "r", encoding="utf-8") as f:
  894. pre_validation_md_content = f.read()
  895. font_size = config['styles'].get('font_size', 10)
  896. height = config['styles']['layout'].get('default_height', 800)
  897. layout_type = "compact"
  898. validator.layout_manager.render_content_by_mode(pre_validation_md_content, "HTML渲染", font_size, height, layout_type)
  899. else:
  900. st.info("暂无预校验结果,请先运行VLM预校验")
  901. with tab3:
  902. # 表格分析页面 - 保持原有逻辑
  903. st.header("📊 表格数据分析")
  904. if validator.md_content and '<table' in validator.md_content.lower():
  905. st.subheader("🔍 表格数据预览")
  906. validator.display_html_table_as_dataframe(validator.md_content)
  907. else:
  908. st.info("当前OCR结果中没有检测到表格数据")
  909. # with tab4:
  910. # # 数据统计页面 - 保持原有逻辑
  911. # st.header("📈 OCR数据统计")
  912. # # 添加数据源特定的统计信息
  913. # if validator.current_source_config:
  914. # st.subheader(f"📊 {get_data_source_display_name(validator.current_source_config)} - 统计信息")
  915. # if stats['categories']:
  916. # st.subheader("📊 类别分布")
  917. # fig_pie = px.pie(
  918. # values=list(stats['categories'].values()),
  919. # names=list(stats['categories'].keys()),
  920. # title="文本类别分布"
  921. # )
  922. # st.plotly_chart(fig_pie, use_container_width=True)
  923. # # 错误率分析
  924. # st.subheader("📈 质量分析")
  925. # accuracy_data = {
  926. # '状态': ['正确', '错误'],
  927. # '数量': [stats['clickable_texts'] - stats['marked_errors'], stats['marked_errors']]
  928. # }
  929. # fig_bar = px.bar(
  930. # accuracy_data, x='状态', y='数量', title="识别质量分布",
  931. # color='状态', color_discrete_map={'正确': 'green', '错误': 'red'}
  932. # )
  933. # st.plotly_chart(fig_bar, use_container_width=True)
  934. if __name__ == "__main__":
  935. main()