""" 批量模板应用控件 """ import streamlit as st import json from pathlib import Path from PIL import Image from typing import Dict, List import sys # 添加父目录到路径 sys.path.insert(0, str(Path(__file__).parent.parent)) from table_template_applier import TableTemplateApplier from table_line_generator import TableLineGenerator def create_batch_template_section(current_line_width: int, current_line_color: str): """ 创建批量应用模板的控制区域 Args: current_line_width: 当前页使用的线条宽度 current_line_color: 当前页使用的线条颜色名称 要求: - 当前在目录模式 - 已有标注(edit 模式) - 有可用的目录清单 """ # 检查前置条件 if 'structure' not in st.session_state or not st.session_state.structure: return if 'current_catalog' not in st.session_state: return if 'current_output_config' not in st.session_state: return # 🔑 检查当前页是否有保存的结构文件 if 'loaded_config_name' not in st.session_state or not st.session_state.loaded_config_name: st.warning("⚠️ 当前页未保存结构文件,请先保存后再批量应用") return st.divider() st.subheader("🔄 批量应用模板") catalog = st.session_state.current_catalog current_index = st.session_state.get('current_catalog_index', 0) current_entry = catalog[current_index] # 统计信息 total_files = len(catalog) current_page = current_entry["index"] # 找出哪些页面还没有标注 output_config = st.session_state.current_output_config output_dir = Path(output_config.get("directory", "output/table_structures")) structure_suffix = output_config.get("structure_suffix", "_structure.json") # 🔑 获取当前页的结构文件路径 current_base_name = current_entry["json"].stem current_structure_file = output_dir / f"{current_base_name}{structure_suffix}" if not current_structure_file.exists(): st.error("❌ 未找到当前页的结构文件,请先保存") st.info(f"期望文件: {current_structure_file}") return unlabeled_pages = [] for entry in catalog: if entry["index"] == current_page: continue # 跳过当前页 structure_file = output_dir / f"{entry['json'].stem}{structure_suffix}" if not structure_file.exists(): unlabeled_pages.append(entry) st.info( f"📊 当前页: {current_page}/{total_files}\n\n" f"📄 模板文件: {current_structure_file.name}\n\n" f"✅ 已标注: {total_files - len(unlabeled_pages)} 页\n\n" f"⏳ 待处理: {len(unlabeled_pages)} 页" ) if len(unlabeled_pages) == 0: st.success("🎉 所有页面都已标注!") return # 🔑 使用当前页的设置 st.info( f"🎨 将使用当前页设置:\n\n" f"• 线条宽度: {current_line_width}px\n\n" f"• 线条颜色: {current_line_color}" ) # 获取颜色配置 line_colors = output_config.get("line_colors") or [ {"name": "黑色", "rgb": [0, 0, 0]}, {"name": "蓝色", "rgb": [0, 0, 255]}, {"name": "红色", "rgb": [255, 0, 0]}, ] # 🔑 从颜色名称映射到 RGB color_map = {c["name"]: tuple(c["rgb"]) for c in line_colors} line_color = color_map.get(current_line_color, (0, 0, 0)) # 应用按钮 if st.button("🚀 批量应用到所有未标注页面", type="primary"): _apply_template_batch( current_structure_file, # 🔑 直接使用保存的结构文件 current_entry, unlabeled_pages, output_dir, structure_suffix, current_line_width, line_color ) def _apply_template_batch( template_file: Path, # 🔑 改为直接传入模板文件路径 template_entry: Dict, target_entries: List[Dict], output_dir: Path, structure_suffix: str, line_width: int, line_color: tuple ): """ 执行批量应用模板 Args: template_file: 模板结构文件路径 template_entry: 模板页面条目 target_entries: 目标页面列表 output_dir: 输出目录 structure_suffix: 结构文件后缀 line_width: 线条宽度 line_color: 线条颜色 (r, g, b) """ try: # 🔑 直接使用保存的结构文件创建模板应用器 applier = TableTemplateApplier(str(template_file)) st.info(f"📋 使用模板: {template_file.name}") # 进度条 progress_bar = st.progress(0) status_text = st.empty() success_count = 0 failed_count = 0 results = [] for idx, entry in enumerate(target_entries): # 更新进度 progress = (idx + 1) / len(target_entries) progress_bar.progress(progress) status_text.text(f"处理中: {entry['display']} ({idx + 1}/{len(target_entries)})") try: # 加载 OCR 数据 with open(entry["json"], "r", encoding="utf-8") as fp: raw = json.load(fp) # 解析 OCR 数据 if 'parsing_res_list' in raw and 'overall_ocr_res' in raw: table_bbox, ocr_data = TableLineGenerator.parse_ppstructure_result(raw) else: raise ValueError("不支持的 OCR 格式") # 加载图片 if entry["image"] and entry["image"].exists(): image = Image.open(entry["image"]) else: st.warning(f"⚠️ 跳过 {entry['display']}: 未找到图片") failed_count += 1 results.append({ 'page': entry['index'], 'status': 'skipped', 'reason': 'no_image' }) continue # 应用模板生成图片 img_with_lines = applier.apply_to_image( image, ocr_data, line_width=line_width, line_color=line_color ) # 生成结构配置 structure = applier.generate_structure_for_image(ocr_data) # 保存图片 base_name = entry["json"].stem image_suffix = st.session_state.current_output_config.get("image_suffix", ".png") output_image_path = output_dir / f"{base_name}{image_suffix}" img_with_lines.save(output_image_path) # 🔑 保存结构(确保 set 转为 list) structure_path = output_dir / f"{base_name}{structure_suffix}" with open(structure_path, 'w', encoding='utf-8') as f: json.dump(structure, f, indent=2, ensure_ascii=False) success_count += 1 results.append({ 'page': entry['index'], 'status': 'success', 'image': str(output_image_path), 'structure': str(structure_path) }) except Exception as e: failed_count += 1 results.append({ 'page': entry['index'], 'status': 'error', 'error': str(e) }) st.error(f"❌ 处理失败 {entry['display']}: {e}") # 完成 progress_bar.progress(1.0) status_text.empty() # 保存批处理结果 batch_result_path = output_dir / "batch_results.json" with open(batch_result_path, 'w', encoding='utf-8') as f: json.dump({ 'template': template_entry['display'], 'template_file': str(template_file), 'total': len(target_entries), 'success': success_count, 'failed': failed_count, 'line_width': line_width, 'line_color': line_color, 'results': results }, f, indent=2, ensure_ascii=False) # 显示结果 if success_count > 0: st.success( f"✅ 批量应用完成!\n\n" f"成功: {success_count} 页\n\n" f"失败: {failed_count} 页" ) # 🔑 提供下载批处理结果 with open(batch_result_path, 'r', encoding='utf-8') as f: st.download_button( "📥 下载批处理报告", f.read(), file_name="batch_results.json", mime="application/json" ) else: st.error("❌ 批量应用失败,没有成功处理任何页面") # 显示详细结果 with st.expander("📋 详细结果"): for result in results: if result['status'] == 'success': st.success(f"✅ 第 {result['page']} 页") elif result['status'] == 'error': st.error(f"❌ 第 {result['page']} 页: {result.get('error', '未知错误')}") else: st.warning(f"⚠️ 第 {result['page']} 页: {result.get('reason', '跳过')}") except Exception as e: st.error(f"❌ 批量应用过程中发生错误: {e}") import traceback with st.expander("🔍 详细错误信息"): st.code(traceback.format_exc())