""" UI 组件 """ import streamlit as st import json from pathlib import Path from PIL import Image import tempfile from typing import Dict, List try: from ..table_line_generator import TableLineGenerator except ImportError: from table_line_generator import TableLineGenerator from .config_loader import load_structure_from_config, build_data_source_catalog from .drawing import clear_table_image_cache def create_file_uploader_section(work_mode: str): """ 创建文件上传区域 Args: work_mode: 工作模式("🆕 新建标注" 或 "📂 加载已有标注") """ if work_mode == "🆕 新建标注": st.sidebar.subheader("上传文件") uploaded_json = st.sidebar.file_uploader("上传OCR结果JSON", type=['json'], key="new_json") uploaded_image = st.sidebar.file_uploader("上传对应图片", type=['jpg', 'png'], key="new_image") # 处理 JSON 上传 if uploaded_json is not None: if st.session_state.loaded_json_name != uploaded_json.name: try: raw_data = json.load(uploaded_json) with st.expander("🔍 原始数据结构"): if isinstance(raw_data, dict): st.json({k: f"<{type(v).__name__}>" if not isinstance(v, (str, int, float, bool, type(None))) else v for k, v in list(raw_data.items())[:5]}) else: st.json(raw_data[:3] if len(raw_data) > 3 else raw_data) ocr_data = parse_ocr_data(raw_data) if not ocr_data: st.error("❌ 无法解析 OCR 数据,请检查 JSON 格式") st.stop() st.session_state.ocr_data = ocr_data st.session_state.loaded_json_name = uploaded_json.name st.session_state.loaded_config_name = None # 清除旧数据 if 'structure' in st.session_state: del st.session_state.structure if 'generator' in st.session_state: del st.session_state.generator st.session_state.undo_stack = [] st.session_state.redo_stack = [] clear_table_image_cache() st.success(f"✅ 成功加载 {len(ocr_data)} 条 OCR 记录") except Exception as e: st.error(f"❌ 加载数据失败: {e}") st.stop() # 处理图片上传 if uploaded_image is not None: if st.session_state.loaded_image_name != uploaded_image.name: try: image = Image.open(uploaded_image) st.session_state.image = image st.session_state.loaded_image_name = uploaded_image.name # 清除旧数据 if 'structure' in st.session_state: del st.session_state.structure if 'generator' in st.session_state: del st.session_state.generator st.session_state.undo_stack = [] st.session_state.redo_stack = [] clear_table_image_cache() st.success(f"✅ 成功加载图片: {uploaded_image.name}") except Exception as e: st.error(f"❌ 加载图片失败: {e}") st.stop() else: # 加载已有标注 st.sidebar.subheader("加载已保存的标注") uploaded_config = st.sidebar.file_uploader( "上传配置文件 (*_structure.json)", type=['json'], key="load_config" ) uploaded_image_for_config = st.sidebar.file_uploader( "上传对应图片(可选)", type=['jpg', 'png'], key="load_image" ) # 处理配置文件加载 if uploaded_config is not None: if st.session_state.loaded_config_name != uploaded_config.name: try: # 创建临时文件 with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False, encoding='utf-8') as tmp: tmp.write(uploaded_config.getvalue().decode('utf-8')) tmp_path = tmp.name # 加载结构 structure = load_structure_from_config(Path(tmp_path)) # 清理临时文件 Path(tmp_path).unlink() st.session_state.structure = structure st.session_state.loaded_config_name = uploaded_config.name # 清除历史记录和缓存 st.session_state.undo_stack = [] st.session_state.redo_stack = [] clear_table_image_cache() st.success(f"✅ 成功加载配置: {uploaded_config.name}") st.info( f"📊 表格结构: {len(structure['rows'])}行 x {len(structure['columns'])}列\n\n" f"📏 横线数: {len(structure.get('horizontal_lines', []))}\n\n" f"📏 竖线数: {len(structure.get('vertical_lines', []))}" ) # 显示配置文件详情 with st.expander("📋 配置详情"): st.json({ "行数": len(structure['rows']), "列数": len(structure['columns']), "横线数": len(structure.get('horizontal_lines', [])), "竖线数": len(structure.get('vertical_lines', [])), "行高": structure.get('row_height'), "列宽": structure.get('col_widths'), "已修改的横线": list(structure.get('modified_h_lines', set())), "已修改的竖线": list(structure.get('modified_v_lines', set())) }) except Exception as e: st.error(f"❌ 加载配置失败: {e}") import traceback st.code(traceback.format_exc()) st.stop() # 处理图片加载 if uploaded_image_for_config is not None: if st.session_state.loaded_image_name != uploaded_image_for_config.name: try: image = Image.open(uploaded_image_for_config) st.session_state.image = image st.session_state.loaded_image_name = uploaded_image_for_config.name clear_table_image_cache() st.success(f"✅ 成功加载图片: {uploaded_image_for_config.name}") except Exception as e: st.error(f"❌ 加载图片失败: {e}") st.stop() # 提示信息 if 'structure' in st.session_state and st.session_state.image is None: st.warning("⚠️ 已加载配置,但未加载对应图片。请上传图片以查看效果。") st.info("💡 提示:配置文件已加载,您可以:\n1. 上传对应图片查看效果\n2. 直接编辑配置并保存") def create_display_settings_section(display_config: Dict): """显示设置(由配置驱动)""" st.sidebar.divider() st.sidebar.subheader("🖼️ 显示设置") line_width = st.sidebar.slider( "线条宽度", int(display_config.get("line_width_min", 1)), int(display_config.get("line_width_max", 5)), int(display_config.get("default_line_width", 2)), ) display_mode = st.sidebar.radio( "显示模式", ["对比显示", "仅显示划线图", "仅显示原图"], index=1, ) zoom_level = st.sidebar.slider( "图片缩放", float(display_config.get("zoom_min", 0.25)), float(display_config.get("zoom_max", 2.0)), float(display_config.get("default_zoom", 1.0)), float(display_config.get("zoom_step", 0.25)), ) show_line_numbers = st.sidebar.checkbox( "显示线条编号", value=bool(display_config.get("show_line_numbers", True)), ) return line_width, display_mode, zoom_level, show_line_numbers def create_undo_redo_section(): """创建撤销/重做区域""" from .state_manager import undo_last_action, redo_last_action from .drawing import clear_table_image_cache st.sidebar.divider() st.sidebar.subheader("↩️ 撤销/重做") col1, col2 = st.sidebar.columns(2) with col1: if st.button("↩️ 撤销", disabled=len(st.session_state.undo_stack) == 0): if undo_last_action(): clear_table_image_cache() st.success("✅ 已撤销") st.rerun() with col2: if st.button("↪️ 重做", disabled=len(st.session_state.redo_stack) == 0): if redo_last_action(): clear_table_image_cache() st.success("✅ 已重做") st.rerun() st.sidebar.info(f"📚 历史记录: {len(st.session_state.undo_stack)} 条") def create_analysis_section(y_tolerance, x_tolerance, min_row_height): """ 创建分析区域 Args: y_tolerance: Y轴聚类容差 x_tolerance: X轴聚类容差 min_row_height: 最小行高 """ if st.button("🔍 分析表格结构"): with st.spinner("分析中..."): try: generator = st.session_state.generator structure = generator.analyze_table_structure( y_tolerance=y_tolerance, x_tolerance=x_tolerance, min_row_height=min_row_height ) if not structure: st.warning("⚠️ 未检测到表格结构") st.stop() structure['modified_h_lines'] = set() structure['modified_v_lines'] = set() st.session_state.structure = structure st.session_state.undo_stack = [] st.session_state.redo_stack = [] clear_table_image_cache() st.success( f"✅ 检测到 {len(structure['rows'])} 行({len(structure['horizontal_lines'])} 条横线)," f"{len(structure['columns'])} 列({len(structure['vertical_lines'])} 条竖线)" ) col1, col2, col3, col4 = st.columns(4) with col1: st.metric("行数", len(structure['rows'])) with col2: st.metric("横线数", len(structure['horizontal_lines'])) with col3: st.metric("列数", len(structure['columns'])) with col4: st.metric("竖线数", len(structure['vertical_lines'])) except Exception as e: st.error(f"❌ 分析失败: {e}") import traceback st.code(traceback.format_exc()) st.stop() def create_save_section(work_mode, structure, image, line_width, output_config: Dict): """ 保存设置(目录/命名来自配置) """ from .config_loader import save_structure_to_config from .drawing import draw_clean_table_lines import io st.divider() defaults = output_config.get("defaults", {}) line_colors = output_config.get("line_colors") or [ {"name": "黑色", "rgb": [0, 0, 0]}, {"name": "蓝色", "rgb": [0, 0, 255]}, {"name": "红色", "rgb": [255, 0, 0]}, ] save_col1, save_col2, save_col3 = st.columns(3) with save_col1: save_structure = st.checkbox( "保存表格结构配置", value=bool(defaults.get("save_structure", True)), ) with save_col2: save_image = st.checkbox( "保存表格线图片", value=bool(defaults.get("save_image", True)), ) color_names = [c["name"] for c in line_colors] default_color = defaults.get("line_color", color_names[0]) default_index = color_names.index(default_color) if default_color in color_names else 0 with save_col3: line_color_option = st.selectbox( "保存时线条颜色", color_names, label_visibility="collapsed", index=default_index, ) if st.button("💾 保存", type="primary"): output_dir = Path(output_config.get("directory", "output/table_structures")) output_dir.mkdir(parents=True, exist_ok=True) structure_suffix = output_config.get("structure_suffix", "_structure.json") image_suffix = output_config.get("image_suffix", "_with_lines.png") # 确定文件名 if work_mode == "🆕 新建标注": if st.session_state.loaded_json_name: base_name = Path(st.session_state.loaded_json_name).stem else: base_name = "table_structure" else: if st.session_state.loaded_config_name: base_name = Path(st.session_state.loaded_config_name).stem if base_name.endswith('_structure'): base_name = base_name[:-10] elif st.session_state.loaded_image_name: base_name = Path(st.session_state.loaded_image_name).stem else: base_name = "table_structure" saved_files = [] if save_structure: structure_filename = f"{base_name}{structure_suffix}" structure_path = output_dir / structure_filename save_structure_to_config(structure, structure_path) saved_files.append(("配置文件", structure_path)) with open(structure_path, 'r') as f: st.download_button( "📥 下载配置文件", f.read(), file_name=f"{base_name}_structure.json", mime="application/json" ) if save_image: if st.session_state.image is None: st.warning("⚠️ 无法保存图片:未加载图片文件") else: selected_color_rgb = next( (tuple(c["rgb"]) for c in line_colors if c["name"] == line_color_option), (0, 0, 0), ) clean_img = draw_clean_table_lines( st.session_state.image, structure, line_width=line_width, line_color=selected_color_rgb, ) image_filename = f"{base_name}{image_suffix}" output_image_path = output_dir / image_filename clean_img.save(output_image_path) saved_files.append(("表格线图片", output_image_path)) buf = io.BytesIO() clean_img.save(buf, format='PNG') buf.seek(0) st.download_button( "📥 下载表格线图片", buf, file_name=f"{base_name}_with_lines.png", mime="image/png" ) if saved_files: st.success(f"✅ 已保存 {len(saved_files)} 个文件:") for file_type, file_path in saved_files: st.info(f" • {file_type}: {file_path}") def setup_new_annotation_mode(ocr_data, image, config: Dict): """ 设置新建标注模式的通用逻辑 Args: ocr_data: OCR 数据 image: 图片对象 config: 显示配置 Returns: tuple: (y_tolerance, x_tolerance, min_row_height, line_width, display_mode, zoom_level, show_line_numbers) """ # 参数调整 st.sidebar.header("🔧 参数调整") y_tolerance = st.sidebar.slider("Y轴聚类容差(像素)", 1, 20, 5, key="new_y_tol") x_tolerance = st.sidebar.slider("X轴聚类容差(像素)", 5, 50, 10, key="new_x_tol") min_row_height = st.sidebar.slider("最小行高(像素)", 10, 100, 20, key="new_min_h") # 显示设置 line_width, display_mode, zoom_level, show_line_numbers = create_display_settings_section(config) create_undo_redo_section() # 初始化生成器 if 'generator' not in st.session_state or st.session_state.generator is None: try: generator = TableLineGenerator(image, ocr_data) st.session_state.generator = generator except Exception as e: st.error(f"❌ 初始化生成器失败: {e}") st.stop() # 分析按钮 create_analysis_section(y_tolerance, x_tolerance, min_row_height) return y_tolerance, x_tolerance, min_row_height, line_width, display_mode, zoom_level, show_line_numbers def setup_edit_annotation_mode(structure, image, config: Dict): """ 设置编辑标注模式的通用逻辑 Args: structure: 表格结构 image: 图片对象(可为 None) config: 显示配置 Returns: tuple: (image, line_width, display_mode, zoom_level, show_line_numbers) """ # 如果没有图片,创建虚拟画布 if image is None: if 'table_bbox' in structure: bbox = structure['table_bbox'] dummy_width = bbox[2] + 100 dummy_height = bbox[3] + 100 else: dummy_width = 2000 dummy_height = 2000 image = Image.new('RGB', (dummy_width, dummy_height), color='white') st.info(f"💡 使用虚拟画布 ({dummy_width}x{dummy_height})") # 显示设置 line_width, display_mode, zoom_level, show_line_numbers = create_display_settings_section(config) create_undo_redo_section() return image, line_width, display_mode, zoom_level, show_line_numbers def render_table_structure_view(structure, image, line_width, display_mode, zoom_level, show_line_numbers, viewport_width, viewport_height): """ 渲染表格结构视图(统一三种模式的显示逻辑) Args: structure: 表格结构 image: 图片对象 line_width: 线条宽度 display_mode: 显示模式 zoom_level: 缩放级别 show_line_numbers: 是否显示线条编号 viewport_width: 视口宽度 viewport_height: 视口高度 """ # 绘制表格线 img_with_lines = get_cached_table_lines_image( image, structure, line_width=line_width, show_numbers=show_line_numbers ) # 根据显示模式显示图片 if display_mode == "对比显示": col1, col2 = st.columns(2) with col1: show_image_with_scroll(image, "原图", viewport_width, viewport_height, zoom_level) with col2: show_image_with_scroll(img_with_lines, "表格线", viewport_width, viewport_height, zoom_level) elif display_mode == "仅显示划线图": show_image_with_scroll( img_with_lines, f"表格线图 (缩放: {zoom_level:.0%})", viewport_width, viewport_height, zoom_level ) else: show_image_with_scroll( image, f"原图 (缩放: {zoom_level:.0%})", viewport_width, viewport_height, zoom_level ) # 手动调整区域 create_adjustment_section(structure) # 显示详细信息 with st.expander("📊 表格结构详情"): st.json({ "行数": len(structure['rows']), "列数": len(structure['columns']), "横线数": len(structure.get('horizontal_lines', [])), "竖线数": len(structure.get('vertical_lines', [])), "横线坐标": structure.get('horizontal_lines', []), "竖线坐标": structure.get('vertical_lines', []), "标准行高": structure.get('row_height'), "列宽度": structure.get('col_widths'), "修改的横线": list(structure.get('modified_h_lines', set())), "修改的竖线": list(structure.get('modified_v_lines', set())) }) def create_directory_selector(data_sources: List[Dict], global_output_config: Dict): """目录模式选择器(优化:避免重复加载)""" st.sidebar.subheader("目录模式") source_names = [src["name"] for src in data_sources] selected_name = st.sidebar.selectbox("选择数据源", source_names, key="dir_mode_source") source_cfg = next(src for src in data_sources if src["name"] == selected_name) output_cfg = source_cfg.get("output", global_output_config) output_dir = Path(output_cfg.get("directory", "output/table_structures")) structure_suffix = output_cfg.get("structure_suffix", "_structure.json") catalog_key = f"catalog::{selected_name}" if catalog_key not in st.session_state: st.session_state[catalog_key] = build_data_source_catalog(source_cfg) catalog = st.session_state[catalog_key] if not catalog: st.sidebar.warning("目录中没有 JSON 文件") return if 'dir_selected_index' not in st.session_state: st.session_state.dir_selected_index = 0 selected = st.sidebar.selectbox( "选择文件", range(len(catalog)), format_func=lambda i: catalog[i]["display"], index=st.session_state.dir_selected_index, key="dir_select_box" ) page_input = st.sidebar.number_input( "页码跳转", min_value=1, max_value=len(catalog), value=catalog[selected]["index"], step=1, key="dir_page_input" ) # 🔑 关键优化:只在切换文件时才重新加载 current_entry_key = f"{selected_name}::{catalog[selected]['json']}" if 'last_loaded_entry' not in st.session_state or st.session_state.last_loaded_entry != current_entry_key: # 文件切换,重新加载 entry = catalog[selected] base_name = entry["json"].stem structure_file = output_dir / f"{base_name}{structure_suffix}" has_structure = structure_file.exists() # 📂 加载 JSON with open(entry["json"], "r", encoding="utf-8") as fp: raw = json.load(fp) st.session_state.ocr_data = parse_ocr_data(raw) st.session_state.loaded_json_name = entry["json"].name # 🖼️ 加载图片 if entry["image"] and entry["image"].exists(): st.session_state.image = Image.open(entry["image"]) st.session_state.loaded_image_name = entry["image"].name else: st.session_state.image = None # 🎯 自动模式判断 if has_structure: st.session_state.dir_auto_mode = "edit" st.session_state.loaded_config_name = structure_file.name try: structure = load_structure_from_config(structure_file) st.session_state.structure = structure st.session_state.undo_stack = [] st.session_state.redo_stack = [] clear_table_image_cache() st.sidebar.success(f"✅ 编辑模式") except Exception as e: st.error(f"❌ 加载标注失败: {e}") st.session_state.dir_auto_mode = "new" else: st.session_state.dir_auto_mode = "new" if 'structure' in st.session_state: del st.session_state.structure if 'generator' in st.session_state: del st.session_state.generator st.sidebar.info(f"🆕 新建模式") # 标记已加载 st.session_state.last_loaded_entry = current_entry_key st.info(f"📂 已加载: {entry['json'].name}") # 页码跳转处理 if page_input != catalog[selected]["index"]: target = next((i for i, item in enumerate(catalog) if item["index"] == page_input), None) if target is not None: st.session_state.dir_selected_index = target st.rerun() return st.session_state.get('dir_auto_mode', 'new')