""" 交叉验证功能模块 """ import streamlit as st import pandas as pd import json from pathlib import Path from io import BytesIO import plotly.express as px from compare_ocr_results import compare_ocr_results from ocr_validator_utils import get_data_source_display_name @st.dialog("交叉验证", width="large", dismissible=True, on_dismiss="rerun") def cross_validation_dialog(validator): """交叉验证对话框""" if validator.current_source_key == validator.verify_source_key: st.error("❌ OCR数据源和验证数据源不能相同") return if 'cross_validation_batch_result' not in st.session_state: st.session_state.cross_validation_batch_result = None st.header("🔄 批量交叉验证") col1, col2 = st.columns(2) with col1: st.info(f"**OCR数据源:** {get_data_source_display_name(validator.current_source_config)}") st.write(f"📁 文件数量: {len(validator.file_info)}") with col2: st.info(f"**验证数据源:** {get_data_source_display_name(validator.verify_source_config)}") st.write(f"📁 文件数量: {len(validator.verify_file_info)}") with st.expander("⚙️ 验证选项", expanded=True): col1, col2 = st.columns(2) with col1: table_mode = st.selectbox( "表格比对模式", options=['standard', 'flow_list'], index=1, format_func=lambda x: '流水表格模式' if x == 'flow_list' else '标准模式', help="选择表格比对算法" ) with col2: similarity_algorithm = st.selectbox( "相似度算法", options=['ratio', 'partial_ratio', 'token_sort_ratio', 'token_set_ratio'], index=0, help="选择文本相似度计算算法" ) if st.button("🚀 开始批量验证", type="primary", use_container_width=True): run_batch_cross_validation(validator, table_mode, similarity_algorithm) if 'cross_validation_batch_result' in st.session_state and st.session_state.cross_validation_batch_result: st.markdown("---") display_batch_validation_results(st.session_state.cross_validation_batch_result) def run_batch_cross_validation(validator, table_mode: str, similarity_algorithm: str): """执行批量交叉验证""" pre_validation_dir = Path(validator.config['pre_validation'].get('out_dir', './output/pre_validation/')).resolve() pre_validation_dir.mkdir(parents=True, exist_ok=True) batch_results = _initialize_batch_results(validator, table_mode, similarity_algorithm) progress_bar = st.progress(0) status_text = st.empty() ocr_page_map = {info['page']: i for i, info in enumerate(validator.file_info)} verify_page_map = {info['page']: i for i, info in enumerate(validator.verify_file_info)} common_pages = sorted(set(ocr_page_map.keys()) & set(verify_page_map.keys())) if not common_pages: st.error("❌ 两个数据源没有共同的页码,无法进行对比") return batch_results['summary']['total_pages'] = len(common_pages) with st.expander("📋 详细对比日志", expanded=True): log_container = st.container() for idx, page_num in enumerate(common_pages): try: progress = (idx + 1) / len(common_pages) progress_bar.progress(progress) status_text.text(f"正在对比第 {page_num} 页... ({idx + 1}/{len(common_pages)})") ocr_file_index = ocr_page_map[page_num] verify_file_index = verify_page_map[page_num] ocr_md_path = Path(validator.file_paths[ocr_file_index]).with_suffix('.md') verify_md_path = Path(validator.verify_file_paths[verify_file_index]).with_suffix('.md') if not ocr_md_path.exists() or not verify_md_path.exists(): with log_container: st.warning(f"⚠️ 第 {page_num} 页:文件不存在,跳过") batch_results['summary']['failed_pages'] += 1 continue comparison_result_path = pre_validation_dir / f"{ocr_md_path.stem}_cross_validation" import io import contextlib output_buffer = io.StringIO() with contextlib.redirect_stdout(output_buffer): comparison_result = compare_ocr_results( file1_path=str(ocr_md_path), file2_path=str(verify_md_path), output_file=str(comparison_result_path), output_format='both', ignore_images=True, table_mode=table_mode, similarity_algorithm=similarity_algorithm ) _process_comparison_result(batch_results, comparison_result, page_num, ocr_md_path, verify_md_path, comparison_result_path) with log_container: if comparison_result['statistics']['total_differences'] == 0: st.success(f"✅ 第 {page_num} 页:完全匹配") else: st.warning(f"⚠️ 第 {page_num} 页:发现 {comparison_result['statistics']['total_differences']} 个差异") except Exception as e: with log_container: st.error(f"❌ 第 {page_num} 页:对比失败 - {str(e)}") batch_results['pages'].append({ 'page_num': page_num, 'status': 'failed', 'error': str(e) }) batch_results['summary']['failed_pages'] += 1 _save_batch_results(validator, batch_results, pre_validation_dir) progress_bar.progress(1.0) status_text.text("✅ 批量验证完成!") st.success(f"🎉 批量验证完成!成功: {batch_results['summary']['successful_pages']}, 失败: {batch_results['summary']['failed_pages']}") def _initialize_batch_results(validator, table_mode: str, similarity_algorithm: str) -> dict: """初始化批量结果存储""" return { 'ocr_source': get_data_source_display_name(validator.current_source_config), 'verify_source': get_data_source_display_name(validator.verify_source_config), 'table_mode': table_mode, 'similarity_algorithm': similarity_algorithm, 'timestamp': pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S'), 'pages': [], 'summary': { 'total_pages': 0, 'successful_pages': 0, 'failed_pages': 0, 'total_differences': 0, 'total_table_differences': 0, 'total_amount_differences': 0, 'total_datetime_differences': 0, 'total_text_differences': 0, 'total_paragraph_differences': 0, 'total_table_pre_header': 0, 'total_table_header_position': 0, 'total_table_header_critical': 0, 'total_table_row_missing': 0, 'total_high_severity': 0, 'total_medium_severity': 0, 'total_low_severity': 0 } } def _process_comparison_result(batch_results: dict, comparison_result: dict, page_num: int, ocr_md_path: Path, verify_md_path: Path, comparison_result_path: Path): """处理对比结果""" stats = comparison_result['statistics'] page_result = { 'page_num': page_num, 'ocr_file': str(ocr_md_path.name), 'verify_file': str(verify_md_path.name), 'total_differences': stats['total_differences'], 'table_differences': stats['table_differences'], 'amount_differences': stats.get('amount_differences', 0), 'datetime_differences': stats.get('datetime_differences', 0), 'text_differences': stats.get('text_differences', 0), 'paragraph_differences': stats['paragraph_differences'], 'table_pre_header': stats.get('table_pre_header', 0), 'table_header_position': stats.get('table_header_position', 0), 'table_header_critical': stats.get('table_header_critical', 0), 'table_row_missing': stats.get('table_row_missing', 0), 'high_severity': stats.get('high_severity', 0), 'medium_severity': stats.get('medium_severity', 0), 'low_severity': stats.get('low_severity', 0), 'status': 'success', 'comparison_json': f"{comparison_result_path}.json", 'comparison_md': f"{comparison_result_path}.md" } batch_results['pages'].append(page_result) batch_results['summary']['successful_pages'] += 1 # 更新汇总统计 for key in stats: total_key = f'total_{key}' if total_key in batch_results['summary']: batch_results['summary'][total_key] += stats.get(key, 0) def _save_batch_results(validator, batch_results: dict, pre_validation_dir: Path): """保存批量结果""" batch_result_path = pre_validation_dir / f"{validator.current_source_config['name']}_{validator.current_source_config['ocr_tool']}_vs_{validator.verify_source_config['ocr_tool']}_batch_cross_validation" with open(f"{batch_result_path}.json", "w", encoding="utf-8") as f: json.dump(batch_results, f, ensure_ascii=False, indent=2) generate_batch_validation_markdown(batch_results, f"{batch_result_path}.md") st.session_state.cross_validation_batch_result = batch_results def generate_batch_validation_markdown(batch_results: dict, output_path: str): """生成批量验证的Markdown报告""" with open(output_path, "w", encoding="utf-8") as f: f.write("# 批量交叉验证报告\n\n") # 基本信息 f.write("## 📋 基本信息\n\n") f.write(f"- **OCR数据源:** {batch_results['ocr_source']}\n") f.write(f"- **验证数据源:** {batch_results['verify_source']}\n") f.write(f"- **表格模式:** {batch_results['table_mode']}\n") f.write(f"- **相似度算法:** {batch_results['similarity_algorithm']}\n") f.write(f"- **验证时间:** {batch_results['timestamp']}\n\n") # 汇总统计 summary = batch_results['summary'] f.write("## 📊 汇总统计\n\n") f.write(f"- **总页数:** {summary['total_pages']}\n") f.write(f"- **成功页数:** {summary['successful_pages']}\n") f.write(f"- **失败页数:** {summary['failed_pages']}\n") f.write(f"- **总差异数:** {summary['total_differences']}\n") f.write(f"- **表格差异:** {summary['total_table_differences']}\n") f.write(f" - 金额差异: {summary.get('total_amount_differences', 0)}\n") f.write(f" - 日期差异: {summary.get('total_datetime_differences', 0)}\n") f.write(f" - 文本差异: {summary.get('total_text_differences', 0)}\n") f.write(f" - 表头前差异: {summary.get('total_table_pre_header', 0)}\n") f.write(f" - 表头位置差异: {summary.get('total_table_header_position', 0)}\n") f.write(f" - 表头严重错误: {summary.get('total_table_header_critical', 0)}\n") f.write(f" - 行缺失: {summary.get('total_table_row_missing', 0)}\n") f.write(f"- **段落差异:** {summary['total_paragraph_differences']}\n") f.write(f"- **严重程度统计:**\n") f.write(f" - 高严重度: {summary.get('total_high_severity', 0)}\n") f.write(f" - 中严重度: {summary.get('total_medium_severity', 0)}\n") f.write(f" - 低严重度: {summary.get('total_low_severity', 0)}\n\n") # 详细结果表格 f.write("## 📄 各页差异统计\n\n") f.write("| 页码 | 状态 | 总差异 | 表格差异 | 金额 | 日期 | 文本 | 段落 | 表头前 | 表头位置 | 表头错误 | 行缺失 | 高 | 中 | 低 |\n") f.write("|------|------|--------|----------|------|------|------|------|--------|----------|----------|--------|----|----|----|\n") for page in batch_results['pages']: if page['status'] == 'success': status_icon = "✅" if page['total_differences'] == 0 else "⚠️" f.write(f"| {page['page_num']} | {status_icon} | ") f.write(f"{page['total_differences']} | ") f.write(f"{page['table_differences']} | ") f.write(f"{page.get('amount_differences', 0)} | ") f.write(f"{page.get('datetime_differences', 0)} | ") f.write(f"{page.get('text_differences', 0)} | ") f.write(f"{page['paragraph_differences']} | ") f.write(f"{page.get('table_pre_header', 0)} | ") f.write(f"{page.get('table_header_position', 0)} | ") f.write(f"{page.get('table_header_critical', 0)} | ") f.write(f"{page.get('table_row_missing', 0)} | ") f.write(f"{page.get('high_severity', 0)} | ") f.write(f"{page.get('medium_severity', 0)} | ") f.write(f"{page.get('low_severity', 0)} |\n") else: f.write(f"| {page['page_num']} | ❌ | - | - | - | - | - | - | - | - | - | - | - | - | - |\n") f.write("\n") # 问题汇总 f.write("## 🔍 问题汇总\n\n") high_diff_pages = [p for p in batch_results['pages'] if p['status'] == 'success' and p['total_differences'] > 10] if high_diff_pages: f.write("### ⚠️ 高差异页面(差异>10)\n\n") for page in high_diff_pages: f.write(f"- 第 {page['page_num']} 页:{page['total_differences']} 个差异\n") f.write("\n") amount_error_pages = [p for p in batch_results['pages'] if p['status'] == 'success' and p.get('amount_differences', 0) > 0] if amount_error_pages: f.write("### 💰 金额差异页面\n\n") for page in amount_error_pages: f.write(f"- 第 {page['page_num']} 页:{page.get('amount_differences', 0)} 个金额差异\n") f.write("\n") header_error_pages = [p for p in batch_results['pages'] if p['status'] == 'success' and p.get('table_header_critical', 0) > 0] if header_error_pages: f.write("### ❌ 表头严重错误页面\n\n") for page in header_error_pages: f.write(f"- 第 {page['page_num']} 页:{page['table_header_critical']} 个表头错误\n") f.write("\n") failed_pages = [p for p in batch_results['pages'] if p['status'] == 'failed'] if failed_pages: f.write("### 💥 验证失败页面\n\n") for page in failed_pages: f.write(f"- 第 {page['page_num']} 页:{page.get('error', '未知错误')}\n") f.write("\n") def display_batch_validation_results(batch_results: dict): """显示批量验证结果""" st.header("📊 批量验证结果") summary = batch_results['summary'] col1, col2, col3, col4 = st.columns(4) with col1: st.metric("总页数", summary['total_pages']) with col2: st.metric("成功页数", summary['successful_pages'], delta=f"{summary['successful_pages']/summary['total_pages']*100:.1f}%") with col3: st.metric("失败页数", summary['failed_pages'], delta=f"-{summary['failed_pages']}" if summary['failed_pages'] > 0 else "0") with col4: st.metric("总差异数", summary['total_differences']) # ✅ 详细差异类型统计 - 更新展示 st.subheader("📈 差异类型统计") col1, col2, col3 = st.columns(3) with col1: st.metric("表格差异", summary['total_table_differences']) st.caption(f"金额: {summary.get('total_amount_differences', 0)} | 日期: {summary.get('total_datetime_differences', 0)} | 文本: {summary.get('total_text_differences', 0)}") with col2: st.metric("段落差异", summary['total_paragraph_differences']) with col3: st.metric("严重度", f"高:{summary.get('total_high_severity', 0)} 中:{summary.get('total_medium_severity', 0)} 低:{summary.get('total_low_severity', 0)}") # 表格结构差异统计 with st.expander("📋 表格结构差异详情", expanded=False): col1, col2, col3, col4 = st.columns(4) with col1: st.metric("表头前", summary.get('total_table_pre_header', 0)) with col2: st.metric("表头位置", summary.get('total_table_header_position', 0)) with col3: st.metric("表头错误", summary.get('total_table_header_critical', 0)) with col4: st.metric("行缺失", summary.get('total_table_row_missing', 0)) # ✅ 各页详细结果表格 - 更新列 st.subheader("📄 各页详细结果") # 准备DataFrame page_data = [] for page in batch_results['pages']: if page['status'] == 'success': page_data.append({ '页码': page['page_num'], '状态': '✅ 成功' if page['total_differences'] == 0 else '⚠️ 有差异', '总差异': page['total_differences'], '表格差异': page['table_differences'], '金额': page.get('amount_differences', 0), '日期': page.get('datetime_differences', 0), '文本': page.get('text_differences', 0), '段落': page['paragraph_differences'], '表头前': page.get('table_pre_header', 0), '表头位置': page.get('table_header_position', 0), '表头错误': page.get('table_header_critical', 0), '行缺失': page.get('table_row_missing', 0), '高': page.get('high_severity', 0), '中': page.get('medium_severity', 0), '低': page.get('low_severity', 0) }) else: page_data.append({ '页码': page['page_num'], '状态': '❌ 失败', '总差异': '-', '表格差异': '-', '金额': '-', '日期': '-', '文本': '-', '段落': '-', '表头前': '-', '表头位置': '-', '表头错误': '-', '行缺失': '-', '高': '-', '中': '-', '低': '-' }) df_pages = pd.DataFrame(page_data) # 显示表格 st.dataframe( df_pages, use_container_width=True, hide_index=True, column_config={ "页码": st.column_config.NumberColumn("页码", width="small"), "状态": st.column_config.TextColumn("状态", width="small"), "总差异": st.column_config.NumberColumn("总差异", width="small"), "表格差异": st.column_config.NumberColumn("表格", width="small"), "金额": st.column_config.NumberColumn("金额", width="small"), "日期": st.column_config.NumberColumn("日期", width="small"), "文本": st.column_config.NumberColumn("文本", width="small"), "段落": st.column_config.NumberColumn("段落", width="small"), } ) # 下载选项 st.subheader("📥 导出报告") col1, col2 = st.columns(2) with col1: # 导出Excel excel_buffer = BytesIO() df_pages.to_excel(excel_buffer, index=False, sheet_name='验证结果') st.download_button( label="📊 下载Excel报告", data=excel_buffer.getvalue(), file_name=f"batch_validation_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.xlsx", mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" ) with col2: # 导出JSON json_data = json.dumps(batch_results, ensure_ascii=False, indent=2) st.download_button( label="📄 下载JSON报告", data=json_data, file_name=f"batch_validation_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.json", mime="application/json" ) @st.dialog("查看交叉验证结果", width="large", dismissible=True, on_dismiss="rerun") def show_batch_cross_validation_results_dialog(): """显示批量验证结果对话框""" if 'cross_validation_batch_result' in st.session_state and st.session_state.cross_validation_batch_result: display_batch_validation_results(st.session_state.cross_validation_batch_result) else: st.info("暂无交叉验证结果,请先运行交叉验证")