|
|
@@ -0,0 +1,277 @@
|
|
|
+"""
|
|
|
+批量模板应用控件
|
|
|
+"""
|
|
|
+import streamlit as st
|
|
|
+import json
|
|
|
+from pathlib import Path
|
|
|
+from PIL import Image
|
|
|
+from typing import Dict, List
|
|
|
+import sys
|
|
|
+
|
|
|
+# 添加父目录到路径
|
|
|
+sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
|
+
|
|
|
+from table_template_applier import TableTemplateApplier
|
|
|
+from table_line_generator import TableLineGenerator
|
|
|
+
|
|
|
+
|
|
|
+def create_batch_template_section(current_line_width: int, current_line_color: str):
|
|
|
+ """
|
|
|
+ 创建批量应用模板的控制区域
|
|
|
+
|
|
|
+ Args:
|
|
|
+ current_line_width: 当前页使用的线条宽度
|
|
|
+ current_line_color: 当前页使用的线条颜色名称
|
|
|
+
|
|
|
+ 要求:
|
|
|
+ - 当前在目录模式
|
|
|
+ - 已有标注(edit 模式)
|
|
|
+ - 有可用的目录清单
|
|
|
+ """
|
|
|
+ # 检查前置条件
|
|
|
+ if 'structure' not in st.session_state or not st.session_state.structure:
|
|
|
+ return
|
|
|
+
|
|
|
+ if 'current_catalog' not in st.session_state:
|
|
|
+ return
|
|
|
+
|
|
|
+ if 'current_output_config' not in st.session_state:
|
|
|
+ return
|
|
|
+
|
|
|
+ # 🔑 检查当前页是否有保存的结构文件
|
|
|
+ if 'loaded_config_name' not in st.session_state or not st.session_state.loaded_config_name:
|
|
|
+ st.warning("⚠️ 当前页未保存结构文件,请先保存后再批量应用")
|
|
|
+ return
|
|
|
+
|
|
|
+ st.divider()
|
|
|
+ st.subheader("🔄 批量应用模板")
|
|
|
+
|
|
|
+ catalog = st.session_state.current_catalog
|
|
|
+ current_index = st.session_state.get('current_catalog_index', 0)
|
|
|
+ current_entry = catalog[current_index]
|
|
|
+
|
|
|
+ # 统计信息
|
|
|
+ total_files = len(catalog)
|
|
|
+ current_page = current_entry["index"]
|
|
|
+
|
|
|
+ # 找出哪些页面还没有标注
|
|
|
+ output_config = st.session_state.current_output_config
|
|
|
+ output_dir = Path(output_config.get("directory", "output/table_structures"))
|
|
|
+ structure_suffix = output_config.get("structure_suffix", "_structure.json")
|
|
|
+
|
|
|
+ # 🔑 获取当前页的结构文件路径
|
|
|
+ current_base_name = current_entry["json"].stem
|
|
|
+ current_structure_file = output_dir / f"{current_base_name}{structure_suffix}"
|
|
|
+
|
|
|
+ if not current_structure_file.exists():
|
|
|
+ st.error("❌ 未找到当前页的结构文件,请先保存")
|
|
|
+ st.info(f"期望文件: {current_structure_file}")
|
|
|
+ return
|
|
|
+
|
|
|
+ unlabeled_pages = []
|
|
|
+ for entry in catalog:
|
|
|
+ if entry["index"] == current_page:
|
|
|
+ continue # 跳过当前页
|
|
|
+ structure_file = output_dir / f"{entry['json'].stem}{structure_suffix}"
|
|
|
+ if not structure_file.exists():
|
|
|
+ unlabeled_pages.append(entry)
|
|
|
+
|
|
|
+ st.info(
|
|
|
+ f"📊 当前页: {current_page}/{total_files}\n\n"
|
|
|
+ f"📄 模板文件: {current_structure_file.name}\n\n"
|
|
|
+ f"✅ 已标注: {total_files - len(unlabeled_pages)} 页\n\n"
|
|
|
+ f"⏳ 待处理: {len(unlabeled_pages)} 页"
|
|
|
+ )
|
|
|
+
|
|
|
+ if len(unlabeled_pages) == 0:
|
|
|
+ st.success("🎉 所有页面都已标注!")
|
|
|
+ return
|
|
|
+
|
|
|
+ # 🔑 使用当前页的设置
|
|
|
+ st.info(
|
|
|
+ f"🎨 将使用当前页设置:\n\n"
|
|
|
+ f"• 线条宽度: {current_line_width}px\n\n"
|
|
|
+ f"• 线条颜色: {current_line_color}"
|
|
|
+ )
|
|
|
+
|
|
|
+ # 获取颜色配置
|
|
|
+ line_colors = output_config.get("line_colors") or [
|
|
|
+ {"name": "黑色", "rgb": [0, 0, 0]},
|
|
|
+ {"name": "蓝色", "rgb": [0, 0, 255]},
|
|
|
+ {"name": "红色", "rgb": [255, 0, 0]},
|
|
|
+ ]
|
|
|
+
|
|
|
+ # 🔑 从颜色名称映射到 RGB
|
|
|
+ color_map = {c["name"]: tuple(c["rgb"]) for c in line_colors}
|
|
|
+ line_color = color_map.get(current_line_color, (0, 0, 0))
|
|
|
+
|
|
|
+ # 应用按钮
|
|
|
+ if st.button("🚀 批量应用到所有未标注页面", type="primary"):
|
|
|
+ _apply_template_batch(
|
|
|
+ current_structure_file, # 🔑 直接使用保存的结构文件
|
|
|
+ current_entry,
|
|
|
+ unlabeled_pages,
|
|
|
+ output_dir,
|
|
|
+ structure_suffix,
|
|
|
+ current_line_width,
|
|
|
+ line_color
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+def _apply_template_batch(
|
|
|
+ template_file: Path, # 🔑 改为直接传入模板文件路径
|
|
|
+ template_entry: Dict,
|
|
|
+ target_entries: List[Dict],
|
|
|
+ output_dir: Path,
|
|
|
+ structure_suffix: str,
|
|
|
+ line_width: int,
|
|
|
+ line_color: tuple
|
|
|
+):
|
|
|
+ """
|
|
|
+ 执行批量应用模板
|
|
|
+
|
|
|
+ Args:
|
|
|
+ template_file: 模板结构文件路径
|
|
|
+ template_entry: 模板页面条目
|
|
|
+ target_entries: 目标页面列表
|
|
|
+ output_dir: 输出目录
|
|
|
+ structure_suffix: 结构文件后缀
|
|
|
+ line_width: 线条宽度
|
|
|
+ line_color: 线条颜色 (r, g, b)
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ # 🔑 直接使用保存的结构文件创建模板应用器
|
|
|
+ applier = TableTemplateApplier(str(template_file))
|
|
|
+
|
|
|
+ st.info(f"📋 使用模板: {template_file.name}")
|
|
|
+
|
|
|
+ # 进度条
|
|
|
+ progress_bar = st.progress(0)
|
|
|
+ status_text = st.empty()
|
|
|
+
|
|
|
+ success_count = 0
|
|
|
+ failed_count = 0
|
|
|
+ results = []
|
|
|
+
|
|
|
+ for idx, entry in enumerate(target_entries):
|
|
|
+ # 更新进度
|
|
|
+ progress = (idx + 1) / len(target_entries)
|
|
|
+ progress_bar.progress(progress)
|
|
|
+ status_text.text(f"处理中: {entry['display']} ({idx + 1}/{len(target_entries)})")
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 加载 OCR 数据
|
|
|
+ with open(entry["json"], "r", encoding="utf-8") as fp:
|
|
|
+ raw = json.load(fp)
|
|
|
+
|
|
|
+ # 解析 OCR 数据
|
|
|
+ if 'parsing_res_list' in raw and 'overall_ocr_res' in raw:
|
|
|
+ table_bbox, ocr_data = TableLineGenerator.parse_ppstructure_result(raw)
|
|
|
+ else:
|
|
|
+ raise ValueError("不支持的 OCR 格式")
|
|
|
+
|
|
|
+ # 加载图片
|
|
|
+ if entry["image"] and entry["image"].exists():
|
|
|
+ image = Image.open(entry["image"])
|
|
|
+ else:
|
|
|
+ st.warning(f"⚠️ 跳过 {entry['display']}: 未找到图片")
|
|
|
+ failed_count += 1
|
|
|
+ results.append({
|
|
|
+ 'page': entry['index'],
|
|
|
+ 'status': 'skipped',
|
|
|
+ 'reason': 'no_image'
|
|
|
+ })
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 应用模板生成图片
|
|
|
+ img_with_lines = applier.apply_to_image(
|
|
|
+ image,
|
|
|
+ ocr_data,
|
|
|
+ line_width=line_width,
|
|
|
+ line_color=line_color
|
|
|
+ )
|
|
|
+
|
|
|
+ # 生成结构配置
|
|
|
+ structure = applier.generate_structure_for_image(ocr_data)
|
|
|
+
|
|
|
+ # 保存图片
|
|
|
+ base_name = entry["json"].stem
|
|
|
+ image_suffix = st.session_state.current_output_config.get("image_suffix", ".png")
|
|
|
+ output_image_path = output_dir / f"{base_name}{image_suffix}"
|
|
|
+ img_with_lines.save(output_image_path)
|
|
|
+
|
|
|
+ # 🔑 保存结构(确保 set 转为 list)
|
|
|
+ structure_path = output_dir / f"{base_name}{structure_suffix}"
|
|
|
+
|
|
|
+ with open(structure_path, 'w', encoding='utf-8') as f:
|
|
|
+ json.dump(structure, f, indent=2, ensure_ascii=False)
|
|
|
+
|
|
|
+ success_count += 1
|
|
|
+ results.append({
|
|
|
+ 'page': entry['index'],
|
|
|
+ 'status': 'success',
|
|
|
+ 'image': str(output_image_path),
|
|
|
+ 'structure': str(structure_path)
|
|
|
+ })
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ failed_count += 1
|
|
|
+ results.append({
|
|
|
+ 'page': entry['index'],
|
|
|
+ 'status': 'error',
|
|
|
+ 'error': str(e)
|
|
|
+ })
|
|
|
+ st.error(f"❌ 处理失败 {entry['display']}: {e}")
|
|
|
+
|
|
|
+ # 完成
|
|
|
+ progress_bar.progress(1.0)
|
|
|
+ status_text.empty()
|
|
|
+
|
|
|
+ # 保存批处理结果
|
|
|
+ batch_result_path = output_dir / "batch_results.json"
|
|
|
+ with open(batch_result_path, 'w', encoding='utf-8') as f:
|
|
|
+ json.dump({
|
|
|
+ 'template': template_entry['display'],
|
|
|
+ 'template_file': str(template_file),
|
|
|
+ 'total': len(target_entries),
|
|
|
+ 'success': success_count,
|
|
|
+ 'failed': failed_count,
|
|
|
+ 'line_width': line_width,
|
|
|
+ 'line_color': line_color,
|
|
|
+ 'results': results
|
|
|
+ }, f, indent=2, ensure_ascii=False)
|
|
|
+
|
|
|
+ # 显示结果
|
|
|
+ if success_count > 0:
|
|
|
+ st.success(
|
|
|
+ f"✅ 批量应用完成!\n\n"
|
|
|
+ f"成功: {success_count} 页\n\n"
|
|
|
+ f"失败: {failed_count} 页"
|
|
|
+ )
|
|
|
+
|
|
|
+ # 🔑 提供下载批处理结果
|
|
|
+ with open(batch_result_path, 'r', encoding='utf-8') as f:
|
|
|
+ st.download_button(
|
|
|
+ "📥 下载批处理报告",
|
|
|
+ f.read(),
|
|
|
+ file_name="batch_results.json",
|
|
|
+ mime="application/json"
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ st.error("❌ 批量应用失败,没有成功处理任何页面")
|
|
|
+
|
|
|
+ # 显示详细结果
|
|
|
+ with st.expander("📋 详细结果"):
|
|
|
+ for result in results:
|
|
|
+ if result['status'] == 'success':
|
|
|
+ st.success(f"✅ 第 {result['page']} 页")
|
|
|
+ elif result['status'] == 'error':
|
|
|
+ st.error(f"❌ 第 {result['page']} 页: {result.get('error', '未知错误')}")
|
|
|
+ else:
|
|
|
+ st.warning(f"⚠️ 第 {result['page']} 页: {result.get('reason', '跳过')}")
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ st.error(f"❌ 批量应用过程中发生错误: {e}")
|
|
|
+ import traceback
|
|
|
+ with st.expander("🔍 详细错误信息"):
|
|
|
+ st.code(traceback.format_exc())
|