Forráskód Böngészése

feat: 优化分析控件和数据处理,支持多种工具类型解析

zhch158_admin 2 napja
szülő
commit
8c92209400

+ 116 - 51
table_line_generator/editor/analysis_controls.py

@@ -1,60 +1,125 @@
 """
-表格结构分析控件
+分析功能控件
 """
 import streamlit as st
-from .drawing import clear_table_image_cache
+from typing import Dict, Optional
+import json
 
 
-def create_analysis_section(y_tolerance: int, x_tolerance: int, min_row_height: int):
+def create_analysis_section(generator, tool: str = "ppstructv3") -> Optional[Dict]:
     """
-    创建分析区域
+    创建分析控件
     
     Args:
-        y_tolerance: Y轴聚类容差
-        x_tolerance: X轴聚类容差
-        min_row_height: 最小行高
+        generator: TableLineGenerator 实例
+        tool: 工具类型
+    
+    Returns:
+        分析后的表格结构(如果点击了分析按钮)
     """
-    if st.button("🔍 分析表格结构"):
-        with st.spinner("分析中..."):
-            try:
-                generator = st.session_state.generator
-                structure = generator.analyze_table_structure(
-                    y_tolerance=y_tolerance,
-                    x_tolerance=x_tolerance,
-                    min_row_height=min_row_height
-                )
-                
-                if not structure:
-                    st.warning("⚠️ 未检测到表格结构")
-                    st.stop()
-                
-                structure['modified_h_lines'] = set()
-                structure['modified_v_lines'] = set()
-                
-                st.session_state.structure = structure
-                st.session_state.undo_stack = []
-                st.session_state.redo_stack = []
-                clear_table_image_cache()
-                
-                st.success(
-                    f"✅ 检测到 {len(structure['rows'])} 行"
-                    f"({len(structure['horizontal_lines'])} 条横线),"
-                    f"{len(structure['columns'])} 列"
-                    f"({len(structure['vertical_lines'])} 条竖线)"
-                )
-                
-                col1, col2, col3, col4 = st.columns(4)
-                with col1:
-                    st.metric("行数", len(structure['rows']))
-                with col2:
-                    st.metric("横线数", len(structure['horizontal_lines']))
-                with col3:
-                    st.metric("列数", len(structure['columns']))
-                with col4:
-                    st.metric("竖线数", len(structure['vertical_lines']))
-            
-            except Exception as e:
-                st.error(f"❌ 分析失败: {e}")
-                import traceback
-                st.code(traceback.format_exc())
-                st.stop()
+    st.sidebar.subheader("🔍 表格结构分析")
+    
+    # 🔑 根据工具类型显示不同的参数
+    if tool.lower() == "mineru":
+        st.sidebar.info("📋 MinerU 格式:直接使用 table_cells 生成结构")
+        
+        if st.sidebar.button("🚀 生成表格结构", type="primary"):
+            with st.spinner("正在分析表格结构..."):
+                try:
+                    # 🔑 MinerU 格式:从原始 JSON 重新解析
+                    current_catalog = st.session_state.get('current_catalog', [])
+                    current_index = st.session_state.get('current_catalog_index', 0)
+                    
+                    if not current_catalog or current_index >= len(current_catalog):
+                        st.error("❌ 未找到当前文件")
+                        return None
+                    
+                    entry = current_catalog[current_index]
+                    
+                    # 加载原始 JSON
+                    with open(entry["json"], "r", encoding="utf-8") as fp:
+                        raw = json.load(fp)
+                    
+                    # 重新解析获取完整结构
+                    from .data_processor import get_structure_from_ocr
+                    
+                    table_bbox, structure = get_structure_from_ocr(raw, tool)
+                    
+                    # 保存到 session_state
+                    st.session_state.structure = structure
+                    st.session_state.table_bbox = table_bbox
+                    st.session_state.undo_stack = []
+                    st.session_state.redo_stack = []
+                    
+                    # 清除缓存的图片
+                    from .drawing import clear_table_image_cache
+                    clear_table_image_cache()
+                    
+                    st.success(
+                        f"✅ 表格结构生成成功!\n\n"
+                        f"检测到 {structure['total_rows']} 行,{structure['total_cols']} 列"
+                    )
+                    return structure
+                    
+                except Exception as e:
+                    st.error(f"❌ 分析失败: {e}")
+                    import traceback
+                    with st.expander("🔍 详细错误"):
+                        st.code(traceback.format_exc())
+    
+    else:
+        # 🔑 PPStructure V3 格式:使用参数调整
+        y_tolerance = st.sidebar.slider(
+            "Y轴聚类容差(行检测)",
+            min_value=1,
+            max_value=20,
+            value=5,
+            help="相邻文本框Y坐标差小于此值时合并为同一行"
+        )
+        
+        x_tolerance = st.sidebar.slider(
+            "X轴聚类容差(列检测)",
+            min_value=5,
+            max_value=30,
+            value=10,
+            help="相邻文本框X坐标差小于此值时合并为同一列"
+        )
+        
+        min_row_height = st.sidebar.slider(
+            "最小行高",
+            min_value=10,
+            max_value=50,
+            value=20,
+            help="行高小于此值的将被过滤"
+        )
+        
+        if st.sidebar.button("🚀 分析表格结构", type="primary"):
+            with st.spinner("正在分析表格结构..."):
+                try:
+                    structure = generator.analyze_table_structure(
+                        y_tolerance=y_tolerance,
+                        x_tolerance=x_tolerance,
+                        min_row_height=min_row_height
+                    )
+                    
+                    st.session_state.structure = structure
+                    st.session_state.undo_stack = []
+                    st.session_state.redo_stack = []
+                    
+                    # 清除缓存的图片
+                    from .drawing import clear_table_image_cache
+                    clear_table_image_cache()
+                    
+                    st.success(
+                        f"✅ 分析完成!\n\n"
+                        f"检测到 {len(structure['rows'])} 行,{len(structure['columns'])} 列"
+                    )
+                    return structure
+                    
+                except Exception as e:
+                    st.error(f"❌ 分析失败: {e}")
+                    import traceback
+                    with st.expander("🔍 详细错误"):
+                        st.code(traceback.format_exc())
+    
+    return None

+ 77 - 47
table_line_generator/editor/data_processor.py

@@ -1,53 +1,83 @@
-import streamlit as st
-import json
-
-# 当直接运行时
+"""
+OCR 数据处理
+"""
 import sys
 from pathlib import Path
+from typing import List, Dict, Tuple
+
+# 添加父目录到路径
 sys.path.insert(0, str(Path(__file__).parent.parent))
-from table_line_generator import TableLineGenerator  # 上级目录
 
-def parse_ocr_data(ocr_data):
-    """解析OCR数据,支持多种格式"""
-    # 如果是字符串,尝试解析
-    if isinstance(ocr_data, str):
-        try:
-            ocr_data = json.loads(ocr_data)
-        except json.JSONDecodeError:
-            st.error("❌ JSON 格式错误,无法解析")
-            return []
-    
-    # 检查是否为 PPStructure V3 格式
-    if isinstance(ocr_data, dict) and 'parsing_res_list' in ocr_data and 'overall_ocr_res' in ocr_data:
-        st.info("🔍 检测到 PPStructure V3 格式")
+try:
+	from table_line_generator import TableLineGenerator
+except ImportError:
+	from ..table_line_generator import TableLineGenerator
+
+
+def parse_ocr_data(raw_data: Dict, tool: str = "ppstructv3") -> Tuple[List[int], List[Dict]]:
+    """
+    解析 OCR 数据(根据工具类型自动选择解析方法)
+    
+    Args:
+        raw_data: 原始 OCR 结果
+        tool: 工具类型 ("ppstructv3" 或 "mineru")
+    
+    Returns:
+        (table_bbox, ocr_data): 表格边界框和文本框列表
+    """
+    if tool.lower() == "mineru":
+        # 使用 MinerU 专用解析方法
+        table_bbox, structure = TableLineGenerator.parse_mineru_table_result(raw_data)
+        
+        # 🔑 将结构转换为 ocr_data 格式(兼容现有逻辑)
+        ocr_data = []
+        for row in structure['rows']:
+            for bbox in row['bboxes']:
+                ocr_data.append({
+                    'bbox': bbox,
+                    'text': ''  # MinerU 可能没有文本,或需要从 table_cells 提取
+                })
         
-        try:
-            table_bbox, text_boxes = TableLineGenerator.parse_ppstructure_result(ocr_data)
-            st.success(f"✅ 表格区域: {table_bbox}")
-            st.success(f"✅ 表格内文本框: {len(text_boxes)} 个")
-            return text_boxes
-        except Exception as e:
-            st.error(f"❌ 解析 PPStructure 结果失败: {e}")
-            return []
-    
-    # 确保是列表
-    if not isinstance(ocr_data, list):
-        st.error(f"❌ OCR 数据应该是列表,实际类型: {type(ocr_data)}")
-        return []
-    
-    if not ocr_data:
-        st.warning("⚠️ OCR 数据为空")
-        return []
-    
-    first_item = ocr_data[0]
-    if not isinstance(first_item, dict):
-        st.error(f"❌ OCR 数据项应该是字典,实际类型: {type(first_item)}")
-        return []
-    
-    if 'bbox' not in first_item:
-        st.error("❌ OCR 数据缺少 'bbox' 字段")
-        st.info("💡 支持的格式示例:\n```json\n[\n  {\n    \"text\": \"文本\",\n    \"bbox\": [x1, y1, x2, y2]\n  }\n]\n```")
-        return []
-    
-    return ocr_data
+        return table_bbox, ocr_data
+    
+    elif tool.lower() == "ppstructv3":
+        # 🔑 PPStructure V3 格式
+        return TableLineGenerator.parse_ppstructure_result(raw_data)
+    
+    else:
+        raise ValueError(f"不支持的工具类型: {tool},支持的类型: ppstructv3, mineru")
+
+
+def get_structure_from_ocr(
+    raw_data: Dict, 
+    tool: str = "ppstructv3"
+) -> Tuple[List[int], Dict]:
+    """
+    从 OCR 数据直接生成表格结构
+    
+    Args:
+        raw_data: 原始 OCR 结果
+        tool: 工具类型
+    
+    Returns:
+        (table_bbox, structure): 表格边界框和结构信息
+    """
+    if tool.lower() == "mineru":
+        # 🔑 MinerU:直接传入完整 JSON,方法内部会提取 table
+        return TableLineGenerator.parse_mineru_table_result(raw_data)
+    
+    elif tool.lower() == "ppstructv3" or tool.lower() == "ppstructure":
+        # 🔑 PPStructure V3:需要先解析再分析
+        table_bbox, ocr_data = TableLineGenerator.parse_ppstructure_result(raw_data)
+        
+        # 使用临时生成器分析结构
+        from PIL import Image
+        dummy_image = Image.new('RGB', (2000, 3000), 'white')
+        generator = TableLineGenerator(dummy_image, ocr_data)
+        structure = generator.analyze_table_structure()
+        
+        return table_bbox, structure
+    
+    else:
+        raise ValueError(f"不支持的工具类型: {tool}")
 

+ 36 - 5
table_line_generator/editor/directory_selector.py

@@ -8,7 +8,7 @@ from PIL import Image
 from typing import Dict, List
 
 from .config_loader import load_structure_from_config, build_data_source_catalog
-from .data_processor import parse_ocr_data
+from .data_processor import parse_ocr_data, get_structure_from_ocr
 from .drawing import clear_table_image_cache
 
 
@@ -46,6 +46,11 @@ def create_directory_selector(
     output_dir = Path(output_cfg.get("directory", "output/table_structures"))
     structure_suffix = output_cfg.get("structure_suffix", "_structure.json")
     
+    # 🔑 获取工具类型
+    tool = source_cfg.get("tool", "ppstructv3")
+    st.session_state.current_tool = tool
+    st.sidebar.info(f"🔧 工具: {tool.upper()}")
+    
     # 构建/缓存目录清单
     catalog_key = f"catalog::{selected_name}"
     if catalog_key not in st.session_state:
@@ -93,7 +98,8 @@ def create_directory_selector(
             catalog[selected], 
             output_dir, 
             structure_suffix, 
-            current_entry_key
+            current_entry_key,
+            tool  # 🔑 传入工具类型
         )
     
     # 页码跳转处理
@@ -109,8 +115,23 @@ def create_directory_selector(
     return st.session_state.get('dir_auto_mode', 'new')
 
 
-def _load_catalog_entry(entry: Dict, output_dir: Path, structure_suffix: str, entry_key: str):
-    """加载目录条目(JSON + 图片 + 结构)"""
+def _load_catalog_entry(
+    entry: Dict, 
+    output_dir: Path, 
+    structure_suffix: str, 
+    entry_key: str,
+    tool: str = "ppstructv3"  # 🔑 新增参数
+):
+    """
+    加载目录条目(JSON + 图片 + 结构)
+    
+    Args:
+        entry: 目录条目
+        output_dir: 输出目录
+        structure_suffix: 结构文件后缀
+        entry_key: 条目唯一键
+        tool: 工具类型
+    """
     base_name = entry["json"].stem
     structure_file = output_dir / f"{base_name}{structure_suffix}"
     has_structure = structure_file.exists()
@@ -119,10 +140,20 @@ def _load_catalog_entry(entry: Dict, output_dir: Path, structure_suffix: str, en
     try:
         with open(entry["json"], "r", encoding="utf-8") as fp:
             raw = json.load(fp)
-        st.session_state.ocr_data = parse_ocr_data(raw)
+        
+        # 🔑 根据工具类型解析数据
+        table_bbox, ocr_data = parse_ocr_data(raw, tool)
+        
+        st.session_state.ocr_data = ocr_data
+        st.session_state.table_bbox = table_bbox
         st.session_state.loaded_json_name = entry["json"].name
+        st.info(f"🔧 使用 {tool.upper()} 解析 JSON")
+        
     except Exception as e:
         st.error(f"❌ 加载 JSON 失败: {e}")
+        import traceback
+        with st.expander("🔍 详细错误"):
+            st.code(traceback.format_exc())
         return
 
     # 🖼️ 加载图片

+ 57 - 65
table_line_generator/editor/mode_setup.py

@@ -4,6 +4,11 @@
 import streamlit as st
 from PIL import Image
 from typing import Dict, Tuple
+import sys
+from pathlib import Path
+
+# 添加父目录到路径
+sys.path.insert(0, str(Path(__file__).parent.parent))
 
 try:
     from ..table_line_generator import TableLineGenerator
@@ -14,85 +19,72 @@ from .display_controls import create_display_settings_section, create_undo_redo_
 from .analysis_controls import create_analysis_section
 
 
-def setup_new_annotation_mode(ocr_data, image, config: Dict) -> Tuple:
+def setup_new_annotation_mode(
+    ocr_data: list,
+    image: Image.Image,
+    display_config: Dict
+) -> Tuple:
     """
-    设置新建标注模式的通用逻辑
-    
-    Args:
-        ocr_data: OCR 数据
-        image: 图片对象
-        config: 显示配置
+    设置新建标注模式
     
     Returns:
-        tuple: (y_tolerance, x_tolerance, min_row_height, line_width, 
-                display_mode, zoom_level, show_line_numbers)
+        (generator, structure, undo_stack, line_width, display_mode, zoom_level, show_line_numbers)
     """
-    # 参数调整
-    st.sidebar.header("🔧 参数调整")
-    y_tolerance = st.sidebar.slider(
-        "Y轴聚类容差(像素)", 
-        1, 20, 5, 
-        key="new_y_tol"
-    )
-    x_tolerance = st.sidebar.slider(
-        "X轴聚类容差(像素)", 
-        5, 50, 10, 
-        key="new_x_tol"
-    )
-    min_row_height = st.sidebar.slider(
-        "最小行高(像素)", 
-        10, 100, 20, 
-        key="new_min_h"
-    )
-    
-    # 显示设置
-    line_width, display_mode, zoom_level, show_line_numbers = \
-        create_display_settings_section(config)
-    create_undo_redo_section()
+    # 🔑 获取当前工具类型
+    tool = st.session_state.get('current_tool', 'ppstructv3')
     
     # 初始化生成器
-    if 'generator' not in st.session_state or st.session_state.generator is None:
-        try:
-            generator = TableLineGenerator(image, ocr_data)
-            st.session_state.generator = generator
-        except Exception as e:
-            st.error(f"❌ 初始化生成器失败: {e}")
-            st.stop()
+    if 'generator' not in st.session_state:
+        st.session_state.generator = TableLineGenerator(image, ocr_data)
     
-    # 分析按钮
-    create_analysis_section(y_tolerance, x_tolerance, min_row_height)
+    # 分析控件
+    structure = create_analysis_section(
+        st.session_state.generator,
+        tool=tool  # 🔑 传入工具类型
+    )
+    
+    # 显示控件
+    line_width, display_mode, zoom_level, show_line_numbers = create_display_settings_section(
+        display_config
+    )
+    
+    # 撤销/重做
+    undo_stack = []
     
-    return (y_tolerance, x_tolerance, min_row_height, 
-            line_width, display_mode, zoom_level, show_line_numbers)
+    return (
+        st.session_state.generator,
+        structure,
+        undo_stack,
+        line_width,
+        display_mode,
+        zoom_level,
+        show_line_numbers
+    )
 
 
-def setup_edit_annotation_mode(structure: Dict, image, config: Dict) -> Tuple:
+def setup_edit_annotation_mode(
+    structure: Dict,
+    image: Image.Image,
+    display_config: Dict
+) -> Tuple:
     """
-    设置编辑标注模式的通用逻辑
-    
-    Args:
-        structure: 表格结构
-        image: 图片对象(可为 None)
-        config: 显示配置
+    设置编辑标注模式
     
     Returns:
-        tuple: (image, line_width, display_mode, zoom_level, show_line_numbers)
+        (image, line_width, display_mode, zoom_level, show_line_numbers)
     """
-    # 如果没有图片,创建虚拟画布
-    if image is None:
-        if 'table_bbox' in structure:
-            bbox = structure['table_bbox']
-            dummy_width = bbox[2] + 100
-            dummy_height = bbox[3] + 100
-        else:
-            dummy_width = 2000
-            dummy_height = 2000
-        image = Image.new('RGB', (dummy_width, dummy_height), color='white')
-        st.info(f"💡 使用虚拟画布 ({dummy_width}x{dummy_height})")
+    # 显示控件
+    line_width, display_mode, zoom_level, show_line_numbers = create_display_settings_section(
+        display_config
+    )
     
-    # 显示设置
-    line_width, display_mode, zoom_level, show_line_numbers = \
-        create_display_settings_section(config)
+    # 撤销/重做控件
     create_undo_redo_section()
     
-    return image, line_width, display_mode, zoom_level, show_line_numbers
+    return (
+        image,
+        line_width,
+        display_mode,
+        zoom_level,
+        show_line_numbers
+    )