| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453 |
- """
- 交叉验证功能模块
- """
- import streamlit as st
- import pandas as pd
- import json
- from pathlib import Path
- from io import BytesIO
- import plotly.express as px
- from comparator import compare_ocr_results
- from ocr_validator_utils import get_data_source_display_name
- @st.dialog("交叉验证", width="large", dismissible=True, on_dismiss="rerun")
- def cross_validation_dialog(validator):
- """交叉验证对话框"""
- if validator.current_source_key == validator.verify_source_key:
- st.error("❌ OCR数据源和验证数据源不能相同")
- return
-
- if 'cross_validation_batch_result' not in st.session_state:
- st.session_state.cross_validation_batch_result = None
-
- st.header("🔄 批量交叉验证")
-
- col1, col2 = st.columns(2)
- with col1:
- st.info(f"**OCR数据源:** {get_data_source_display_name(validator.current_source_config)}")
- st.write(f"📁 文件数量: {len(validator.file_info)}")
- with col2:
- st.info(f"**验证数据源:** {get_data_source_display_name(validator.verify_source_config)}")
- st.write(f"📁 文件数量: {len(validator.verify_file_info)}")
-
- with st.expander("⚙️ 验证选项", expanded=True):
- col1, col2 = st.columns(2)
- with col1:
- table_mode = st.selectbox(
- "表格比对模式",
- options=['standard', 'flow_list'],
- index=1,
- format_func=lambda x: '流水表格模式' if x == 'flow_list' else '标准模式',
- help="选择表格比对算法"
- )
- with col2:
- similarity_algorithm = st.selectbox(
- "相似度算法",
- options=['ratio', 'partial_ratio', 'token_sort_ratio', 'token_set_ratio'],
- index=0,
- help="选择文本相似度计算算法"
- )
-
- if st.button("🚀 开始批量验证", type="primary", width='stretch'):
- run_batch_cross_validation(validator, table_mode, similarity_algorithm)
-
- if 'cross_validation_batch_result' in st.session_state and st.session_state.cross_validation_batch_result:
- st.markdown("---")
- display_batch_validation_results(st.session_state.cross_validation_batch_result)
- def run_batch_cross_validation(validator, table_mode: str, similarity_algorithm: str):
- """执行批量交叉验证"""
- pre_validation_dir = Path(validator.config['pre_validation'].get('out_dir', './output/pre_validation/')).resolve()
- pre_validation_dir.mkdir(parents=True, exist_ok=True)
-
- batch_results = _initialize_batch_results(validator, table_mode, similarity_algorithm)
-
- progress_bar = st.progress(0)
- status_text = st.empty()
-
- ocr_page_map = {info['page']: i for i, info in enumerate(validator.file_info)}
- verify_page_map = {info['page']: i for i, info in enumerate(validator.verify_file_info)}
-
- common_pages = sorted(set(ocr_page_map.keys()) & set(verify_page_map.keys()))
-
- if not common_pages:
- st.error("❌ 两个数据源没有共同的页码,无法进行对比")
- return
-
- batch_results['summary']['total_pages'] = len(common_pages)
-
- with st.expander("📋 详细对比日志", expanded=True):
- log_container = st.container()
-
- for idx, page_num in enumerate(common_pages):
- try:
- progress = (idx + 1) / len(common_pages)
- progress_bar.progress(progress)
- status_text.text(f"正在对比第 {page_num} 页... ({idx + 1}/{len(common_pages)})")
-
- ocr_file_index = ocr_page_map[page_num]
- verify_file_index = verify_page_map[page_num]
-
- ocr_md_path = Path(validator.file_paths[ocr_file_index]).with_suffix('.md')
- verify_md_path = Path(validator.verify_file_paths[verify_file_index]).with_suffix('.md')
-
- if not ocr_md_path.exists() or not verify_md_path.exists():
- with log_container:
- st.warning(f"⚠️ 第 {page_num} 页:文件不存在,跳过")
- batch_results['summary']['failed_pages'] += 1
- continue
-
- comparison_result_path = pre_validation_dir / f"{ocr_md_path.stem}_cross_validation"
-
- import io
- import contextlib
-
- output_buffer = io.StringIO()
-
- with contextlib.redirect_stdout(output_buffer):
- comparison_result = compare_ocr_results(
- file1_path=str(ocr_md_path),
- file2_path=str(verify_md_path),
- output_file=str(comparison_result_path),
- output_format='both',
- ignore_images=True,
- table_mode=table_mode,
- similarity_algorithm=similarity_algorithm
- )
-
- _process_comparison_result(batch_results, comparison_result, page_num,
- ocr_md_path, verify_md_path, comparison_result_path)
-
- with log_container:
- if comparison_result['statistics']['total_differences'] == 0:
- st.success(f"✅ 第 {page_num} 页:完全匹配")
- else:
- st.warning(f"⚠️ 第 {page_num} 页:发现 {comparison_result['statistics']['total_differences']} 个差异")
-
- except Exception as e:
- with log_container:
- st.error(f"❌ 第 {page_num} 页:对比失败 - {str(e)}")
-
- batch_results['pages'].append({
- 'page_num': page_num,
- 'status': 'failed',
- 'error': str(e)
- })
- batch_results['summary']['failed_pages'] += 1
-
- _save_batch_results(validator, batch_results, pre_validation_dir)
-
- progress_bar.progress(1.0)
- status_text.text("✅ 批量验证完成!")
-
- st.success(f"🎉 批量验证完成!成功: {batch_results['summary']['successful_pages']}, 失败: {batch_results['summary']['failed_pages']}")
- def _initialize_batch_results(validator, table_mode: str, similarity_algorithm: str) -> dict:
- """初始化批量结果存储"""
- return {
- 'ocr_source': get_data_source_display_name(validator.current_source_config),
- 'verify_source': get_data_source_display_name(validator.verify_source_config),
- 'table_mode': table_mode,
- 'similarity_algorithm': similarity_algorithm,
- 'timestamp': pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S'),
- 'pages': [],
- 'summary': {
- 'total_pages': 0,
- 'successful_pages': 0,
- 'failed_pages': 0,
- 'total_differences': 0,
- 'total_table_differences': 0,
- 'total_amount_differences': 0,
- 'total_datetime_differences': 0,
- 'total_text_differences': 0,
- 'total_paragraph_differences': 0,
- 'total_table_pre_header': 0,
- 'total_table_header_position': 0,
- 'total_table_header_critical': 0,
- 'total_table_row_missing': 0,
- 'total_high_severity': 0,
- 'total_medium_severity': 0,
- 'total_low_severity': 0
- }
- }
- def _process_comparison_result(batch_results: dict, comparison_result: dict, page_num: int,
- ocr_md_path: Path, verify_md_path: Path, comparison_result_path: Path):
- """处理对比结果"""
- stats = comparison_result['statistics']
-
- page_result = {
- 'page_num': page_num,
- 'ocr_file': str(ocr_md_path.name),
- 'verify_file': str(verify_md_path.name),
- 'total_differences': stats['total_differences'],
- 'table_differences': stats['table_differences'],
- 'amount_differences': stats.get('amount_differences', 0),
- 'datetime_differences': stats.get('datetime_differences', 0),
- 'text_differences': stats.get('text_differences', 0),
- 'paragraph_differences': stats['paragraph_differences'],
- 'table_pre_header': stats.get('table_pre_header', 0),
- 'table_header_position': stats.get('table_header_position', 0),
- 'table_header_critical': stats.get('table_header_critical', 0),
- 'table_row_missing': stats.get('table_row_missing', 0),
- 'high_severity': stats.get('high_severity', 0),
- 'medium_severity': stats.get('medium_severity', 0),
- 'low_severity': stats.get('low_severity', 0),
- 'status': 'success',
- 'comparison_json': f"{comparison_result_path}.json",
- 'comparison_md': f"{comparison_result_path}.md"
- }
-
- batch_results['pages'].append(page_result)
- batch_results['summary']['successful_pages'] += 1
-
- # 更新汇总统计
- for key in stats:
- total_key = f'total_{key}'
- if total_key in batch_results['summary']:
- batch_results['summary'][total_key] += stats.get(key, 0)
- def _save_batch_results(validator, batch_results: dict, pre_validation_dir: Path):
- """保存批量结果"""
- batch_result_path = pre_validation_dir / f"{validator.current_source_config['name']}_{validator.current_source_config['ocr_tool']}_vs_{validator.verify_source_config['ocr_tool']}_batch_cross_validation"
-
- with open(f"{batch_result_path}.json", "w", encoding="utf-8") as f:
- json.dump(batch_results, f, ensure_ascii=False, indent=2)
-
- generate_batch_validation_markdown(batch_results, f"{batch_result_path}.md")
-
- st.session_state.cross_validation_batch_result = batch_results
- def generate_batch_validation_markdown(batch_results: dict, output_path: str):
- """生成批量验证的Markdown报告"""
- with open(output_path, "w", encoding="utf-8") as f:
- f.write("# 批量交叉验证报告\n\n")
-
- # 基本信息
- f.write("## 📋 基本信息\n\n")
- f.write(f"- **OCR数据源:** {batch_results['ocr_source']}\n")
- f.write(f"- **验证数据源:** {batch_results['verify_source']}\n")
- f.write(f"- **表格模式:** {batch_results['table_mode']}\n")
- f.write(f"- **相似度算法:** {batch_results['similarity_algorithm']}\n")
- f.write(f"- **验证时间:** {batch_results['timestamp']}\n\n")
-
- # 汇总统计
- summary = batch_results['summary']
- f.write("## 📊 汇总统计\n\n")
- f.write(f"- **总页数:** {summary['total_pages']}\n")
- f.write(f"- **成功页数:** {summary['successful_pages']}\n")
- f.write(f"- **失败页数:** {summary['failed_pages']}\n")
- f.write(f"- **总差异数:** {summary['total_differences']}\n")
- f.write(f"- **表格差异:** {summary['total_table_differences']}\n")
- f.write(f" - 金额差异: {summary.get('total_amount_differences', 0)}\n")
- f.write(f" - 日期差异: {summary.get('total_datetime_differences', 0)}\n")
- f.write(f" - 文本差异: {summary.get('total_text_differences', 0)}\n")
- f.write(f" - 表头前差异: {summary.get('total_table_pre_header', 0)}\n")
- f.write(f" - 表头位置差异: {summary.get('total_table_header_position', 0)}\n")
- f.write(f" - 表头严重错误: {summary.get('total_table_header_critical', 0)}\n")
- f.write(f" - 行缺失: {summary.get('total_table_row_missing', 0)}\n")
- f.write(f"- **段落差异:** {summary['total_paragraph_differences']}\n")
- f.write(f"- **严重程度统计:**\n")
- f.write(f" - 高严重度: {summary.get('total_high_severity', 0)}\n")
- f.write(f" - 中严重度: {summary.get('total_medium_severity', 0)}\n")
- f.write(f" - 低严重度: {summary.get('total_low_severity', 0)}\n\n")
-
- # 详细结果表格
- f.write("## 📄 各页差异统计\n\n")
- f.write("| 页码 | 状态 | 总差异 | 表格差异 | 金额 | 日期 | 文本 | 段落 | 表头前 | 表头位置 | 表头错误 | 行缺失 | 高 | 中 | 低 |\n")
- f.write("|------|------|--------|----------|------|------|------|------|--------|----------|----------|--------|----|----|----|\n")
-
- for page in batch_results['pages']:
- if page['status'] == 'success':
- status_icon = "✅" if page['total_differences'] == 0 else "⚠️"
- f.write(f"| {page['page_num']} | {status_icon} | ")
- f.write(f"{page['total_differences']} | ")
- f.write(f"{page['table_differences']} | ")
- f.write(f"{page.get('amount_differences', 0)} | ")
- f.write(f"{page.get('datetime_differences', 0)} | ")
- f.write(f"{page.get('text_differences', 0)} | ")
- f.write(f"{page['paragraph_differences']} | ")
- f.write(f"{page.get('table_pre_header', 0)} | ")
- f.write(f"{page.get('table_header_position', 0)} | ")
- f.write(f"{page.get('table_header_critical', 0)} | ")
- f.write(f"{page.get('table_row_missing', 0)} | ")
- f.write(f"{page.get('high_severity', 0)} | ")
- f.write(f"{page.get('medium_severity', 0)} | ")
- f.write(f"{page.get('low_severity', 0)} |\n")
- else:
- f.write(f"| {page['page_num']} | ❌ | - | - | - | - | - | - | - | - | - | - | - | - | - |\n")
-
- f.write("\n")
-
- # 问题汇总
- f.write("## 🔍 问题汇总\n\n")
-
- high_diff_pages = [p for p in batch_results['pages']
- if p['status'] == 'success' and p['total_differences'] > 10]
- if high_diff_pages:
- f.write("### ⚠️ 高差异页面(差异>10)\n\n")
- for page in high_diff_pages:
- f.write(f"- 第 {page['page_num']} 页:{page['total_differences']} 个差异\n")
- f.write("\n")
-
- amount_error_pages = [p for p in batch_results['pages']
- if p['status'] == 'success' and p.get('amount_differences', 0) > 0]
- if amount_error_pages:
- f.write("### 💰 金额差异页面\n\n")
- for page in amount_error_pages:
- f.write(f"- 第 {page['page_num']} 页:{page.get('amount_differences', 0)} 个金额差异\n")
- f.write("\n")
-
- header_error_pages = [p for p in batch_results['pages']
- if p['status'] == 'success' and p.get('table_header_critical', 0) > 0]
- if header_error_pages:
- f.write("### ❌ 表头严重错误页面\n\n")
- for page in header_error_pages:
- f.write(f"- 第 {page['page_num']} 页:{page['table_header_critical']} 个表头错误\n")
- f.write("\n")
-
- failed_pages = [p for p in batch_results['pages'] if p['status'] == 'failed']
- if failed_pages:
- f.write("### 💥 验证失败页面\n\n")
- for page in failed_pages:
- f.write(f"- 第 {page['page_num']} 页:{page.get('error', '未知错误')}\n")
- f.write("\n")
- def display_batch_validation_results(batch_results: dict):
- """显示批量验证结果"""
- st.header("📊 批量验证结果")
-
- summary = batch_results['summary']
-
- col1, col2, col3, col4 = st.columns(4)
- with col1:
- st.metric("总页数", summary['total_pages'])
- with col2:
- st.metric("成功页数", summary['successful_pages'],
- delta=f"{summary['successful_pages']/summary['total_pages']*100:.1f}%")
- with col3:
- st.metric("失败页数", summary['failed_pages'],
- delta=f"-{summary['failed_pages']}" if summary['failed_pages'] > 0 else "0")
- with col4:
- st.metric("总差异数", summary['total_differences'])
-
- # ✅ 详细差异类型统计 - 更新展示
- st.subheader("📈 差异类型统计")
-
- col1, col2, col3 = st.columns(3)
- with col1:
- st.metric("表格差异", summary['total_table_differences'])
- st.caption(f"金额: {summary.get('total_amount_differences', 0)} | 日期: {summary.get('total_datetime_differences', 0)} | 文本: {summary.get('total_text_differences', 0)}")
- with col2:
- st.metric("段落差异", summary['total_paragraph_differences'])
- with col3:
- st.metric("严重度", f"高:{summary.get('total_high_severity', 0)} 中:{summary.get('total_medium_severity', 0)} 低:{summary.get('total_low_severity', 0)}")
-
- # 表格结构差异统计
- with st.expander("📋 表格结构差异详情", expanded=False):
- col1, col2, col3, col4 = st.columns(4)
- with col1:
- st.metric("表头前", summary.get('total_table_pre_header', 0))
- with col2:
- st.metric("表头位置", summary.get('total_table_header_position', 0))
- with col3:
- st.metric("表头错误", summary.get('total_table_header_critical', 0))
- with col4:
- st.metric("行缺失", summary.get('total_table_row_missing', 0))
-
- # ✅ 各页详细结果表格 - 更新列
- st.subheader("📄 各页详细结果")
-
- # 准备DataFrame
- page_data = []
- for page in batch_results['pages']:
- if page['status'] == 'success':
- page_data.append({
- '页码': page['page_num'],
- '状态': '✅ 成功' if page['total_differences'] == 0 else '⚠️ 有差异',
- '总差异': page['total_differences'],
- '表格差异': page['table_differences'],
- '金额': page.get('amount_differences', 0),
- '日期': page.get('datetime_differences', 0),
- '文本': page.get('text_differences', 0),
- '段落': page['paragraph_differences'],
- '表头前': page.get('table_pre_header', 0),
- '表头位置': page.get('table_header_position', 0),
- '表头错误': page.get('table_header_critical', 0),
- '行缺失': page.get('table_row_missing', 0),
- '高': page.get('high_severity', 0),
- '中': page.get('medium_severity', 0),
- '低': page.get('low_severity', 0)
- })
- else:
- page_data.append({
- '页码': page['page_num'],
- '状态': '❌ 失败',
- '总差异': '-', '表格差异': '-', '金额': '-', '日期': '-',
- '文本': '-', '段落': '-', '表头前': '-', '表头位置': '-',
- '表头错误': '-', '行缺失': '-', '高': '-', '中': '-', '低': '-'
- })
-
- df_pages = pd.DataFrame(page_data)
-
- # 显示表格
- st.dataframe(
- df_pages,
- width='stretch',
- hide_index=True,
- column_config={
- "页码": st.column_config.NumberColumn("页码", width="small"),
- "状态": st.column_config.TextColumn("状态", width="small"),
- "总差异": st.column_config.NumberColumn("总差异", width="small"),
- "表格差异": st.column_config.NumberColumn("表格", width="small"),
- "金额": st.column_config.NumberColumn("金额", width="small"),
- "日期": st.column_config.NumberColumn("日期", width="small"),
- "文本": st.column_config.NumberColumn("文本", width="small"),
- "段落": st.column_config.NumberColumn("段落", width="small"),
- }
- )
-
- # 下载选项
- st.subheader("📥 导出报告")
-
- col1, col2 = st.columns(2)
-
- with col1:
- # 导出Excel
- excel_buffer = BytesIO()
- df_pages.to_excel(excel_buffer, index=False, sheet_name='验证结果')
-
- st.download_button(
- label="📊 下载Excel报告",
- data=excel_buffer.getvalue(),
- file_name=f"batch_validation_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.xlsx",
- mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
- )
-
- with col2:
- # 导出JSON
- json_data = json.dumps(batch_results, ensure_ascii=False, indent=2)
-
- st.download_button(
- label="📄 下载JSON报告",
- data=json_data,
- file_name=f"batch_validation_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.json",
- mime="application/json"
- )
- @st.dialog("查看交叉验证结果", width="large", dismissible=True, on_dismiss="rerun")
- def show_batch_cross_validation_results_dialog():
- """显示批量验证结果对话框"""
- if 'cross_validation_batch_result' in st.session_state and st.session_state.cross_validation_batch_result:
- display_batch_validation_results(st.session_state.cross_validation_batch_result)
- else:
- st.info("暂无交叉验证结果,请先运行交叉验证")
|