|
|
@@ -7,61 +7,16 @@ import json
|
|
|
from pathlib import Path
|
|
|
from PIL import Image
|
|
|
import tempfile
|
|
|
+from typing import Dict, List
|
|
|
|
|
|
try:
|
|
|
from ..table_line_generator import TableLineGenerator
|
|
|
except ImportError:
|
|
|
from table_line_generator import TableLineGenerator
|
|
|
|
|
|
-from .config_loader import load_structure_from_config
|
|
|
+from .config_loader import load_structure_from_config, build_data_source_catalog
|
|
|
from .drawing import clear_table_image_cache
|
|
|
|
|
|
-
|
|
|
-def parse_ocr_data(ocr_data):
|
|
|
- """解析OCR数据,支持多种格式"""
|
|
|
- # 如果是字符串,尝试解析
|
|
|
- if isinstance(ocr_data, str):
|
|
|
- try:
|
|
|
- ocr_data = json.loads(ocr_data)
|
|
|
- except json.JSONDecodeError:
|
|
|
- st.error("❌ JSON 格式错误,无法解析")
|
|
|
- return []
|
|
|
-
|
|
|
- # 检查是否为 PPStructure V3 格式
|
|
|
- if isinstance(ocr_data, dict) and 'parsing_res_list' in ocr_data and 'overall_ocr_res' in ocr_data:
|
|
|
- st.info("🔍 检测到 PPStructure V3 格式")
|
|
|
-
|
|
|
- try:
|
|
|
- table_bbox, text_boxes = TableLineGenerator.parse_ppstructure_result(ocr_data)
|
|
|
- st.success(f"✅ 表格区域: {table_bbox}")
|
|
|
- st.success(f"✅ 表格内文本框: {len(text_boxes)} 个")
|
|
|
- return text_boxes
|
|
|
- except Exception as e:
|
|
|
- st.error(f"❌ 解析 PPStructure 结果失败: {e}")
|
|
|
- return []
|
|
|
-
|
|
|
- # 确保是列表
|
|
|
- if not isinstance(ocr_data, list):
|
|
|
- st.error(f"❌ OCR 数据应该是列表,实际类型: {type(ocr_data)}")
|
|
|
- return []
|
|
|
-
|
|
|
- if not ocr_data:
|
|
|
- st.warning("⚠️ OCR 数据为空")
|
|
|
- return []
|
|
|
-
|
|
|
- first_item = ocr_data[0]
|
|
|
- if not isinstance(first_item, dict):
|
|
|
- st.error(f"❌ OCR 数据项应该是字典,实际类型: {type(first_item)}")
|
|
|
- return []
|
|
|
-
|
|
|
- if 'bbox' not in first_item:
|
|
|
- st.error("❌ OCR 数据缺少 'bbox' 字段")
|
|
|
- st.info("💡 支持的格式示例:\n```json\n[\n {\n \"text\": \"文本\",\n \"bbox\": [x1, y1, x2, y2]\n }\n]\n```")
|
|
|
- return []
|
|
|
-
|
|
|
- return ocr_data
|
|
|
-
|
|
|
-
|
|
|
def create_file_uploader_section(work_mode: str):
|
|
|
"""
|
|
|
创建文件上传区域
|
|
|
@@ -221,16 +176,34 @@ def create_file_uploader_section(work_mode: str):
|
|
|
st.info("💡 提示:配置文件已加载,您可以:\n1. 上传对应图片查看效果\n2. 直接编辑配置并保存")
|
|
|
|
|
|
|
|
|
-def create_display_settings_section():
|
|
|
- """创建显示设置区域"""
|
|
|
+def create_display_settings_section(display_config: Dict):
|
|
|
+ """显示设置(由配置驱动)"""
|
|
|
st.sidebar.divider()
|
|
|
st.sidebar.subheader("🖼️ 显示设置")
|
|
|
-
|
|
|
- line_width = st.sidebar.slider("线条宽度", 1, 5, 2)
|
|
|
- display_mode = st.sidebar.radio("显示模式", ["对比显示", "仅显示划线图", "仅显示原图"], index=1)
|
|
|
- zoom_level = st.sidebar.slider("图片缩放", 0.25, 2.0, 1.0, 0.25)
|
|
|
- show_line_numbers = st.sidebar.checkbox("显示线条编号", value=True)
|
|
|
-
|
|
|
+
|
|
|
+ line_width = st.sidebar.slider(
|
|
|
+ "线条宽度",
|
|
|
+ int(display_config.get("line_width_min", 1)),
|
|
|
+ int(display_config.get("line_width_max", 5)),
|
|
|
+ int(display_config.get("default_line_width", 2)),
|
|
|
+ )
|
|
|
+ display_mode = st.sidebar.radio(
|
|
|
+ "显示模式",
|
|
|
+ ["对比显示", "仅显示划线图", "仅显示原图"],
|
|
|
+ index=1,
|
|
|
+ )
|
|
|
+ zoom_level = st.sidebar.slider(
|
|
|
+ "图片缩放",
|
|
|
+ float(display_config.get("zoom_min", 0.25)),
|
|
|
+ float(display_config.get("zoom_max", 2.0)),
|
|
|
+ float(display_config.get("default_zoom", 1.0)),
|
|
|
+ float(display_config.get("zoom_step", 0.25)),
|
|
|
+ )
|
|
|
+ show_line_numbers = st.sidebar.checkbox(
|
|
|
+ "显示线条编号",
|
|
|
+ value=bool(display_config.get("show_line_numbers", True)),
|
|
|
+ )
|
|
|
+
|
|
|
return line_width, display_mode, zoom_level, show_line_numbers
|
|
|
|
|
|
|
|
|
@@ -313,41 +286,56 @@ def create_analysis_section(y_tolerance, x_tolerance, min_row_height):
|
|
|
st.stop()
|
|
|
|
|
|
|
|
|
-def create_save_section(work_mode, structure, image, line_width):
|
|
|
+def create_save_section(work_mode, structure, image, line_width, output_config: Dict):
|
|
|
"""
|
|
|
- 创建保存区域
|
|
|
-
|
|
|
- Args:
|
|
|
- work_mode: 工作模式
|
|
|
- structure: 表格结构
|
|
|
- image: 图片
|
|
|
- line_width: 线条宽度
|
|
|
+ 保存设置(目录/命名来自配置)
|
|
|
"""
|
|
|
from .config_loader import save_structure_to_config
|
|
|
from .drawing import draw_clean_table_lines
|
|
|
import io
|
|
|
-
|
|
|
+
|
|
|
st.divider()
|
|
|
-
|
|
|
+
|
|
|
+ defaults = output_config.get("defaults", {})
|
|
|
+ line_colors = output_config.get("line_colors") or [
|
|
|
+ {"name": "黑色", "rgb": [0, 0, 0]},
|
|
|
+ {"name": "蓝色", "rgb": [0, 0, 255]},
|
|
|
+ {"name": "红色", "rgb": [255, 0, 0]},
|
|
|
+ ]
|
|
|
+
|
|
|
save_col1, save_col2, save_col3 = st.columns(3)
|
|
|
-
|
|
|
+
|
|
|
with save_col1:
|
|
|
- save_structure = st.checkbox("保存表格结构配置", value=True)
|
|
|
-
|
|
|
+ save_structure = st.checkbox(
|
|
|
+ "保存表格结构配置",
|
|
|
+ value=bool(defaults.get("save_structure", True)),
|
|
|
+ )
|
|
|
+
|
|
|
with save_col2:
|
|
|
- save_image = st.checkbox("保存表格线图片", value=True)
|
|
|
-
|
|
|
+ save_image = st.checkbox(
|
|
|
+ "保存表格线图片",
|
|
|
+ value=bool(defaults.get("save_image", True)),
|
|
|
+ )
|
|
|
+
|
|
|
+ color_names = [c["name"] for c in line_colors]
|
|
|
+ default_color = defaults.get("line_color", color_names[0])
|
|
|
+ default_index = color_names.index(default_color) if default_color in color_names else 0
|
|
|
+
|
|
|
with save_col3:
|
|
|
line_color_option = st.selectbox(
|
|
|
"保存时线条颜色",
|
|
|
- ["黑色", "蓝色", "红色"],
|
|
|
- index=0
|
|
|
+ color_names,
|
|
|
+ label_visibility="collapsed",
|
|
|
+ index=default_index,
|
|
|
)
|
|
|
-
|
|
|
+
|
|
|
if st.button("💾 保存", type="primary"):
|
|
|
- output_dir = Path("output/table_structures")
|
|
|
+ output_dir = Path(output_config.get("directory", "output/table_structures"))
|
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
-
|
|
|
+
|
|
|
+ structure_suffix = output_config.get("structure_suffix", "_structure.json")
|
|
|
+ image_suffix = output_config.get("image_suffix", "_with_lines.png")
|
|
|
+
|
|
|
# 确定文件名
|
|
|
if work_mode == "🆕 新建标注":
|
|
|
if st.session_state.loaded_json_name:
|
|
|
@@ -367,7 +355,8 @@ def create_save_section(work_mode, structure, image, line_width):
|
|
|
saved_files = []
|
|
|
|
|
|
if save_structure:
|
|
|
- structure_path = output_dir / f"{base_name}_structure.json"
|
|
|
+ structure_filename = f"{base_name}{structure_suffix}"
|
|
|
+ structure_path = output_dir / structure_filename
|
|
|
save_structure_to_config(structure, structure_path)
|
|
|
saved_files.append(("配置文件", structure_path))
|
|
|
|
|
|
@@ -383,21 +372,18 @@ def create_save_section(work_mode, structure, image, line_width):
|
|
|
if st.session_state.image is None:
|
|
|
st.warning("⚠️ 无法保存图片:未加载图片文件")
|
|
|
else:
|
|
|
- color_map = {
|
|
|
- "黑色": (0, 0, 0),
|
|
|
- "蓝色": (0, 0, 255),
|
|
|
- "红色": (255, 0, 0)
|
|
|
- }
|
|
|
- selected_color = color_map[line_color_option]
|
|
|
-
|
|
|
+ selected_color_rgb = next(
|
|
|
+ (tuple(c["rgb"]) for c in line_colors if c["name"] == line_color_option),
|
|
|
+ (0, 0, 0),
|
|
|
+ )
|
|
|
clean_img = draw_clean_table_lines(
|
|
|
st.session_state.image,
|
|
|
structure,
|
|
|
line_width=line_width,
|
|
|
- line_color=selected_color
|
|
|
+ line_color=selected_color_rgb,
|
|
|
)
|
|
|
-
|
|
|
- output_image_path = output_dir / f"{base_name}_with_lines.png"
|
|
|
+ image_filename = f"{base_name}{image_suffix}"
|
|
|
+ output_image_path = output_dir / image_filename
|
|
|
clean_img.save(output_image_path)
|
|
|
saved_files.append(("表格线图片", output_image_path))
|
|
|
|
|
|
@@ -415,4 +401,234 @@ def create_save_section(work_mode, structure, image, line_width):
|
|
|
if saved_files:
|
|
|
st.success(f"✅ 已保存 {len(saved_files)} 个文件:")
|
|
|
for file_type, file_path in saved_files:
|
|
|
- st.info(f" • {file_type}: {file_path}")
|
|
|
+ st.info(f" • {file_type}: {file_path}")
|
|
|
+
|
|
|
+def setup_new_annotation_mode(ocr_data, image, config: Dict):
|
|
|
+ """
|
|
|
+ 设置新建标注模式的通用逻辑
|
|
|
+
|
|
|
+ Args:
|
|
|
+ ocr_data: OCR 数据
|
|
|
+ image: 图片对象
|
|
|
+ config: 显示配置
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ tuple: (y_tolerance, x_tolerance, min_row_height, line_width, display_mode, zoom_level, show_line_numbers)
|
|
|
+ """
|
|
|
+ # 参数调整
|
|
|
+ st.sidebar.header("🔧 参数调整")
|
|
|
+ y_tolerance = st.sidebar.slider("Y轴聚类容差(像素)", 1, 20, 5, key="new_y_tol")
|
|
|
+ x_tolerance = st.sidebar.slider("X轴聚类容差(像素)", 5, 50, 10, key="new_x_tol")
|
|
|
+ min_row_height = st.sidebar.slider("最小行高(像素)", 10, 100, 20, key="new_min_h")
|
|
|
+
|
|
|
+ # 显示设置
|
|
|
+ line_width, display_mode, zoom_level, show_line_numbers = create_display_settings_section(config)
|
|
|
+ create_undo_redo_section()
|
|
|
+
|
|
|
+ # 初始化生成器
|
|
|
+ if 'generator' not in st.session_state or st.session_state.generator is None:
|
|
|
+ try:
|
|
|
+ generator = TableLineGenerator(image, ocr_data)
|
|
|
+ st.session_state.generator = generator
|
|
|
+ except Exception as e:
|
|
|
+ st.error(f"❌ 初始化生成器失败: {e}")
|
|
|
+ st.stop()
|
|
|
+
|
|
|
+ # 分析按钮
|
|
|
+ create_analysis_section(y_tolerance, x_tolerance, min_row_height)
|
|
|
+
|
|
|
+ return y_tolerance, x_tolerance, min_row_height, line_width, display_mode, zoom_level, show_line_numbers
|
|
|
+
|
|
|
+
|
|
|
+def setup_edit_annotation_mode(structure, image, config: Dict):
|
|
|
+ """
|
|
|
+ 设置编辑标注模式的通用逻辑
|
|
|
+
|
|
|
+ Args:
|
|
|
+ structure: 表格结构
|
|
|
+ image: 图片对象(可为 None)
|
|
|
+ config: 显示配置
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ tuple: (image, line_width, display_mode, zoom_level, show_line_numbers)
|
|
|
+ """
|
|
|
+ # 如果没有图片,创建虚拟画布
|
|
|
+ if image is None:
|
|
|
+ if 'table_bbox' in structure:
|
|
|
+ bbox = structure['table_bbox']
|
|
|
+ dummy_width = bbox[2] + 100
|
|
|
+ dummy_height = bbox[3] + 100
|
|
|
+ else:
|
|
|
+ dummy_width = 2000
|
|
|
+ dummy_height = 2000
|
|
|
+ image = Image.new('RGB', (dummy_width, dummy_height), color='white')
|
|
|
+ st.info(f"💡 使用虚拟画布 ({dummy_width}x{dummy_height})")
|
|
|
+
|
|
|
+ # 显示设置
|
|
|
+ line_width, display_mode, zoom_level, show_line_numbers = create_display_settings_section(config)
|
|
|
+ create_undo_redo_section()
|
|
|
+
|
|
|
+ return image, line_width, display_mode, zoom_level, show_line_numbers
|
|
|
+
|
|
|
+
|
|
|
+def render_table_structure_view(structure, image, line_width, display_mode, zoom_level, show_line_numbers,
|
|
|
+ viewport_width, viewport_height):
|
|
|
+ """
|
|
|
+ 渲染表格结构视图(统一三种模式的显示逻辑)
|
|
|
+
|
|
|
+ Args:
|
|
|
+ structure: 表格结构
|
|
|
+ image: 图片对象
|
|
|
+ line_width: 线条宽度
|
|
|
+ display_mode: 显示模式
|
|
|
+ zoom_level: 缩放级别
|
|
|
+ show_line_numbers: 是否显示线条编号
|
|
|
+ viewport_width: 视口宽度
|
|
|
+ viewport_height: 视口高度
|
|
|
+ """
|
|
|
+ # 绘制表格线
|
|
|
+ img_with_lines = get_cached_table_lines_image(
|
|
|
+ image, structure, line_width=line_width, show_numbers=show_line_numbers
|
|
|
+ )
|
|
|
+
|
|
|
+ # 根据显示模式显示图片
|
|
|
+ if display_mode == "对比显示":
|
|
|
+ col1, col2 = st.columns(2)
|
|
|
+ with col1:
|
|
|
+ show_image_with_scroll(image, "原图", viewport_width, viewport_height, zoom_level)
|
|
|
+ with col2:
|
|
|
+ show_image_with_scroll(img_with_lines, "表格线", viewport_width, viewport_height, zoom_level)
|
|
|
+ elif display_mode == "仅显示划线图":
|
|
|
+ show_image_with_scroll(
|
|
|
+ img_with_lines,
|
|
|
+ f"表格线图 (缩放: {zoom_level:.0%})",
|
|
|
+ viewport_width,
|
|
|
+ viewport_height,
|
|
|
+ zoom_level
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ show_image_with_scroll(
|
|
|
+ image,
|
|
|
+ f"原图 (缩放: {zoom_level:.0%})",
|
|
|
+ viewport_width,
|
|
|
+ viewport_height,
|
|
|
+ zoom_level
|
|
|
+ )
|
|
|
+
|
|
|
+ # 手动调整区域
|
|
|
+ create_adjustment_section(structure)
|
|
|
+
|
|
|
+ # 显示详细信息
|
|
|
+ with st.expander("📊 表格结构详情"):
|
|
|
+ st.json({
|
|
|
+ "行数": len(structure['rows']),
|
|
|
+ "列数": len(structure['columns']),
|
|
|
+ "横线数": len(structure.get('horizontal_lines', [])),
|
|
|
+ "竖线数": len(structure.get('vertical_lines', [])),
|
|
|
+ "横线坐标": structure.get('horizontal_lines', []),
|
|
|
+ "竖线坐标": structure.get('vertical_lines', []),
|
|
|
+ "标准行高": structure.get('row_height'),
|
|
|
+ "列宽度": structure.get('col_widths'),
|
|
|
+ "修改的横线": list(structure.get('modified_h_lines', set())),
|
|
|
+ "修改的竖线": list(structure.get('modified_v_lines', set()))
|
|
|
+ })
|
|
|
+
|
|
|
+
|
|
|
+def create_directory_selector(data_sources: List[Dict], global_output_config: Dict):
|
|
|
+ """目录模式选择器(优化:避免重复加载)"""
|
|
|
+ st.sidebar.subheader("目录模式")
|
|
|
+ source_names = [src["name"] for src in data_sources]
|
|
|
+ selected_name = st.sidebar.selectbox("选择数据源", source_names, key="dir_mode_source")
|
|
|
+ source_cfg = next(src for src in data_sources if src["name"] == selected_name)
|
|
|
+
|
|
|
+ output_cfg = source_cfg.get("output", global_output_config)
|
|
|
+ output_dir = Path(output_cfg.get("directory", "output/table_structures"))
|
|
|
+ structure_suffix = output_cfg.get("structure_suffix", "_structure.json")
|
|
|
+
|
|
|
+ catalog_key = f"catalog::{selected_name}"
|
|
|
+ if catalog_key not in st.session_state:
|
|
|
+ st.session_state[catalog_key] = build_data_source_catalog(source_cfg)
|
|
|
+ catalog = st.session_state[catalog_key]
|
|
|
+
|
|
|
+ if not catalog:
|
|
|
+ st.sidebar.warning("目录中没有 JSON 文件")
|
|
|
+ return
|
|
|
+
|
|
|
+ if 'dir_selected_index' not in st.session_state:
|
|
|
+ st.session_state.dir_selected_index = 0
|
|
|
+
|
|
|
+ selected = st.sidebar.selectbox(
|
|
|
+ "选择文件",
|
|
|
+ range(len(catalog)),
|
|
|
+ format_func=lambda i: catalog[i]["display"],
|
|
|
+ index=st.session_state.dir_selected_index,
|
|
|
+ key="dir_select_box"
|
|
|
+ )
|
|
|
+
|
|
|
+ page_input = st.sidebar.number_input(
|
|
|
+ "页码跳转",
|
|
|
+ min_value=1,
|
|
|
+ max_value=len(catalog),
|
|
|
+ value=catalog[selected]["index"],
|
|
|
+ step=1,
|
|
|
+ key="dir_page_input"
|
|
|
+ )
|
|
|
+
|
|
|
+ # 🔑 关键优化:只在切换文件时才重新加载
|
|
|
+ current_entry_key = f"{selected_name}::{catalog[selected]['json']}"
|
|
|
+
|
|
|
+ if 'last_loaded_entry' not in st.session_state or st.session_state.last_loaded_entry != current_entry_key:
|
|
|
+ # 文件切换,重新加载
|
|
|
+ entry = catalog[selected]
|
|
|
+ base_name = entry["json"].stem
|
|
|
+ structure_file = output_dir / f"{base_name}{structure_suffix}"
|
|
|
+ has_structure = structure_file.exists()
|
|
|
+
|
|
|
+ # 📂 加载 JSON
|
|
|
+ with open(entry["json"], "r", encoding="utf-8") as fp:
|
|
|
+ raw = json.load(fp)
|
|
|
+ st.session_state.ocr_data = parse_ocr_data(raw)
|
|
|
+ st.session_state.loaded_json_name = entry["json"].name
|
|
|
+
|
|
|
+ # 🖼️ 加载图片
|
|
|
+ if entry["image"] and entry["image"].exists():
|
|
|
+ st.session_state.image = Image.open(entry["image"])
|
|
|
+ st.session_state.loaded_image_name = entry["image"].name
|
|
|
+ else:
|
|
|
+ st.session_state.image = None
|
|
|
+
|
|
|
+ # 🎯 自动模式判断
|
|
|
+ if has_structure:
|
|
|
+ st.session_state.dir_auto_mode = "edit"
|
|
|
+ st.session_state.loaded_config_name = structure_file.name
|
|
|
+
|
|
|
+ try:
|
|
|
+ structure = load_structure_from_config(structure_file)
|
|
|
+ st.session_state.structure = structure
|
|
|
+ st.session_state.undo_stack = []
|
|
|
+ st.session_state.redo_stack = []
|
|
|
+ clear_table_image_cache()
|
|
|
+ st.sidebar.success(f"✅ 编辑模式")
|
|
|
+ except Exception as e:
|
|
|
+ st.error(f"❌ 加载标注失败: {e}")
|
|
|
+ st.session_state.dir_auto_mode = "new"
|
|
|
+ else:
|
|
|
+ st.session_state.dir_auto_mode = "new"
|
|
|
+ if 'structure' in st.session_state:
|
|
|
+ del st.session_state.structure
|
|
|
+ if 'generator' in st.session_state:
|
|
|
+ del st.session_state.generator
|
|
|
+ st.sidebar.info(f"🆕 新建模式")
|
|
|
+
|
|
|
+ # 标记已加载
|
|
|
+ st.session_state.last_loaded_entry = current_entry_key
|
|
|
+ st.info(f"📂 已加载: {entry['json'].name}")
|
|
|
+
|
|
|
+ # 页码跳转处理
|
|
|
+ if page_input != catalog[selected]["index"]:
|
|
|
+ target = next((i for i, item in enumerate(catalog) if item["index"] == page_input), None)
|
|
|
+ if target is not None:
|
|
|
+ st.session_state.dir_selected_index = target
|
|
|
+ st.rerun()
|
|
|
+
|
|
|
+ return st.session_state.get('dir_auto_mode', 'new')
|