streamlit_validator_cross.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453
  1. """
  2. 交叉验证功能模块
  3. """
  4. import streamlit as st
  5. import pandas as pd
  6. import json
  7. from pathlib import Path
  8. from io import BytesIO
  9. import plotly.express as px
  10. from compare_ocr_results import compare_ocr_results
  11. from ocr_validator_utils import get_data_source_display_name
  12. @st.dialog("交叉验证", width="large", dismissible=True, on_dismiss="rerun")
  13. def cross_validation_dialog(validator):
  14. """交叉验证对话框"""
  15. if validator.current_source_key == validator.verify_source_key:
  16. st.error("❌ OCR数据源和验证数据源不能相同")
  17. return
  18. if 'cross_validation_batch_result' not in st.session_state:
  19. st.session_state.cross_validation_batch_result = None
  20. st.header("🔄 批量交叉验证")
  21. col1, col2 = st.columns(2)
  22. with col1:
  23. st.info(f"**OCR数据源:** {get_data_source_display_name(validator.current_source_config)}")
  24. st.write(f"📁 文件数量: {len(validator.file_info)}")
  25. with col2:
  26. st.info(f"**验证数据源:** {get_data_source_display_name(validator.verify_source_config)}")
  27. st.write(f"📁 文件数量: {len(validator.verify_file_info)}")
  28. with st.expander("⚙️ 验证选项", expanded=True):
  29. col1, col2 = st.columns(2)
  30. with col1:
  31. table_mode = st.selectbox(
  32. "表格比对模式",
  33. options=['standard', 'flow_list'],
  34. index=1,
  35. format_func=lambda x: '流水表格模式' if x == 'flow_list' else '标准模式',
  36. help="选择表格比对算法"
  37. )
  38. with col2:
  39. similarity_algorithm = st.selectbox(
  40. "相似度算法",
  41. options=['ratio', 'partial_ratio', 'token_sort_ratio', 'token_set_ratio'],
  42. index=0,
  43. help="选择文本相似度计算算法"
  44. )
  45. if st.button("🚀 开始批量验证", type="primary", use_container_width=True):
  46. run_batch_cross_validation(validator, table_mode, similarity_algorithm)
  47. if 'cross_validation_batch_result' in st.session_state and st.session_state.cross_validation_batch_result:
  48. st.markdown("---")
  49. display_batch_validation_results(st.session_state.cross_validation_batch_result)
  50. def run_batch_cross_validation(validator, table_mode: str, similarity_algorithm: str):
  51. """执行批量交叉验证"""
  52. pre_validation_dir = Path(validator.config['pre_validation'].get('out_dir', './output/pre_validation/')).resolve()
  53. pre_validation_dir.mkdir(parents=True, exist_ok=True)
  54. batch_results = _initialize_batch_results(validator, table_mode, similarity_algorithm)
  55. progress_bar = st.progress(0)
  56. status_text = st.empty()
  57. ocr_page_map = {info['page']: i for i, info in enumerate(validator.file_info)}
  58. verify_page_map = {info['page']: i for i, info in enumerate(validator.verify_file_info)}
  59. common_pages = sorted(set(ocr_page_map.keys()) & set(verify_page_map.keys()))
  60. if not common_pages:
  61. st.error("❌ 两个数据源没有共同的页码,无法进行对比")
  62. return
  63. batch_results['summary']['total_pages'] = len(common_pages)
  64. with st.expander("📋 详细对比日志", expanded=True):
  65. log_container = st.container()
  66. for idx, page_num in enumerate(common_pages):
  67. try:
  68. progress = (idx + 1) / len(common_pages)
  69. progress_bar.progress(progress)
  70. status_text.text(f"正在对比第 {page_num} 页... ({idx + 1}/{len(common_pages)})")
  71. ocr_file_index = ocr_page_map[page_num]
  72. verify_file_index = verify_page_map[page_num]
  73. ocr_md_path = Path(validator.file_paths[ocr_file_index]).with_suffix('.md')
  74. verify_md_path = Path(validator.verify_file_paths[verify_file_index]).with_suffix('.md')
  75. if not ocr_md_path.exists() or not verify_md_path.exists():
  76. with log_container:
  77. st.warning(f"⚠️ 第 {page_num} 页:文件不存在,跳过")
  78. batch_results['summary']['failed_pages'] += 1
  79. continue
  80. comparison_result_path = pre_validation_dir / f"{ocr_md_path.stem}_cross_validation"
  81. import io
  82. import contextlib
  83. output_buffer = io.StringIO()
  84. with contextlib.redirect_stdout(output_buffer):
  85. comparison_result = compare_ocr_results(
  86. file1_path=str(ocr_md_path),
  87. file2_path=str(verify_md_path),
  88. output_file=str(comparison_result_path),
  89. output_format='both',
  90. ignore_images=True,
  91. table_mode=table_mode,
  92. similarity_algorithm=similarity_algorithm
  93. )
  94. _process_comparison_result(batch_results, comparison_result, page_num,
  95. ocr_md_path, verify_md_path, comparison_result_path)
  96. with log_container:
  97. if comparison_result['statistics']['total_differences'] == 0:
  98. st.success(f"✅ 第 {page_num} 页:完全匹配")
  99. else:
  100. st.warning(f"⚠️ 第 {page_num} 页:发现 {comparison_result['statistics']['total_differences']} 个差异")
  101. except Exception as e:
  102. with log_container:
  103. st.error(f"❌ 第 {page_num} 页:对比失败 - {str(e)}")
  104. batch_results['pages'].append({
  105. 'page_num': page_num,
  106. 'status': 'failed',
  107. 'error': str(e)
  108. })
  109. batch_results['summary']['failed_pages'] += 1
  110. _save_batch_results(validator, batch_results, pre_validation_dir)
  111. progress_bar.progress(1.0)
  112. status_text.text("✅ 批量验证完成!")
  113. st.success(f"🎉 批量验证完成!成功: {batch_results['summary']['successful_pages']}, 失败: {batch_results['summary']['failed_pages']}")
  114. def _initialize_batch_results(validator, table_mode: str, similarity_algorithm: str) -> dict:
  115. """初始化批量结果存储"""
  116. return {
  117. 'ocr_source': get_data_source_display_name(validator.current_source_config),
  118. 'verify_source': get_data_source_display_name(validator.verify_source_config),
  119. 'table_mode': table_mode,
  120. 'similarity_algorithm': similarity_algorithm,
  121. 'timestamp': pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S'),
  122. 'pages': [],
  123. 'summary': {
  124. 'total_pages': 0,
  125. 'successful_pages': 0,
  126. 'failed_pages': 0,
  127. 'total_differences': 0,
  128. 'total_table_differences': 0,
  129. 'total_amount_differences': 0,
  130. 'total_datetime_differences': 0,
  131. 'total_text_differences': 0,
  132. 'total_paragraph_differences': 0,
  133. 'total_table_pre_header': 0,
  134. 'total_table_header_position': 0,
  135. 'total_table_header_critical': 0,
  136. 'total_table_row_missing': 0,
  137. 'total_high_severity': 0,
  138. 'total_medium_severity': 0,
  139. 'total_low_severity': 0
  140. }
  141. }
  142. def _process_comparison_result(batch_results: dict, comparison_result: dict, page_num: int,
  143. ocr_md_path: Path, verify_md_path: Path, comparison_result_path: Path):
  144. """处理对比结果"""
  145. stats = comparison_result['statistics']
  146. page_result = {
  147. 'page_num': page_num,
  148. 'ocr_file': str(ocr_md_path.name),
  149. 'verify_file': str(verify_md_path.name),
  150. 'total_differences': stats['total_differences'],
  151. 'table_differences': stats['table_differences'],
  152. 'amount_differences': stats.get('amount_differences', 0),
  153. 'datetime_differences': stats.get('datetime_differences', 0),
  154. 'text_differences': stats.get('text_differences', 0),
  155. 'paragraph_differences': stats['paragraph_differences'],
  156. 'table_pre_header': stats.get('table_pre_header', 0),
  157. 'table_header_position': stats.get('table_header_position', 0),
  158. 'table_header_critical': stats.get('table_header_critical', 0),
  159. 'table_row_missing': stats.get('table_row_missing', 0),
  160. 'high_severity': stats.get('high_severity', 0),
  161. 'medium_severity': stats.get('medium_severity', 0),
  162. 'low_severity': stats.get('low_severity', 0),
  163. 'status': 'success',
  164. 'comparison_json': f"{comparison_result_path}.json",
  165. 'comparison_md': f"{comparison_result_path}.md"
  166. }
  167. batch_results['pages'].append(page_result)
  168. batch_results['summary']['successful_pages'] += 1
  169. # 更新汇总统计
  170. for key in stats:
  171. total_key = f'total_{key}'
  172. if total_key in batch_results['summary']:
  173. batch_results['summary'][total_key] += stats.get(key, 0)
  174. def _save_batch_results(validator, batch_results: dict, pre_validation_dir: Path):
  175. """保存批量结果"""
  176. batch_result_path = pre_validation_dir / f"{validator.current_source_config['name']}_{validator.current_source_config['ocr_tool']}_vs_{validator.verify_source_config['ocr_tool']}_batch_cross_validation"
  177. with open(f"{batch_result_path}.json", "w", encoding="utf-8") as f:
  178. json.dump(batch_results, f, ensure_ascii=False, indent=2)
  179. generate_batch_validation_markdown(batch_results, f"{batch_result_path}.md")
  180. st.session_state.cross_validation_batch_result = batch_results
  181. def generate_batch_validation_markdown(batch_results: dict, output_path: str):
  182. """生成批量验证的Markdown报告"""
  183. with open(output_path, "w", encoding="utf-8") as f:
  184. f.write("# 批量交叉验证报告\n\n")
  185. # 基本信息
  186. f.write("## 📋 基本信息\n\n")
  187. f.write(f"- **OCR数据源:** {batch_results['ocr_source']}\n")
  188. f.write(f"- **验证数据源:** {batch_results['verify_source']}\n")
  189. f.write(f"- **表格模式:** {batch_results['table_mode']}\n")
  190. f.write(f"- **相似度算法:** {batch_results['similarity_algorithm']}\n")
  191. f.write(f"- **验证时间:** {batch_results['timestamp']}\n\n")
  192. # 汇总统计
  193. summary = batch_results['summary']
  194. f.write("## 📊 汇总统计\n\n")
  195. f.write(f"- **总页数:** {summary['total_pages']}\n")
  196. f.write(f"- **成功页数:** {summary['successful_pages']}\n")
  197. f.write(f"- **失败页数:** {summary['failed_pages']}\n")
  198. f.write(f"- **总差异数:** {summary['total_differences']}\n")
  199. f.write(f"- **表格差异:** {summary['total_table_differences']}\n")
  200. f.write(f" - 金额差异: {summary.get('total_amount_differences', 0)}\n")
  201. f.write(f" - 日期差异: {summary.get('total_datetime_differences', 0)}\n")
  202. f.write(f" - 文本差异: {summary.get('total_text_differences', 0)}\n")
  203. f.write(f" - 表头前差异: {summary.get('total_table_pre_header', 0)}\n")
  204. f.write(f" - 表头位置差异: {summary.get('total_table_header_position', 0)}\n")
  205. f.write(f" - 表头严重错误: {summary.get('total_table_header_critical', 0)}\n")
  206. f.write(f" - 行缺失: {summary.get('total_table_row_missing', 0)}\n")
  207. f.write(f"- **段落差异:** {summary['total_paragraph_differences']}\n")
  208. f.write(f"- **严重程度统计:**\n")
  209. f.write(f" - 高严重度: {summary.get('total_high_severity', 0)}\n")
  210. f.write(f" - 中严重度: {summary.get('total_medium_severity', 0)}\n")
  211. f.write(f" - 低严重度: {summary.get('total_low_severity', 0)}\n\n")
  212. # 详细结果表格
  213. f.write("## 📄 各页差异统计\n\n")
  214. f.write("| 页码 | 状态 | 总差异 | 表格差异 | 金额 | 日期 | 文本 | 段落 | 表头前 | 表头位置 | 表头错误 | 行缺失 | 高 | 中 | 低 |\n")
  215. f.write("|------|------|--------|----------|------|------|------|------|--------|----------|----------|--------|----|----|----|\n")
  216. for page in batch_results['pages']:
  217. if page['status'] == 'success':
  218. status_icon = "✅" if page['total_differences'] == 0 else "⚠️"
  219. f.write(f"| {page['page_num']} | {status_icon} | ")
  220. f.write(f"{page['total_differences']} | ")
  221. f.write(f"{page['table_differences']} | ")
  222. f.write(f"{page.get('amount_differences', 0)} | ")
  223. f.write(f"{page.get('datetime_differences', 0)} | ")
  224. f.write(f"{page.get('text_differences', 0)} | ")
  225. f.write(f"{page['paragraph_differences']} | ")
  226. f.write(f"{page.get('table_pre_header', 0)} | ")
  227. f.write(f"{page.get('table_header_position', 0)} | ")
  228. f.write(f"{page.get('table_header_critical', 0)} | ")
  229. f.write(f"{page.get('table_row_missing', 0)} | ")
  230. f.write(f"{page.get('high_severity', 0)} | ")
  231. f.write(f"{page.get('medium_severity', 0)} | ")
  232. f.write(f"{page.get('low_severity', 0)} |\n")
  233. else:
  234. f.write(f"| {page['page_num']} | ❌ | - | - | - | - | - | - | - | - | - | - | - | - | - |\n")
  235. f.write("\n")
  236. # 问题汇总
  237. f.write("## 🔍 问题汇总\n\n")
  238. high_diff_pages = [p for p in batch_results['pages']
  239. if p['status'] == 'success' and p['total_differences'] > 10]
  240. if high_diff_pages:
  241. f.write("### ⚠️ 高差异页面(差异>10)\n\n")
  242. for page in high_diff_pages:
  243. f.write(f"- 第 {page['page_num']} 页:{page['total_differences']} 个差异\n")
  244. f.write("\n")
  245. amount_error_pages = [p for p in batch_results['pages']
  246. if p['status'] == 'success' and p.get('amount_differences', 0) > 0]
  247. if amount_error_pages:
  248. f.write("### 💰 金额差异页面\n\n")
  249. for page in amount_error_pages:
  250. f.write(f"- 第 {page['page_num']} 页:{page.get('amount_differences', 0)} 个金额差异\n")
  251. f.write("\n")
  252. header_error_pages = [p for p in batch_results['pages']
  253. if p['status'] == 'success' and p.get('table_header_critical', 0) > 0]
  254. if header_error_pages:
  255. f.write("### ❌ 表头严重错误页面\n\n")
  256. for page in header_error_pages:
  257. f.write(f"- 第 {page['page_num']} 页:{page['table_header_critical']} 个表头错误\n")
  258. f.write("\n")
  259. failed_pages = [p for p in batch_results['pages'] if p['status'] == 'failed']
  260. if failed_pages:
  261. f.write("### 💥 验证失败页面\n\n")
  262. for page in failed_pages:
  263. f.write(f"- 第 {page['page_num']} 页:{page.get('error', '未知错误')}\n")
  264. f.write("\n")
  265. def display_batch_validation_results(batch_results: dict):
  266. """显示批量验证结果"""
  267. st.header("📊 批量验证结果")
  268. summary = batch_results['summary']
  269. col1, col2, col3, col4 = st.columns(4)
  270. with col1:
  271. st.metric("总页数", summary['total_pages'])
  272. with col2:
  273. st.metric("成功页数", summary['successful_pages'],
  274. delta=f"{summary['successful_pages']/summary['total_pages']*100:.1f}%")
  275. with col3:
  276. st.metric("失败页数", summary['failed_pages'],
  277. delta=f"-{summary['failed_pages']}" if summary['failed_pages'] > 0 else "0")
  278. with col4:
  279. st.metric("总差异数", summary['total_differences'])
  280. # ✅ 详细差异类型统计 - 更新展示
  281. st.subheader("📈 差异类型统计")
  282. col1, col2, col3 = st.columns(3)
  283. with col1:
  284. st.metric("表格差异", summary['total_table_differences'])
  285. st.caption(f"金额: {summary.get('total_amount_differences', 0)} | 日期: {summary.get('total_datetime_differences', 0)} | 文本: {summary.get('total_text_differences', 0)}")
  286. with col2:
  287. st.metric("段落差异", summary['total_paragraph_differences'])
  288. with col3:
  289. st.metric("严重度", f"高:{summary.get('total_high_severity', 0)} 中:{summary.get('total_medium_severity', 0)} 低:{summary.get('total_low_severity', 0)}")
  290. # 表格结构差异统计
  291. with st.expander("📋 表格结构差异详情", expanded=False):
  292. col1, col2, col3, col4 = st.columns(4)
  293. with col1:
  294. st.metric("表头前", summary.get('total_table_pre_header', 0))
  295. with col2:
  296. st.metric("表头位置", summary.get('total_table_header_position', 0))
  297. with col3:
  298. st.metric("表头错误", summary.get('total_table_header_critical', 0))
  299. with col4:
  300. st.metric("行缺失", summary.get('total_table_row_missing', 0))
  301. # ✅ 各页详细结果表格 - 更新列
  302. st.subheader("📄 各页详细结果")
  303. # 准备DataFrame
  304. page_data = []
  305. for page in batch_results['pages']:
  306. if page['status'] == 'success':
  307. page_data.append({
  308. '页码': page['page_num'],
  309. '状态': '✅ 成功' if page['total_differences'] == 0 else '⚠️ 有差异',
  310. '总差异': page['total_differences'],
  311. '表格差异': page['table_differences'],
  312. '金额': page.get('amount_differences', 0),
  313. '日期': page.get('datetime_differences', 0),
  314. '文本': page.get('text_differences', 0),
  315. '段落': page['paragraph_differences'],
  316. '表头前': page.get('table_pre_header', 0),
  317. '表头位置': page.get('table_header_position', 0),
  318. '表头错误': page.get('table_header_critical', 0),
  319. '行缺失': page.get('table_row_missing', 0),
  320. '高': page.get('high_severity', 0),
  321. '中': page.get('medium_severity', 0),
  322. '低': page.get('low_severity', 0)
  323. })
  324. else:
  325. page_data.append({
  326. '页码': page['page_num'],
  327. '状态': '❌ 失败',
  328. '总差异': '-', '表格差异': '-', '金额': '-', '日期': '-',
  329. '文本': '-', '段落': '-', '表头前': '-', '表头位置': '-',
  330. '表头错误': '-', '行缺失': '-', '高': '-', '中': '-', '低': '-'
  331. })
  332. df_pages = pd.DataFrame(page_data)
  333. # 显示表格
  334. st.dataframe(
  335. df_pages,
  336. use_container_width=True,
  337. hide_index=True,
  338. column_config={
  339. "页码": st.column_config.NumberColumn("页码", width="small"),
  340. "状态": st.column_config.TextColumn("状态", width="small"),
  341. "总差异": st.column_config.NumberColumn("总差异", width="small"),
  342. "表格差异": st.column_config.NumberColumn("表格", width="small"),
  343. "金额": st.column_config.NumberColumn("金额", width="small"),
  344. "日期": st.column_config.NumberColumn("日期", width="small"),
  345. "文本": st.column_config.NumberColumn("文本", width="small"),
  346. "段落": st.column_config.NumberColumn("段落", width="small"),
  347. }
  348. )
  349. # 下载选项
  350. st.subheader("📥 导出报告")
  351. col1, col2 = st.columns(2)
  352. with col1:
  353. # 导出Excel
  354. excel_buffer = BytesIO()
  355. df_pages.to_excel(excel_buffer, index=False, sheet_name='验证结果')
  356. st.download_button(
  357. label="📊 下载Excel报告",
  358. data=excel_buffer.getvalue(),
  359. file_name=f"batch_validation_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.xlsx",
  360. mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
  361. )
  362. with col2:
  363. # 导出JSON
  364. json_data = json.dumps(batch_results, ensure_ascii=False, indent=2)
  365. st.download_button(
  366. label="📄 下载JSON报告",
  367. data=json_data,
  368. file_name=f"batch_validation_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.json",
  369. mime="application/json"
  370. )
  371. @st.dialog("查看交叉验证结果", width="large", dismissible=True, on_dismiss="rerun")
  372. def show_batch_cross_validation_results_dialog():
  373. """显示批量验证结果对话框"""
  374. if 'cross_validation_batch_result' in st.session_state and st.session_state.cross_validation_batch_result:
  375. display_batch_validation_results(st.session_state.cross_validation_batch_result)
  376. else:
  377. st.info("暂无交叉验证结果,请先运行交叉验证")