浏览代码

feat: 添加目录模式支持,优化配置加载和状态管理

zhch158_admin 4 天之前
父节点
当前提交
244e079db9
共有 1 个文件被更改,包括 149 次插入107 次删除
  1. 149 107
      table_line_generator/streamlit_table_line_editor.py

+ 149 - 107
table_line_generator/streamlit_table_line_editor.py

@@ -6,6 +6,10 @@
 import streamlit as st
 from pathlib import Path
 from PIL import Image
+import yaml
+from typing import Dict, List, Optional, Tuple
+import argparse
+import sys
 
 try:
     from table_line_generator import TableLineGenerator
@@ -20,6 +24,11 @@ from editor import (
     create_undo_redo_section,
     create_analysis_section,
     create_save_section,
+    create_directory_selector,
+    # 新增的模块功能
+    setup_new_annotation_mode,
+    setup_edit_annotation_mode,
+    render_table_structure_view,
     
     # 绘图
     get_cached_table_lines_image,
@@ -29,8 +38,31 @@ from editor import (
     
     # 调整
     create_adjustment_section,
+    show_image_with_scroll,
+
+    # 配置
+    load_table_editor_config,
+    build_data_source_catalog,
+    parse_table_editor_cli_args,
 )
 
+DEFAULT_CONFIG_PATH = Path(__file__).with_name("table_line_generator.yaml")
+
+@st.cache_resource
+def get_cli_args():
+    return parse_table_editor_cli_args()
+
+@st.cache_resource
+def get_table_editor_config():
+    """缓存配置加载(整个 session 共享)"""
+    cli_args = get_cli_args()
+    config_path = (
+        Path(cli_args.config).expanduser()
+        if cli_args.config
+        else DEFAULT_CONFIG_PATH
+    )
+    return load_table_editor_config(config_path)
+
 
 def create_table_line_editor():
     """创建表格线编辑器界面"""
@@ -44,7 +76,13 @@ def create_table_line_editor():
     
     st.title("📏 表格线编辑器")
     
-    # 初始化 session_state
+    # 🎯 从缓存获取配置
+    TABLE_EDITOR_CONFIG = get_table_editor_config()
+    VIEWPORT_WIDTH = TABLE_EDITOR_CONFIG["viewport"]["width"]
+    VIEWPORT_HEIGHT = TABLE_EDITOR_CONFIG["viewport"]["height"]
+    DATA_SOURCES = TABLE_EDITOR_CONFIG.get("data_sources", [])
+    
+    # 初始化 session_state(集中管理)
     if 'loaded_json_name' not in st.session_state:
         st.session_state.loaded_json_name = None
     if 'loaded_image_name' not in st.session_state:
@@ -56,6 +94,14 @@ def create_table_line_editor():
     if 'image' not in st.session_state:
         st.session_state.image = None
     
+    # 🆕 目录模式专用状态
+    if 'dir_selected_index' not in st.session_state:
+        st.session_state.dir_selected_index = 0
+    if 'last_loaded_entry' not in st.session_state:
+        st.session_state.last_loaded_entry = None
+    if 'dir_auto_mode' not in st.session_state:
+        st.session_state.dir_auto_mode = None
+    
     # 初始化撤销/重做栈
     init_undo_stack()
     
@@ -63,18 +109,64 @@ def create_table_line_editor():
     st.sidebar.header("📂 工作模式")
     work_mode = st.sidebar.radio(
         "选择模式",
-        ["🆕 新建标注", "📂 加载已有标注"],
+        ["🆕 新建标注", "📂 加载已有标注", "📁 目录模式"],
         index=0
     )
     
-    # 文件上传区域
-    create_file_uploader_section(work_mode)
+    # 📁 目录模式
+    if work_mode == "📁 目录模式":
+        if not DATA_SOURCES:
+            st.sidebar.warning("未配置 data_sources")
+            return
+        
+        auto_mode = create_directory_selector(DATA_SOURCES, TABLE_EDITOR_CONFIG["output"])
+        
+        if auto_mode == "new":
+            if not (st.session_state.ocr_data and st.session_state.image):
+                st.warning("⚠️ 缺少必要数据")
+                return
+            setup_new_annotation_mode(
+                st.session_state.ocr_data,
+                st.session_state.image,
+                TABLE_EDITOR_CONFIG["display"]
+            )
+        else:  # edit
+            if 'structure' not in st.session_state:
+                st.warning("⚠️ 结构加载失败")
+                return
+            image, line_width, display_mode, zoom_level, show_line_numbers = setup_edit_annotation_mode(
+                st.session_state.structure,
+                st.session_state.image,
+                TABLE_EDITOR_CONFIG["display"]
+            )
+        
+        # 统一渲染
+        if 'structure' in st.session_state and st.session_state.structure:
+            render_table_structure_view(
+                st.session_state.structure,
+                st.session_state.image or Image.new('RGB', (2000, 2000), 'white'),
+                line_width if auto_mode == "edit" else st.session_state.get('line_width', 2),
+                display_mode if auto_mode == "edit" else st.session_state.get('display_mode', "仅显示划线图"),
+                zoom_level if auto_mode == "edit" else st.session_state.get('zoom_level', 1.0),
+                show_line_numbers if auto_mode == "edit" else True,
+                VIEWPORT_WIDTH,
+                VIEWPORT_HEIGHT
+            )
+            create_save_section(
+                auto_mode,
+                st.session_state.structure,
+                st.session_state.image,
+                line_width if auto_mode == "edit" else 2,
+                TABLE_EDITOR_CONFIG["output"]
+            )
+        return
     
-    # 检查必要条件
+    # 🆕 新建标注模式
     if work_mode == "🆕 新建标注":
-        if st.session_state.ocr_data is None or st.session_state.image is None:
+        create_file_uploader_section(work_mode)
+        
+        if not (st.session_state.ocr_data and st.session_state.image):
             st.info("👆 请在左侧上传 OCR 结果 JSON 文件和对应的图片")
-            
             with st.expander("📖 使用说明"):
                 st.markdown("""
                 ### 🆕 新建标注模式
@@ -111,24 +203,40 @@ def create_table_line_editor():
                 """)
             return
         
-        ocr_data = st.session_state.ocr_data
-        image = st.session_state.image
-        
         st.info(f"📂 已加载: {st.session_state.loaded_json_name} + {st.session_state.loaded_image_name}")
         
-        # 初始化生成器
-        if 'generator' not in st.session_state or st.session_state.generator is None:
-            try:
-                generator = TableLineGenerator(image, ocr_data)
-                st.session_state.generator = generator
-            except Exception as e:
-                st.error(f"❌ 初始化失败: {e}")
-                st.stop()
+        _, _, _, line_width, display_mode, zoom_level, show_line_numbers = setup_new_annotation_mode(
+            st.session_state.ocr_data,
+            st.session_state.image,
+            TABLE_EDITOR_CONFIG["display"]
+        )
+        
+        if 'structure' in st.session_state and st.session_state.structure:
+            render_table_structure_view(
+                st.session_state.structure,
+                st.session_state.image,
+                line_width,
+                display_mode,
+                zoom_level,
+                show_line_numbers,
+                VIEWPORT_WIDTH,
+                VIEWPORT_HEIGHT
+            )
+            create_save_section(
+                work_mode,
+                st.session_state.structure,
+                st.session_state.image,
+                line_width,
+                TABLE_EDITOR_CONFIG["output"]
+            )
+        return
     
-    else:  # 加载已有标注模式
+    # 📂 加载已有标注模式
+    if work_mode == "📂 加载已有标注":
+        create_file_uploader_section(work_mode)
+        
         if 'structure' not in st.session_state:
             st.info("👆 请在左侧上传配置文件 (*_structure.json)")
-            
             with st.expander("📖 使用说明"):
                 st.markdown("""
                 ### 📂 加载已有标注
@@ -150,95 +258,29 @@ def create_table_line_editor():
         if st.session_state.image is None:
             st.warning("⚠️ 仅加载了配置,未加载图片。部分功能受限。")
         
-        structure = st.session_state.structure
-        image = st.session_state.image
-        
-        # 如果没有图片,创建虚拟画布
-        if image is None:
-            if 'table_bbox' in structure:
-                bbox = structure['table_bbox']
-                dummy_width = bbox[2] + 100
-                dummy_height = bbox[3] + 100
-            else:
-                dummy_width = 2000
-                dummy_height = 2000
-            
-            image = Image.new('RGB', (dummy_width, dummy_height), color='white')
-            st.info(f"💡 使用虚拟画布 ({dummy_width}x{dummy_height}) 显示表格结构")
-    
-    # 参数调整(仅在新建模式显示)
-    if work_mode == "🆕 新建标注":
-        st.sidebar.header("🔧 参数调整")
-        
-        y_tolerance = st.sidebar.slider("Y轴聚类容差(像素)", 1, 20, 5)
-        x_tolerance = st.sidebar.slider("X轴聚类容差(像素)", 5, 50, 10)
-        min_row_height = st.sidebar.slider("最小行高(像素)", 10, 100, 20)
-    
-    # 显示设置
-    line_width, display_mode, zoom_level, show_line_numbers = create_display_settings_section()
-    
-    # 撤销/重做
-    create_undo_redo_section()
-    
-    # 分析表格结构(仅在新建模式显示)
-    if work_mode == "🆕 新建标注":
-        create_analysis_section(y_tolerance, x_tolerance, min_row_height)
-    
-    # 显示结果
-    if 'structure' in st.session_state and st.session_state.structure:
-        structure = st.session_state.structure
-        
-        # 使用缓存机制绘制表格线
-        img_with_lines = get_cached_table_lines_image(
-            image, 
-            structure, 
-            line_width=line_width,
-            show_numbers=show_line_numbers
+        image, line_width, display_mode, zoom_level, show_line_numbers = setup_edit_annotation_mode(
+            st.session_state.structure,
+            st.session_state.image,
+            TABLE_EDITOR_CONFIG["display"]
         )
         
-        # 根据显示模式显示图片
-        if display_mode == "对比显示":
-            col1, col2 = st.columns(2)
-            with col1:
-                st.subheader("原图")
-                st.image(image, use_container_width=True)
-            
-            with col2:
-                st.subheader("添加表格线")
-                st.image(img_with_lines, use_container_width=True)
-                
-        elif display_mode == "仅显示划线图":
-            display_width = int(img_with_lines.width * zoom_level)
-            
-            st.subheader(f"表格线图 (缩放: {zoom_level:.0%})")
-            st.image(img_with_lines, width=display_width)
-            
-        else:
-            display_width = int(image.width * zoom_level)
-            
-            st.subheader(f"原图 (缩放: {zoom_level:.0%})")
-            st.image(image, width=display_width)
-        
-        # 手动调整区域
-        create_adjustment_section(structure)
-        
-        # 显示详细信息
-        with st.expander("📊 表格结构详情"):
-            st.json({
-                "行数": len(structure['rows']),
-                "列数": len(structure['columns']),
-                "横线数": len(structure.get('horizontal_lines', [])),
-                "竖线数": len(structure.get('vertical_lines', [])),
-                "横线坐标": structure.get('horizontal_lines', []),
-                "竖线坐标": structure.get('vertical_lines', []),
-                "标准行高": structure.get('row_height'),
-                "列宽度": structure.get('col_widths'),
-                "修改的横线": list(structure.get('modified_h_lines', set())),
-                "修改的竖线": list(structure.get('modified_v_lines', set()))
-            })
-        
-        # 保存区域
-        create_save_section(work_mode, structure, image, line_width)
+        render_table_structure_view(
+            st.session_state.structure,
+            image,
+            line_width,
+            display_mode,
+            zoom_level,
+            show_line_numbers,
+            VIEWPORT_WIDTH,
+            VIEWPORT_HEIGHT
+        )
+        create_save_section(
+            work_mode,
+            st.session_state.structure,
+            image,
+            line_width,
+            TABLE_EDITOR_CONFIG["output"]
+        )
 
 
 if __name__ == "__main__":