| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418 |
- """
- UI 组件
- """
- import streamlit as st
- import json
- from pathlib import Path
- from PIL import Image
- import tempfile
- try:
- from ..table_line_generator import TableLineGenerator
- except ImportError:
- from table_line_generator import TableLineGenerator
- from .config_loader import load_structure_from_config
- from .drawing import clear_table_image_cache
- def parse_ocr_data(ocr_data):
- """解析OCR数据,支持多种格式"""
- # 如果是字符串,尝试解析
- if isinstance(ocr_data, str):
- try:
- ocr_data = json.loads(ocr_data)
- except json.JSONDecodeError:
- st.error("❌ JSON 格式错误,无法解析")
- return []
-
- # 检查是否为 PPStructure V3 格式
- if isinstance(ocr_data, dict) and 'parsing_res_list' in ocr_data and 'overall_ocr_res' in ocr_data:
- st.info("🔍 检测到 PPStructure V3 格式")
-
- try:
- table_bbox, text_boxes = TableLineGenerator.parse_ppstructure_result(ocr_data)
- st.success(f"✅ 表格区域: {table_bbox}")
- st.success(f"✅ 表格内文本框: {len(text_boxes)} 个")
- return text_boxes
- except Exception as e:
- st.error(f"❌ 解析 PPStructure 结果失败: {e}")
- return []
-
- # 确保是列表
- if not isinstance(ocr_data, list):
- st.error(f"❌ OCR 数据应该是列表,实际类型: {type(ocr_data)}")
- return []
-
- if not ocr_data:
- st.warning("⚠️ OCR 数据为空")
- return []
-
- first_item = ocr_data[0]
- if not isinstance(first_item, dict):
- st.error(f"❌ OCR 数据项应该是字典,实际类型: {type(first_item)}")
- return []
-
- if 'bbox' not in first_item:
- st.error("❌ OCR 数据缺少 'bbox' 字段")
- st.info("💡 支持的格式示例:\n```json\n[\n {\n \"text\": \"文本\",\n \"bbox\": [x1, y1, x2, y2]\n }\n]\n```")
- return []
-
- return ocr_data
- def create_file_uploader_section(work_mode: str):
- """
- 创建文件上传区域
-
- Args:
- work_mode: 工作模式("🆕 新建标注" 或 "📂 加载已有标注")
- """
- if work_mode == "🆕 新建标注":
- st.sidebar.subheader("上传文件")
- uploaded_json = st.sidebar.file_uploader("上传OCR结果JSON", type=['json'], key="new_json")
- uploaded_image = st.sidebar.file_uploader("上传对应图片", type=['jpg', 'png'], key="new_image")
-
- # 处理 JSON 上传
- if uploaded_json is not None:
- if st.session_state.loaded_json_name != uploaded_json.name:
- try:
- raw_data = json.load(uploaded_json)
-
- with st.expander("🔍 原始数据结构"):
- if isinstance(raw_data, dict):
- st.json({k: f"<{type(v).__name__}>" if not isinstance(v, (str, int, float, bool, type(None))) else v
- for k, v in list(raw_data.items())[:5]})
- else:
- st.json(raw_data[:3] if len(raw_data) > 3 else raw_data)
-
- ocr_data = parse_ocr_data(raw_data)
-
- if not ocr_data:
- st.error("❌ 无法解析 OCR 数据,请检查 JSON 格式")
- st.stop()
-
- st.session_state.ocr_data = ocr_data
- st.session_state.loaded_json_name = uploaded_json.name
- st.session_state.loaded_config_name = None
-
- # 清除旧数据
- if 'structure' in st.session_state:
- del st.session_state.structure
- if 'generator' in st.session_state:
- del st.session_state.generator
- st.session_state.undo_stack = []
- st.session_state.redo_stack = []
- clear_table_image_cache()
-
- st.success(f"✅ 成功加载 {len(ocr_data)} 条 OCR 记录")
-
- except Exception as e:
- st.error(f"❌ 加载数据失败: {e}")
- st.stop()
-
- # 处理图片上传
- if uploaded_image is not None:
- if st.session_state.loaded_image_name != uploaded_image.name:
- try:
- image = Image.open(uploaded_image)
- st.session_state.image = image
- st.session_state.loaded_image_name = uploaded_image.name
-
- # 清除旧数据
- if 'structure' in st.session_state:
- del st.session_state.structure
- if 'generator' in st.session_state:
- del st.session_state.generator
- st.session_state.undo_stack = []
- st.session_state.redo_stack = []
- clear_table_image_cache()
-
- st.success(f"✅ 成功加载图片: {uploaded_image.name}")
-
- except Exception as e:
- st.error(f"❌ 加载图片失败: {e}")
- st.stop()
-
- else: # 加载已有标注
- st.sidebar.subheader("加载已保存的标注")
-
- uploaded_config = st.sidebar.file_uploader(
- "上传配置文件 (*_structure.json)",
- type=['json'],
- key="load_config"
- )
-
- uploaded_image_for_config = st.sidebar.file_uploader(
- "上传对应图片(可选)",
- type=['jpg', 'png'],
- key="load_image"
- )
-
- # 处理配置文件加载
- if uploaded_config is not None:
- if st.session_state.loaded_config_name != uploaded_config.name:
- try:
- # 创建临时文件
- with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False, encoding='utf-8') as tmp:
- tmp.write(uploaded_config.getvalue().decode('utf-8'))
- tmp_path = tmp.name
-
- # 加载结构
- structure = load_structure_from_config(Path(tmp_path))
-
- # 清理临时文件
- Path(tmp_path).unlink()
-
- st.session_state.structure = structure
- st.session_state.loaded_config_name = uploaded_config.name
-
- # 清除历史记录和缓存
- st.session_state.undo_stack = []
- st.session_state.redo_stack = []
- clear_table_image_cache()
-
- st.success(f"✅ 成功加载配置: {uploaded_config.name}")
- st.info(
- f"📊 表格结构: {len(structure['rows'])}行 x {len(structure['columns'])}列\n\n"
- f"📏 横线数: {len(structure.get('horizontal_lines', []))}\n\n"
- f"📏 竖线数: {len(structure.get('vertical_lines', []))}"
- )
-
- # 显示配置文件详情
- with st.expander("📋 配置详情"):
- st.json({
- "行数": len(structure['rows']),
- "列数": len(structure['columns']),
- "横线数": len(structure.get('horizontal_lines', [])),
- "竖线数": len(structure.get('vertical_lines', [])),
- "行高": structure.get('row_height'),
- "列宽": structure.get('col_widths'),
- "已修改的横线": list(structure.get('modified_h_lines', set())),
- "已修改的竖线": list(structure.get('modified_v_lines', set()))
- })
-
- except Exception as e:
- st.error(f"❌ 加载配置失败: {e}")
- import traceback
- st.code(traceback.format_exc())
- st.stop()
-
- # 处理图片加载
- if uploaded_image_for_config is not None:
- if st.session_state.loaded_image_name != uploaded_image_for_config.name:
- try:
- image = Image.open(uploaded_image_for_config)
- st.session_state.image = image
- st.session_state.loaded_image_name = uploaded_image_for_config.name
-
- clear_table_image_cache()
-
- st.success(f"✅ 成功加载图片: {uploaded_image_for_config.name}")
-
- except Exception as e:
- st.error(f"❌ 加载图片失败: {e}")
- st.stop()
-
- # 提示信息
- if 'structure' in st.session_state and st.session_state.image is None:
- st.warning("⚠️ 已加载配置,但未加载对应图片。请上传图片以查看效果。")
- st.info("💡 提示:配置文件已加载,您可以:\n1. 上传对应图片查看效果\n2. 直接编辑配置并保存")
- def create_display_settings_section():
- """创建显示设置区域"""
- st.sidebar.divider()
- st.sidebar.subheader("🖼️ 显示设置")
-
- line_width = st.sidebar.slider("线条宽度", 1, 5, 2)
- display_mode = st.sidebar.radio("显示模式", ["对比显示", "仅显示划线图", "仅显示原图"], index=1)
- zoom_level = st.sidebar.slider("图片缩放", 0.25, 2.0, 1.0, 0.25)
- show_line_numbers = st.sidebar.checkbox("显示线条编号", value=True)
-
- return line_width, display_mode, zoom_level, show_line_numbers
- def create_undo_redo_section():
- """创建撤销/重做区域"""
- from .state_manager import undo_last_action, redo_last_action
- from .drawing import clear_table_image_cache
-
- st.sidebar.divider()
- st.sidebar.subheader("↩️ 撤销/重做")
-
- col1, col2 = st.sidebar.columns(2)
- with col1:
- if st.button("↩️ 撤销", disabled=len(st.session_state.undo_stack) == 0):
- if undo_last_action():
- clear_table_image_cache()
- st.success("✅ 已撤销")
- st.rerun()
-
- with col2:
- if st.button("↪️ 重做", disabled=len(st.session_state.redo_stack) == 0):
- if redo_last_action():
- clear_table_image_cache()
- st.success("✅ 已重做")
- st.rerun()
-
- st.sidebar.info(f"📚 历史记录: {len(st.session_state.undo_stack)} 条")
- def create_analysis_section(y_tolerance, x_tolerance, min_row_height):
- """
- 创建分析区域
-
- Args:
- y_tolerance: Y轴聚类容差
- x_tolerance: X轴聚类容差
- min_row_height: 最小行高
- """
- if st.button("🔍 分析表格结构"):
- with st.spinner("分析中..."):
- try:
- generator = st.session_state.generator
- structure = generator.analyze_table_structure(
- y_tolerance=y_tolerance,
- x_tolerance=x_tolerance,
- min_row_height=min_row_height
- )
-
- if not structure:
- st.warning("⚠️ 未检测到表格结构")
- st.stop()
-
- structure['modified_h_lines'] = set()
- structure['modified_v_lines'] = set()
-
- st.session_state.structure = structure
- st.session_state.undo_stack = []
- st.session_state.redo_stack = []
- clear_table_image_cache()
-
- st.success(
- f"✅ 检测到 {len(structure['rows'])} 行({len(structure['horizontal_lines'])} 条横线),"
- f"{len(structure['columns'])} 列({len(structure['vertical_lines'])} 条竖线)"
- )
-
- col1, col2, col3, col4 = st.columns(4)
- with col1:
- st.metric("行数", len(structure['rows']))
- with col2:
- st.metric("横线数", len(structure['horizontal_lines']))
- with col3:
- st.metric("列数", len(structure['columns']))
- with col4:
- st.metric("竖线数", len(structure['vertical_lines']))
-
- except Exception as e:
- st.error(f"❌ 分析失败: {e}")
- import traceback
- st.code(traceback.format_exc())
- st.stop()
- def create_save_section(work_mode, structure, image, line_width):
- """
- 创建保存区域
-
- Args:
- work_mode: 工作模式
- structure: 表格结构
- image: 图片
- line_width: 线条宽度
- """
- from .config_loader import save_structure_to_config
- from .drawing import draw_clean_table_lines
- import io
-
- st.divider()
-
- save_col1, save_col2, save_col3 = st.columns(3)
-
- with save_col1:
- save_structure = st.checkbox("保存表格结构配置", value=True)
-
- with save_col2:
- save_image = st.checkbox("保存表格线图片", value=True)
-
- with save_col3:
- line_color_option = st.selectbox(
- "保存时线条颜色",
- ["黑色", "蓝色", "红色"],
- index=0
- )
-
- if st.button("💾 保存", type="primary"):
- output_dir = Path("output/table_structures")
- output_dir.mkdir(parents=True, exist_ok=True)
-
- # 确定文件名
- if work_mode == "🆕 新建标注":
- if st.session_state.loaded_json_name:
- base_name = Path(st.session_state.loaded_json_name).stem
- else:
- base_name = "table_structure"
- else:
- if st.session_state.loaded_config_name:
- base_name = Path(st.session_state.loaded_config_name).stem
- if base_name.endswith('_structure'):
- base_name = base_name[:-10]
- elif st.session_state.loaded_image_name:
- base_name = Path(st.session_state.loaded_image_name).stem
- else:
- base_name = "table_structure"
-
- saved_files = []
-
- if save_structure:
- structure_path = output_dir / f"{base_name}_structure.json"
- save_structure_to_config(structure, structure_path)
- saved_files.append(("配置文件", structure_path))
-
- with open(structure_path, 'r') as f:
- st.download_button(
- "📥 下载配置文件",
- f.read(),
- file_name=f"{base_name}_structure.json",
- mime="application/json"
- )
-
- if save_image:
- if st.session_state.image is None:
- st.warning("⚠️ 无法保存图片:未加载图片文件")
- else:
- color_map = {
- "黑色": (0, 0, 0),
- "蓝色": (0, 0, 255),
- "红色": (255, 0, 0)
- }
- selected_color = color_map[line_color_option]
-
- clean_img = draw_clean_table_lines(
- st.session_state.image,
- structure,
- line_width=line_width,
- line_color=selected_color
- )
-
- output_image_path = output_dir / f"{base_name}_with_lines.png"
- clean_img.save(output_image_path)
- saved_files.append(("表格线图片", output_image_path))
-
- buf = io.BytesIO()
- clean_img.save(buf, format='PNG')
- buf.seek(0)
-
- st.download_button(
- "📥 下载表格线图片",
- buf,
- file_name=f"{base_name}_with_lines.png",
- mime="image/png"
- )
-
- if saved_files:
- st.success(f"✅ 已保存 {len(saved_files)} 个文件:")
- for file_type, file_path in saved_files:
- st.info(f" • {file_type}: {file_path}")
|