| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277 |
- """
- 批量模板应用控件
- """
- import streamlit as st
- import json
- from pathlib import Path
- from PIL import Image
- from typing import Dict, List
- import sys
- # 添加父目录到路径
- sys.path.insert(0, str(Path(__file__).parent.parent))
- from table_template_applier import TableTemplateApplier
- from table_line_generator import TableLineGenerator
- def create_batch_template_section(current_line_width: int, current_line_color: str):
- """
- 创建批量应用模板的控制区域
-
- Args:
- current_line_width: 当前页使用的线条宽度
- current_line_color: 当前页使用的线条颜色名称
-
- 要求:
- - 当前在目录模式
- - 已有标注(edit 模式)
- - 有可用的目录清单
- """
- # 检查前置条件
- if 'structure' not in st.session_state or not st.session_state.structure:
- return
-
- if 'current_catalog' not in st.session_state:
- return
-
- if 'current_output_config' not in st.session_state:
- return
-
- # 🔑 检查当前页是否有保存的结构文件
- if 'loaded_config_name' not in st.session_state or not st.session_state.loaded_config_name:
- st.warning("⚠️ 当前页未保存结构文件,请先保存后再批量应用")
- return
-
- st.divider()
- st.subheader("🔄 批量应用模板")
-
- catalog = st.session_state.current_catalog
- current_index = st.session_state.get('current_catalog_index', 0)
- current_entry = catalog[current_index]
-
- # 统计信息
- total_files = len(catalog)
- current_page = current_entry["index"]
-
- # 找出哪些页面还没有标注
- output_config = st.session_state.current_output_config
- output_dir = Path(output_config.get("directory", "output/table_structures"))
- structure_suffix = output_config.get("structure_suffix", "_structure.json")
-
- # 🔑 获取当前页的结构文件路径
- current_base_name = current_entry["json"].stem
- current_structure_file = output_dir / f"{current_base_name}{structure_suffix}"
-
- if not current_structure_file.exists():
- st.error("❌ 未找到当前页的结构文件,请先保存")
- st.info(f"期望文件: {current_structure_file}")
- return
-
- unlabeled_pages = []
- for entry in catalog:
- if entry["index"] == current_page:
- continue # 跳过当前页
- structure_file = output_dir / f"{entry['json'].stem}{structure_suffix}"
- if not structure_file.exists():
- unlabeled_pages.append(entry)
-
- st.info(
- f"📊 当前页: {current_page}/{total_files}\n\n"
- f"📄 模板文件: {current_structure_file.name}\n\n"
- f"✅ 已标注: {total_files - len(unlabeled_pages)} 页\n\n"
- f"⏳ 待处理: {len(unlabeled_pages)} 页"
- )
-
- if len(unlabeled_pages) == 0:
- st.success("🎉 所有页面都已标注!")
- return
-
- # 🔑 使用当前页的设置
- st.info(
- f"🎨 将使用当前页设置:\n\n"
- f"• 线条宽度: {current_line_width}px\n\n"
- f"• 线条颜色: {current_line_color}"
- )
-
- # 获取颜色配置
- line_colors = output_config.get("line_colors") or [
- {"name": "黑色", "rgb": [0, 0, 0]},
- {"name": "蓝色", "rgb": [0, 0, 255]},
- {"name": "红色", "rgb": [255, 0, 0]},
- ]
-
- # 🔑 从颜色名称映射到 RGB
- color_map = {c["name"]: tuple(c["rgb"]) for c in line_colors}
- line_color = color_map.get(current_line_color, (0, 0, 0))
-
- # 应用按钮
- if st.button("🚀 批量应用到所有未标注页面", type="primary"):
- _apply_template_batch(
- current_structure_file, # 🔑 直接使用保存的结构文件
- current_entry,
- unlabeled_pages,
- output_dir,
- structure_suffix,
- current_line_width,
- line_color
- )
- def _apply_template_batch(
- template_file: Path, # 🔑 改为直接传入模板文件路径
- template_entry: Dict,
- target_entries: List[Dict],
- output_dir: Path,
- structure_suffix: str,
- line_width: int,
- line_color: tuple
- ):
- """
- 执行批量应用模板
-
- Args:
- template_file: 模板结构文件路径
- template_entry: 模板页面条目
- target_entries: 目标页面列表
- output_dir: 输出目录
- structure_suffix: 结构文件后缀
- line_width: 线条宽度
- line_color: 线条颜色 (r, g, b)
- """
- try:
- # 🔑 直接使用保存的结构文件创建模板应用器
- applier = TableTemplateApplier(str(template_file))
-
- st.info(f"📋 使用模板: {template_file.name}")
-
- # 进度条
- progress_bar = st.progress(0)
- status_text = st.empty()
-
- success_count = 0
- failed_count = 0
- results = []
-
- for idx, entry in enumerate(target_entries):
- # 更新进度
- progress = (idx + 1) / len(target_entries)
- progress_bar.progress(progress)
- status_text.text(f"处理中: {entry['display']} ({idx + 1}/{len(target_entries)})")
-
- try:
- # 加载 OCR 数据
- with open(entry["json"], "r", encoding="utf-8") as fp:
- raw = json.load(fp)
-
- # 解析 OCR 数据
- if 'parsing_res_list' in raw and 'overall_ocr_res' in raw:
- table_bbox, ocr_data = TableLineGenerator.parse_ppstructure_result(raw)
- else:
- raise ValueError("不支持的 OCR 格式")
-
- # 加载图片
- if entry["image"] and entry["image"].exists():
- image = Image.open(entry["image"])
- else:
- st.warning(f"⚠️ 跳过 {entry['display']}: 未找到图片")
- failed_count += 1
- results.append({
- 'page': entry['index'],
- 'status': 'skipped',
- 'reason': 'no_image'
- })
- continue
-
- # 应用模板生成图片
- img_with_lines = applier.apply_to_image(
- image,
- ocr_data,
- line_width=line_width,
- line_color=line_color
- )
-
- # 生成结构配置
- structure = applier.generate_structure_for_image(ocr_data)
-
- # 保存图片
- base_name = entry["json"].stem
- image_suffix = st.session_state.current_output_config.get("image_suffix", ".png")
- output_image_path = output_dir / f"{base_name}{image_suffix}"
- img_with_lines.save(output_image_path)
-
- # 🔑 保存结构(确保 set 转为 list)
- structure_path = output_dir / f"{base_name}{structure_suffix}"
-
- with open(structure_path, 'w', encoding='utf-8') as f:
- json.dump(structure, f, indent=2, ensure_ascii=False)
-
- success_count += 1
- results.append({
- 'page': entry['index'],
- 'status': 'success',
- 'image': str(output_image_path),
- 'structure': str(structure_path)
- })
-
- except Exception as e:
- failed_count += 1
- results.append({
- 'page': entry['index'],
- 'status': 'error',
- 'error': str(e)
- })
- st.error(f"❌ 处理失败 {entry['display']}: {e}")
-
- # 完成
- progress_bar.progress(1.0)
- status_text.empty()
-
- # 保存批处理结果
- batch_result_path = output_dir / "batch_results.json"
- with open(batch_result_path, 'w', encoding='utf-8') as f:
- json.dump({
- 'template': template_entry['display'],
- 'template_file': str(template_file),
- 'total': len(target_entries),
- 'success': success_count,
- 'failed': failed_count,
- 'line_width': line_width,
- 'line_color': line_color,
- 'results': results
- }, f, indent=2, ensure_ascii=False)
-
- # 显示结果
- if success_count > 0:
- st.success(
- f"✅ 批量应用完成!\n\n"
- f"成功: {success_count} 页\n\n"
- f"失败: {failed_count} 页"
- )
-
- # 🔑 提供下载批处理结果
- with open(batch_result_path, 'r', encoding='utf-8') as f:
- st.download_button(
- "📥 下载批处理报告",
- f.read(),
- file_name="batch_results.json",
- mime="application/json"
- )
- else:
- st.error("❌ 批量应用失败,没有成功处理任何页面")
-
- # 显示详细结果
- with st.expander("📋 详细结果"):
- for result in results:
- if result['status'] == 'success':
- st.success(f"✅ 第 {result['page']} 页")
- elif result['status'] == 'error':
- st.error(f"❌ 第 {result['page']} 页: {result.get('error', '未知错误')}")
- else:
- st.warning(f"⚠️ 第 {result['page']} 页: {result.get('reason', '跳过')}")
-
- except Exception as e:
- st.error(f"❌ 批量应用过程中发生错误: {e}")
- import traceback
- with st.expander("🔍 详细错误信息"):
- st.code(traceback.format_exc())
|