Эх сурвалжийг харах

feat(增强OCR布局管理): 在ocr_validator_layout.py中新增类别颜色转换函数,优化边界框收集逻辑,支持按类别着色,提升可视化效果与用户体验。

zhch158_admin 1 сар өмнө
parent
commit
87b0f0a6e8

+ 127 - 39
ocr_validator/ocr_validator_layout.py

@@ -7,7 +7,7 @@ OCR验证工具的布局管理模块
 import streamlit as st
 from pathlib import Path
 from PIL import Image
-from typing import Dict, List, Optional
+from typing import Any, Dict, List, Optional
 import plotly.graph_objects as go
 from typing import Tuple
 import re
@@ -29,6 +29,10 @@ if str(ocr_platform_root) not in sys.path:
 # 从 ocr_utils 导入通用工具
 from ocr_utils.html_utils import convert_html_table_to_markdown, parse_html_tables
 from ocr_utils.visualization_utils import VisualizationUtils
+from ocr_utils.module_debug_viz import (
+    OCR_BOX_LINE_THICKNESS,
+    ocr_box_color_rgb,
+)
 
 # BeautifulSoup用于精确HTML表格处理
 from bs4 import BeautifulSoup
@@ -39,6 +43,27 @@ from ocr_validator_file_utils import load_css_styles
 # 为了向后兼容,提供函数别名
 draw_bbox_on_image = VisualizationUtils.draw_bbox_on_image
 
+
+def category_to_plotly_rgba(category: str, alpha: float = 0.85) -> str:
+    """将 VisualizationUtils.COLOR_MAP 中的 RGB 转为 Plotly 线条颜色。"""
+    rgb = VisualizationUtils.COLOR_MAP.get(category)
+    if rgb is None:
+        rgb = (128, 128, 128)
+    r, g, b = rgb
+    return f"rgba({r}, {g}, {b}, {alpha})"
+
+
+def ocr_box_plotly_rgba(alpha: float = 0.85) -> str:
+    """OCR 亮蓝(与 module_debug_viz / *_ocr_spans 一致)。"""
+    r, g, b = ocr_box_color_rgb()
+    return f"rgba({r}, {g}, {b}, {alpha})"
+
+
+# 仅 layout 结构框按类别着色;其余按 OCR 亮蓝实线/虚线
+LAYOUT_STRUCTURE_CATEGORIES = frozenset({
+    'table_body', 'table', 'image_body', 'image', 'figure', 'chart',
+})
+
 # detect_image_orientation_by_opencv 保留在 ocr_validator_file_utils
 from ocr_validator_file_utils import detect_image_orientation_by_opencv
 
@@ -636,6 +661,14 @@ class OCRLayoutManager:
                             match_type = "exact"
                         else:
                             match_type = "no_bbox"
+
+                        if info_list[0].get('category') == 'seal':
+                            conf = info_list[0].get('confidence', 0)
+                            method = info_list[0].get('recognition_method', '')
+                            hint = f"🔖 **印章** | 置信度 {conf:.2f}"
+                            if method:
+                                hint += f" | 识别方式 `{method}`"
+                            st.info(hint)
                     
                     # 🎯 应用高亮
                     if len(selected_text) >= self.config.get('ocr', {}).get('min_text_length', 2):
@@ -772,8 +805,10 @@ class OCRLayoutManager:
 
     # st.markdown('</div>', unsafe_allow_html=True)
 
-    def zoom_image(self, image: Image.Image, current_zoom: float) -> Tuple[Image.Image, List[List[int]], List[List[int]]]:
-        """缩放图像"""
+    def zoom_image(
+        self, image: Image.Image, current_zoom: float
+    ) -> Tuple[Image.Image, List[Dict[str, Any]], List[List[int]]]:
+        """缩放图像;all_boxes 为带 category 的框列表,供按类着色。"""
         # 根据缩放级别调整图片大小
         new_width = int(image.width * current_zoom)
         new_height = int(image.height * current_zoom)
@@ -789,15 +824,19 @@ class OCRLayoutManager:
                     selected_box = [int(coord * current_zoom) for coord in bbox]
                     selected_boxes.append(selected_box)
 
-        # 收集所有框
-        all_boxes = []
+        # 收集所有框(含类别,用于按类着色)
+        all_boxes: List[Dict[str, Any]] = []
         if self.show_all_boxes:
             for text, info_list in self.validator.text_bbox_mapping.items():
                 for info in info_list:
-                    bbox = info['bbox']
+                    bbox = info.get('bbox', [])
                     if len(bbox) >= 4:
                         scaled_bbox = [coord * current_zoom for coord in bbox]
-                        all_boxes.append(scaled_bbox)
+                        all_boxes.append({
+                            'bbox': scaled_bbox,
+                            'category': info.get('category', 'text'),
+                            'has_text': bool(text and str(text).strip()),
+                        })
 
         return resized_image, all_boxes, selected_boxes
 
@@ -837,48 +876,56 @@ class OCRLayoutManager:
         # 🎯 一次性更新所有形状
         fig.update_layout(shapes=fig.layout.shapes + tuple(shapes))
 
-    def _add_bboxes_as_scatter(self, fig: go.Figure, bboxes: List[List[int]], 
-                          image_height: int,
-                          line_color: str = "blue", 
-                          line_width: int = 2,
-                          name: str = "boxes"):
-        """
-        使用 Scatter 绘制边界框(极致性能优化)
-        """
+    def _add_bboxes_as_scatter(
+        self,
+        fig: go.Figure,
+        bboxes: List[List[int]],
+        image_height: int,
+        line_color: str = "blue",
+        line_width: int = 2,
+        name: str = "boxes",
+        *,
+        dashed: bool = False,
+    ):
+        """使用 Scatter 绘制边界框(极致性能优化)。"""
         if not bboxes or len(bboxes) == 0:
             return
-        
-        # 🎯 收集所有矩形的边框线坐标
+
         x_coords = []
         y_coords = []
-        
+
         for bbox in bboxes:
             if len(bbox) < 4:
                 continue
-            
+
             x1, y1, x2, y2 = bbox[:4]
-            
-            # 转换坐标
             plot_y1 = image_height - y2
             plot_y2 = image_height - y1
-            
-            # 绘制矩形:5个点(闭合)
-            x_coords.extend([x1, x2, x2, x1, x1, None])  # None用于断开线段
+            x_coords.extend([x1, x2, x2, x1, x1, None])
             y_coords.extend([plot_y1, plot_y1, plot_y2, plot_y2, plot_y1, None])
-        
-        # 🎯 一次性添加所有边框
+
+        line_style = dict(
+            color=line_color,
+            width=line_width,
+            dash='dash' if dashed else 'solid',
+        )
         fig.add_trace(go.Scatter(
             x=x_coords,
             y=y_coords,
             mode='lines',
-            line=dict(color=line_color, width=line_width),
+            line=line_style,
             name=name,
             showlegend=False,
-            hoverinfo='skip'
+            hoverinfo='skip',
         ))
 
-    def create_resized_interactive_plot(self, image: Image.Image, selected_boxes: List[List[int]], 
-                                       zoom_level: float, all_boxes: List[List[int]]) -> go.Figure:
+    def create_resized_interactive_plot(
+        self,
+        image: Image.Image,
+        selected_boxes: List[List[int]],
+        zoom_level: float,
+        all_boxes: List[Dict[str, Any]],
+    ) -> go.Figure:
         """创建可调整大小的交互式图片 - 修复容器溢出问题"""
         fig = go.Figure()
         
@@ -897,16 +944,57 @@ class OCRLayoutManager:
             )
         )
         
-        # 显示所有bbox(淡蓝色)
+        # 显示所有框:layout 结构按 COLOR_MAP;OCR 文字亮蓝实线/无文字虚线
         if all_boxes:
-            self._add_bboxes_as_scatter(
-                fig=fig,
-                bboxes=all_boxes,
-                image_height=image.height,
-                line_color="rgba(0, 100, 200, 0.8)",
-                line_width=2,
-                name="all_boxes"
-            )
+            layout_by_category: Dict[str, List[List[float]]] = {}
+            ocr_solid: List[List[float]] = []
+            ocr_dashed: List[List[float]] = []
+            ocr_color = ocr_box_plotly_rgba()
+            ocr_width = OCR_BOX_LINE_THICKNESS
+
+            for box_item in all_boxes:
+                cat = box_item.get('category', 'text')
+                bbox = box_item.get('bbox', [])
+                if len(bbox) < 4:
+                    continue
+                if cat in LAYOUT_STRUCTURE_CATEGORIES:
+                    layout_by_category.setdefault(cat, []).append(bbox)
+                else:
+                    if box_item.get('has_text', True):
+                        ocr_solid.append(bbox)
+                    else:
+                        ocr_dashed.append(bbox)
+
+            for cat, bboxes in layout_by_category.items():
+                line_width = 4 if cat == 'seal' else 2
+                self._add_bboxes_as_scatter(
+                    fig=fig,
+                    bboxes=bboxes,
+                    image_height=image.height,
+                    line_color=category_to_plotly_rgba(cat),
+                    line_width=line_width,
+                    name=f"all_{cat}",
+                )
+            if ocr_solid:
+                self._add_bboxes_as_scatter(
+                    fig=fig,
+                    bboxes=ocr_solid,
+                    image_height=image.height,
+                    line_color=ocr_color,
+                    line_width=ocr_width,
+                    name="ocr_text",
+                    dashed=False,
+                )
+            if ocr_dashed:
+                self._add_bboxes_as_scatter(
+                    fig=fig,
+                    bboxes=ocr_dashed,
+                    image_height=image.height,
+                    line_color=ocr_color,
+                    line_width=ocr_width,
+                    name="ocr_detect_only",
+                    dashed=True,
+                )
 
         # 高亮显示选中的bbox(红色)
         if selected_boxes: