Преглед изворни кода

feat: 添加批量模板控件,优化目录模式和保存功能

zhch158_admin пре 3 дана
родитељ
комит
8f5ffb2a25

+ 13 - 0
table_line_generator/editor/__init__.py

@@ -1,6 +1,13 @@
 """
 表格线编辑器核心模块
 """
+import sys
+from pathlib import Path
+
+# ✅ 确保父目录在路径中
+_parent_dir = Path(__file__).parent.parent
+if str(_parent_dir) not in sys.path:
+    sys.path.insert(0, str(_parent_dir))
 
 # 文件处理
 from .file_handlers import create_file_uploader_section
@@ -17,6 +24,9 @@ from .analysis_controls import create_analysis_section
 # 保存控件
 from .save_controls import create_save_section
 
+# 🆕 批量模板控件
+from .batch_template_controls import create_batch_template_section
+
 # 模式设置
 from .mode_setup import (
     setup_new_annotation_mode,
@@ -77,6 +87,9 @@ __all__ = [
     # 保存控件
     'create_save_section',
     
+    # 🆕 批量模板控件
+    'create_batch_template_section',
+    
     # 模式设置
     'setup_new_annotation_mode',
     'setup_edit_annotation_mode',

+ 277 - 0
table_line_generator/editor/batch_template_controls.py

@@ -0,0 +1,277 @@
+"""
+批量模板应用控件
+"""
+import streamlit as st
+import json
+from pathlib import Path
+from PIL import Image
+from typing import Dict, List
+import sys
+
+# 添加父目录到路径
+sys.path.insert(0, str(Path(__file__).parent.parent))
+
+from table_template_applier import TableTemplateApplier
+from table_line_generator import TableLineGenerator
+
+
+def create_batch_template_section(current_line_width: int, current_line_color: str):
+    """
+    创建批量应用模板的控制区域
+    
+    Args:
+        current_line_width: 当前页使用的线条宽度
+        current_line_color: 当前页使用的线条颜色名称
+        
+    要求:
+    - 当前在目录模式
+    - 已有标注(edit 模式)
+    - 有可用的目录清单
+    """
+    # 检查前置条件
+    if 'structure' not in st.session_state or not st.session_state.structure:
+        return
+    
+    if 'current_catalog' not in st.session_state:
+        return
+    
+    if 'current_output_config' not in st.session_state:
+        return
+    
+    # 🔑 检查当前页是否有保存的结构文件
+    if 'loaded_config_name' not in st.session_state or not st.session_state.loaded_config_name:
+        st.warning("⚠️ 当前页未保存结构文件,请先保存后再批量应用")
+        return
+    
+    st.divider()
+    st.subheader("🔄 批量应用模板")
+    
+    catalog = st.session_state.current_catalog
+    current_index = st.session_state.get('current_catalog_index', 0)
+    current_entry = catalog[current_index]
+    
+    # 统计信息
+    total_files = len(catalog)
+    current_page = current_entry["index"]
+    
+    # 找出哪些页面还没有标注
+    output_config = st.session_state.current_output_config
+    output_dir = Path(output_config.get("directory", "output/table_structures"))
+    structure_suffix = output_config.get("structure_suffix", "_structure.json")
+    
+    # 🔑 获取当前页的结构文件路径
+    current_base_name = current_entry["json"].stem
+    current_structure_file = output_dir / f"{current_base_name}{structure_suffix}"
+    
+    if not current_structure_file.exists():
+        st.error("❌ 未找到当前页的结构文件,请先保存")
+        st.info(f"期望文件: {current_structure_file}")
+        return
+    
+    unlabeled_pages = []
+    for entry in catalog:
+        if entry["index"] == current_page:
+            continue  # 跳过当前页
+        structure_file = output_dir / f"{entry['json'].stem}{structure_suffix}"
+        if not structure_file.exists():
+            unlabeled_pages.append(entry)
+    
+    st.info(
+        f"📊 当前页: {current_page}/{total_files}\n\n"
+        f"📄 模板文件: {current_structure_file.name}\n\n"
+        f"✅ 已标注: {total_files - len(unlabeled_pages)} 页\n\n"
+        f"⏳ 待处理: {len(unlabeled_pages)} 页"
+    )
+    
+    if len(unlabeled_pages) == 0:
+        st.success("🎉 所有页面都已标注!")
+        return
+    
+    # 🔑 使用当前页的设置
+    st.info(
+        f"🎨 将使用当前页设置:\n\n"
+        f"• 线条宽度: {current_line_width}px\n\n"
+        f"• 线条颜色: {current_line_color}"
+    )
+    
+    # 获取颜色配置
+    line_colors = output_config.get("line_colors") or [
+        {"name": "黑色", "rgb": [0, 0, 0]},
+        {"name": "蓝色", "rgb": [0, 0, 255]},
+        {"name": "红色", "rgb": [255, 0, 0]},
+    ]
+    
+    # 🔑 从颜色名称映射到 RGB
+    color_map = {c["name"]: tuple(c["rgb"]) for c in line_colors}
+    line_color = color_map.get(current_line_color, (0, 0, 0))
+    
+    # 应用按钮
+    if st.button("🚀 批量应用到所有未标注页面", type="primary"):
+        _apply_template_batch(
+            current_structure_file,  # 🔑 直接使用保存的结构文件
+            current_entry,
+            unlabeled_pages,
+            output_dir,
+            structure_suffix,
+            current_line_width,
+            line_color
+        )
+
+
+def _apply_template_batch(
+    template_file: Path,  # 🔑 改为直接传入模板文件路径
+    template_entry: Dict,
+    target_entries: List[Dict],
+    output_dir: Path,
+    structure_suffix: str,
+    line_width: int,
+    line_color: tuple
+):
+    """
+    执行批量应用模板
+    
+    Args:
+        template_file: 模板结构文件路径
+        template_entry: 模板页面条目
+        target_entries: 目标页面列表
+        output_dir: 输出目录
+        structure_suffix: 结构文件后缀
+        line_width: 线条宽度
+        line_color: 线条颜色 (r, g, b)
+    """
+    try:
+        # 🔑 直接使用保存的结构文件创建模板应用器
+        applier = TableTemplateApplier(str(template_file))
+        
+        st.info(f"📋 使用模板: {template_file.name}")
+        
+        # 进度条
+        progress_bar = st.progress(0)
+        status_text = st.empty()
+        
+        success_count = 0
+        failed_count = 0
+        results = []
+        
+        for idx, entry in enumerate(target_entries):
+            # 更新进度
+            progress = (idx + 1) / len(target_entries)
+            progress_bar.progress(progress)
+            status_text.text(f"处理中: {entry['display']} ({idx + 1}/{len(target_entries)})")
+            
+            try:
+                # 加载 OCR 数据
+                with open(entry["json"], "r", encoding="utf-8") as fp:
+                    raw = json.load(fp)
+                
+                # 解析 OCR 数据
+                if 'parsing_res_list' in raw and 'overall_ocr_res' in raw:
+                    table_bbox, ocr_data = TableLineGenerator.parse_ppstructure_result(raw)
+                else:
+                    raise ValueError("不支持的 OCR 格式")
+                
+                # 加载图片
+                if entry["image"] and entry["image"].exists():
+                    image = Image.open(entry["image"])
+                else:
+                    st.warning(f"⚠️ 跳过 {entry['display']}: 未找到图片")
+                    failed_count += 1
+                    results.append({
+                        'page': entry['index'],
+                        'status': 'skipped',
+                        'reason': 'no_image'
+                    })
+                    continue
+                
+                # 应用模板生成图片
+                img_with_lines = applier.apply_to_image(
+                    image,
+                    ocr_data,
+                    line_width=line_width,
+                    line_color=line_color
+                )
+                
+                # 生成结构配置
+                structure = applier.generate_structure_for_image(ocr_data)
+                
+                # 保存图片
+                base_name = entry["json"].stem
+                image_suffix = st.session_state.current_output_config.get("image_suffix", ".png")
+                output_image_path = output_dir / f"{base_name}{image_suffix}"
+                img_with_lines.save(output_image_path)
+                
+                # 🔑 保存结构(确保 set 转为 list)
+                structure_path = output_dir / f"{base_name}{structure_suffix}"
+                
+                with open(structure_path, 'w', encoding='utf-8') as f:
+                    json.dump(structure, f, indent=2, ensure_ascii=False)
+                
+                success_count += 1
+                results.append({
+                    'page': entry['index'],
+                    'status': 'success',
+                    'image': str(output_image_path),
+                    'structure': str(structure_path)
+                })
+                
+            except Exception as e:
+                failed_count += 1
+                results.append({
+                    'page': entry['index'],
+                    'status': 'error',
+                    'error': str(e)
+                })
+                st.error(f"❌ 处理失败 {entry['display']}: {e}")
+        
+        # 完成
+        progress_bar.progress(1.0)
+        status_text.empty()
+        
+        # 保存批处理结果
+        batch_result_path = output_dir / "batch_results.json"
+        with open(batch_result_path, 'w', encoding='utf-8') as f:
+            json.dump({
+                'template': template_entry['display'],
+                'template_file': str(template_file),
+                'total': len(target_entries),
+                'success': success_count,
+                'failed': failed_count,
+                'line_width': line_width,
+                'line_color': line_color,
+                'results': results
+            }, f, indent=2, ensure_ascii=False)
+        
+        # 显示结果
+        if success_count > 0:
+            st.success(
+                f"✅ 批量应用完成!\n\n"
+                f"成功: {success_count} 页\n\n"
+                f"失败: {failed_count} 页"
+            )
+            
+            # 🔑 提供下载批处理结果
+            with open(batch_result_path, 'r', encoding='utf-8') as f:
+                st.download_button(
+                    "📥 下载批处理报告",
+                    f.read(),
+                    file_name="batch_results.json",
+                    mime="application/json"
+                )
+        else:
+            st.error("❌ 批量应用失败,没有成功处理任何页面")
+        
+        # 显示详细结果
+        with st.expander("📋 详细结果"):
+            for result in results:
+                if result['status'] == 'success':
+                    st.success(f"✅ 第 {result['page']} 页")
+                elif result['status'] == 'error':
+                    st.error(f"❌ 第 {result['page']} 页: {result.get('error', '未知错误')}")
+                else:
+                    st.warning(f"⚠️ 第 {result['page']} 页: {result.get('reason', '跳过')}")
+    
+    except Exception as e:
+        st.error(f"❌ 批量应用过程中发生错误: {e}")
+        import traceback
+        with st.expander("🔍 详细错误信息"):
+            st.code(traceback.format_exc())

+ 26 - 9
table_line_generator/editor/directory_selector.py

@@ -36,8 +36,13 @@ def create_directory_selector(
     )
     source_cfg = next(src for src in data_sources if src["name"] == selected_name)
     
-    # 获取输出配置
+    # 🔑 保存当前选择的数据源配置到 session_state
+    st.session_state.current_data_source = source_cfg
+    
+    # 获取输出配置(优先使用数据源自己的 output)
     output_cfg = source_cfg.get("output", global_output_config)
+    st.session_state.current_output_config = output_cfg
+    
     output_dir = Path(output_cfg.get("directory", "output/table_structures"))
     structure_suffix = output_cfg.get("structure_suffix", "_structure.json")
     
@@ -74,6 +79,10 @@ def create_directory_selector(
         key="dir_page_input"
     )
     
+    # 🔑 保存当前选择的目录清单到 session_state(供批量应用使用)
+    st.session_state.current_catalog = catalog
+    st.session_state.current_catalog_index = selected
+    
     # 🔑 关键优化:只在切换文件时才重新加载
     current_entry_key = f"{selected_name}::{catalog[selected]['json']}"
     
@@ -107,16 +116,24 @@ def _load_catalog_entry(entry: Dict, output_dir: Path, structure_suffix: str, en
     has_structure = structure_file.exists()
     
     # 📂 加载 JSON
-    with open(entry["json"], "r", encoding="utf-8") as fp:
-        raw = json.load(fp)
-    st.session_state.ocr_data = parse_ocr_data(raw)
-    st.session_state.loaded_json_name = entry["json"].name
+    try:
+        with open(entry["json"], "r", encoding="utf-8") as fp:
+            raw = json.load(fp)
+        st.session_state.ocr_data = parse_ocr_data(raw)
+        st.session_state.loaded_json_name = entry["json"].name
+    except Exception as e:
+        st.error(f"❌ 加载 JSON 失败: {e}")
+        return
 
     # 🖼️ 加载图片
-    if entry["image"] and entry["image"].exists():
-        st.session_state.image = Image.open(entry["image"])
-        st.session_state.loaded_image_name = entry["image"].name
-    else:
+    try:
+        if entry["image"] and entry["image"].exists():
+            st.session_state.image = Image.open(entry["image"])
+            st.session_state.loaded_image_name = entry["image"].name
+        else:
+            st.session_state.image = None
+    except Exception as e:
+        st.error(f"❌ 加载图片失败: {e}")
         st.session_state.image = None
 
     # 🎯 自动模式判断

+ 27 - 8
table_line_generator/editor/save_controls.py

@@ -19,12 +19,19 @@ def create_save_section(work_mode: str, structure: Dict, image, line_width: int,
         structure: 表格结构
         image: 图片对象
         line_width: 线条宽度
-        output_config: 输出配置
+        output_config: 输出配置(兜底用)
     """
     st.divider()
 
-    defaults = output_config.get("defaults", {})
-    line_colors = output_config.get("line_colors") or [
+    # 🔑 优先使用当前数据源的输出配置
+    if 'current_output_config' in st.session_state:
+        active_output_config = st.session_state.current_output_config
+        st.info(f"📂 保存位置:{active_output_config.get('directory', 'N/A')}")
+    else:
+        active_output_config = output_config
+
+    defaults = active_output_config.get("defaults", {})
+    line_colors = active_output_config.get("line_colors") or [
         {"name": "黑色", "rgb": [0, 0, 0]},
         {"name": "蓝色", "rgb": [0, 0, 255]},
         {"name": "红色", "rgb": [255, 0, 0]},
@@ -54,18 +61,19 @@ def create_save_section(work_mode: str, structure: Dict, image, line_width: int,
 
     with save_col3:
         line_color_option = st.selectbox(
-            "保存时线条颜色",
+            "线条颜色",
             color_names,
-            label_visibility="collapsed",
             index=default_index,
+			label_visibility="collapsed",
+            key="save_line_color"
         )
 
     if st.button("💾 保存", type="primary"):
-        output_dir = Path(output_config.get("directory", "output/table_structures"))
+        output_dir = Path(active_output_config.get("directory", "output/table_structures"))
         output_dir.mkdir(parents=True, exist_ok=True)
 
-        structure_suffix = output_config.get("structure_suffix", "_structure.json")
-        image_suffix = output_config.get("image_suffix", "_with_lines.png")
+        structure_suffix = active_output_config.get("structure_suffix", "_structure.json")
+        image_suffix = active_output_config.get("image_suffix", "_with_lines.png")
 
         # 确定文件名
         base_name = _determine_base_name(work_mode)
@@ -98,6 +106,17 @@ def create_save_section(work_mode: str, structure: Dict, image, line_width: int,
             st.success(f"✅ 已保存 {len(saved_files)} 个文件:")
             for file_type, file_path in saved_files:
                 st.info(f"  • {file_type}: {file_path}")
+            
+            # 显示当前数据源信息
+            if 'current_data_source' in st.session_state:
+                ds = st.session_state.current_data_source
+                with st.expander("📋 数据源信息"):
+                    st.json({
+                        "名称": ds.get("name"),
+                        "JSON目录": str(ds.get("json_dir")),
+                        "图片目录": str(ds.get("image_dir")),
+                        "输出目录": str(output_dir),
+                    })
 
 
 def _determine_base_name(work_mode: str) -> str: