Browse Source

feat: 优化OCR内容渲染,增强文本选择和高亮显示功能

zhch158_admin 1 month ago
parent
commit
cc9b643fcb
1 changed files with 96 additions and 112 deletions
  1. 96 112
      ocr_validator_layout.py

+ 96 - 112
ocr_validator_layout.py

@@ -178,54 +178,6 @@ class OCRLayoutManager:
             st.error(f"❌ 图像加载失败: {e}")
             st.error(f"❌ 图像加载失败: {e}")
             return None
             return None
 
 
-    def render_content_section(self, layout_type: str = "compact"):
-        """渲染内容区域 - 统一方法"""
-        st.header("📄 OCR识别内容")
-        
-        # 显示旋转信息
-        # rotation_angle = self.get_rotation_angle()
-        # if rotation_angle != 0:
-        #     st.info(f"📐 文档旋转角度: {rotation_angle}°")
-        
-        # 文本选择器
-        if self.validator.text_bbox_mapping:
-            text_options = ["请选择文本..."] + list(self.validator.text_bbox_mapping.keys())
-            selected_index = st.selectbox(
-                "选择要校验的文本",
-                range(len(text_options)),
-                format_func=lambda x: text_options[x][:50] + "..." if len(text_options[x]) > 50 else text_options[x],
-                key=f"{layout_type}_text_selector"
-            )
-            
-            if selected_index > 0:
-                st.session_state.selected_text = text_options[selected_index]
-        else:
-            st.warning("没有找到可点击的文本")
-    
-    def render_md_content(self, layout_type: str):
-        """渲染Markdown内容 - 统一方法"""
-        if not self.validator.md_content:
-            return None, None
-            
-        # 搜索功能
-        search_term = st.text_input(
-            "🔍 搜索文本内容", 
-            placeholder="输入关键词搜索...", 
-            key=f"{layout_type}_search"
-        )
-        
-        display_content = self.validator.md_content
-        if search_term:
-            lines = display_content.split('\n')
-            filtered_lines = [line for line in lines if search_term.lower() in line.lower()]
-            display_content = '\n'.join(filtered_lines)
-            if filtered_lines:
-                st.success(f"找到 {len(filtered_lines)} 行包含 '{search_term}'")
-            else:
-                st.warning(f"未找到包含 '{search_term}' 的内容")
-        
-        return display_content
-
     def render_content_by_mode(self, content: str, render_mode: str, font_size: int, container_height: int, layout_type: str):
     def render_content_by_mode(self, content: str, render_mode: str, font_size: int, container_height: int, layout_type: str):
         """根据渲染模式显示内容 - 增强版本"""
         """根据渲染模式显示内容 - 增强版本"""
         if content is None or render_mode is None:
         if content is None or render_mode is None:
@@ -349,42 +301,49 @@ class OCRLayoutManager:
         # 主要内容区域
         # 主要内容区域
         layout = config['styles']['layout']
         layout = config['styles']['layout']
         font_size = config['styles'].get('font_size', 10)
         font_size = config['styles'].get('font_size', 10)
-        container_height = layout.get('default_height', 600)  # 默认高度
-        zoom_level = layout.get('default_zoom', 1.0)  # 默认缩放级别
+        container_height = layout.get('default_height', 600)
+        zoom_level = layout.get('default_zoom', 1.0)
         layout_type = "compact"
         layout_type = "compact"
 
 
         left_col, right_col = st.columns([layout['content_width'], layout['sidebar_width']], vertical_alignment='top', border=True)
         left_col, right_col = st.columns([layout['content_width'], layout['sidebar_width']], vertical_alignment='top', border=True)
 
 
         with left_col:
         with left_col:
-            # self.render_content_section(layout_type)
-            # 快速定位文本选择器(使用不同的key)
+            # 快速定位文本选择器
             if self.validator.text_bbox_mapping:
             if self.validator.text_bbox_mapping:
-                text_options = ["请选择文本..."] + list(self.validator.text_bbox_mapping.keys())
+                text_options = ["请选择文本..."]
+                text_display = ["请选择文本..."]
+                
+                for text, info_list in self.validator.text_bbox_mapping.items():
+                    text_options.append(text)
+                    display_text = text[:47] + "..." if len(text) > 50 else text
+                    text_display.append(display_text)
+                
                 selected_index = st.selectbox(
                 selected_index = st.selectbox(
                     "快速定位文本",
                     "快速定位文本",
                     range(len(text_options)),
                     range(len(text_options)),
-                    format_func=lambda x: text_options[x][:30] + "..." if len(text_options[x]) > 30 else text_options[x],
+                    format_func=lambda x: text_display[x],
                     label_visibility="collapsed",
                     label_visibility="collapsed",
-                    key="compact_quick_text_selector"  # 使用不同的key
+                    key="compact_quick_text_selector"
                 )
                 )
                 
                 
                 if selected_index > 0:
                 if selected_index > 0:
                     st.session_state.selected_text = text_options[selected_index]
                     st.session_state.selected_text = text_options[selected_index]
             
             
-            # 处理并显示OCR内容
+            # 处理并显示OCR内容 - 只高亮选中的文本
             if self.validator.md_content:
             if self.validator.md_content:
-                # 高亮可点击文本
                 highlighted_content = self.validator.md_content
                 highlighted_content = self.validator.md_content
-                for text in self.validator.text_bbox_mapping.keys():
-                    if len(text) > 2:  # 避免高亮过短的文本
-                        css_class = "highlight-text selected-highlight" if text == st.session_state.selected_text else "highlight-text"
+                
+                # 只高亮选中的文本
+                if st.session_state.selected_text:
+                    selected_text = st.session_state.selected_text
+                    if len(selected_text) > 2:
                         highlighted_content = highlighted_content.replace(
                         highlighted_content = highlighted_content.replace(
-                            text, 
-                            # f'<span class="{css_class}" title="{text[:50]}...">{text}</span>'
-                            f'<span class="{css_class}" title="{text}">{text}</span>'
+                            selected_text,
+                            f'<span class="highlight-text selected-highlight" title="{selected_text}">{selected_text}</span>'
                         )
                         )
+                
                 self.render_content_by_mode(highlighted_content, "HTML渲染", font_size, container_height, layout_type)
                 self.render_content_by_mode(highlighted_content, "HTML渲染", font_size, container_height, layout_type)
-            
+    
         with right_col:
         with right_col:
             # 修复的对齐图片显示
             # 修复的对齐图片显示
             self.create_aligned_image_display(zoom_level, "compact")
             self.create_aligned_image_display(zoom_level, "compact")
@@ -445,9 +404,9 @@ class OCRLayoutManager:
         
         
         if image:
         if image:
             try:
             try:
-                resized_image, all_boxes, selected_bbox = self.zoom_image(image, self.zoom_level)
+                resized_image, all_boxes, selected_boxes = self.zoom_image(image, self.zoom_level)
                 # 创建交互式图片
                 # 创建交互式图片
-                fig = self.create_resized_interactive_plot(resized_image, selected_bbox, self.zoom_level, all_boxes)
+                fig = self.create_resized_interactive_plot(resized_image, selected_boxes, self.zoom_level, all_boxes)
 
 
                 plot_config = {
                 plot_config = {
                     'displayModeBar': True,
                     'displayModeBar': True,
@@ -482,7 +441,7 @@ class OCRLayoutManager:
 
 
     # st.markdown('</div>', unsafe_allow_html=True)
     # st.markdown('</div>', unsafe_allow_html=True)
 
 
-    def zoom_image(self, image: Image.Image, current_zoom: float) -> Tuple[Image.Image, List[List[int]], Optional[List[int]]]:
+    def zoom_image(self, image: Image.Image, current_zoom: float) -> Tuple[Image.Image, List[List[int]], List[List[int]]]:
         """缩放图像"""
         """缩放图像"""
         # 根据缩放级别调整图片大小
         # 根据缩放级别调整图片大小
         new_width = int(image.width * current_zoom)
         new_width = int(image.width * current_zoom)
@@ -490,11 +449,14 @@ class OCRLayoutManager:
         resized_image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
         resized_image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
 
 
         # 计算选中的bbox
         # 计算选中的bbox
-        selected_bbox = None
+        selected_boxes = []
         if st.session_state.selected_text and st.session_state.selected_text in self.validator.text_bbox_mapping:
         if st.session_state.selected_text and st.session_state.selected_text in self.validator.text_bbox_mapping:
-            info = self.validator.text_bbox_mapping[st.session_state.selected_text][0]
-            bbox = info['bbox']
-            selected_bbox = [int(coord * current_zoom) for coord in bbox]
+            info_list = self.validator.text_bbox_mapping[st.session_state.selected_text]
+            for info in info_list:
+                if 'bbox' in info:
+                    bbox = info['bbox']
+                    selected_box = [int(coord * current_zoom) for coord in bbox]
+                    selected_boxes.append(selected_box)
 
 
         # 收集所有框
         # 收集所有框
         all_boxes = []
         all_boxes = []
@@ -506,9 +468,49 @@ class OCRLayoutManager:
                         scaled_bbox = [coord * current_zoom for coord in bbox]
                         scaled_bbox = [coord * current_zoom for coord in bbox]
                         all_boxes.append(scaled_bbox)
                         all_boxes.append(scaled_bbox)
 
 
-        return resized_image, all_boxes, selected_bbox
+        return resized_image, all_boxes, selected_boxes
+
+    def _add_bboxes_to_plot(self, fig: go.Figure, bboxes: List[List[int]], image_height: int, 
+                           line_color: str = "blue", line_width: int = 1, 
+                           fill_color: str = "rgba(0, 100, 200, 0.2)"):
+        """
+        在plotly图表上添加边界框
         
         
-    def create_resized_interactive_plot(self, image: Image.Image, selected_bbox: Optional[List[int]], zoom_level: float, all_boxes: List[List[int]]) -> go.Figure:
+        Args:
+            fig: plotly图表对象
+            bboxes: 边界框列表,每个bbox格式为[x1, y1, x2, y2]
+            image_height: 图片高度,用于Y轴坐标转换
+            line_color: 边框线颜色
+            line_width: 边框线宽度
+            fill_color: 填充颜色(RGBA格式)
+        """
+        if not bboxes or len(bboxes) == 0:
+            return
+            
+        for bbox in bboxes:
+            if len(bbox) < 4:
+                continue
+                
+            x1, y1, x2, y2 = bbox[:4]
+            
+            # 转换为Plotly坐标系(翻转Y轴)
+            # JSON格式: 原点在左上角, y向下增加
+            # Plotly格式: 原点在左下角, y向上增加
+            plot_x1 = x1
+            plot_x2 = x2
+            plot_y1 = image_height - y2  # JSON的y2(底部) -> Plotly的底部
+            plot_y2 = image_height - y1  # JSON的y1(顶部) -> Plotly的顶部
+            
+            fig.add_shape(
+                type="rect",
+                x0=plot_x1, y0=plot_y1,
+                x1=plot_x2, y1=plot_y2,
+                line=dict(color=line_color, width=line_width),
+                fillcolor=fill_color,
+            )
+
+    def create_resized_interactive_plot(self, image: Image.Image, selected_boxes: List[List[int]], 
+                                       zoom_level: float, all_boxes: List[List[int]]) -> go.Figure:
         """创建可调整大小的交互式图片 - 修复容器溢出问题"""
         """创建可调整大小的交互式图片 - 修复容器溢出问题"""
         fig = go.Figure()
         fig = go.Figure()
         
         
@@ -527,44 +529,26 @@ class OCRLayoutManager:
             )
             )
         )
         )
         
         
-        # 显示所有bbox - 需要坐标转换
-        if len(all_boxes) > 0:
-            for bbox in all_boxes:
-                if len(bbox) >= 4:
-                    x1, y1, x2, y2 = bbox[:4]
-                    
-                    # 转换为Plotly坐标系(翻转Y轴)
-                    plot_x1 = x1
-                    plot_x2 = x2
-                    plot_y1 = image.height - y2  # JSON的y2 -> Plotly的底部
-                    plot_y2 = image.height - y1  # JSON的y1 -> Plotly的顶部
-                    
-                    color = "rgba(0, 100, 200, 0.2)"
-                    
-                    fig.add_shape(
-                        type="rect",
-                        x0=plot_x1, y0=plot_y1,
-                        x1=plot_x2, y1=plot_y2,
-                        line=dict(color="blue", width=1),
-                        fillcolor=color,
-                    )
+        # 显示所有bbox(淡蓝色)
+        if all_boxes:
+            self._add_bboxes_to_plot(
+                fig=fig,
+                bboxes=all_boxes,
+                image_height=image.height,
+                line_color="blue",
+                line_width=1,
+                fill_color="rgba(0, 100, 200, 0.2)"
+            )
 
 
-        # 高亮显示选中的bbox
-        if selected_bbox and len(selected_bbox) >= 4:
-            x1, y1, x2, y2 = selected_bbox[:4]
-            
-            # 转换为Plotly坐标系
-            plot_x1 = x1
-            plot_x2 = x2
-            plot_y1 = image.height - y2  # 翻转Y坐坐标
-            plot_y2 = image.height - y1  # 翻转Y坐标
-            
-            fig.add_shape(
-                type="rect",
-                x0=plot_x1, y0=plot_y1,
-                x1=plot_x2, y1=plot_y2,
-                line=dict(color="red", width=3),
-                fillcolor="rgba(255, 0, 0, 0.3)",
+        # 高亮显示选中的bbox(红色)
+        if selected_boxes:
+            self._add_bboxes_to_plot(
+                fig=fig,
+                bboxes=selected_boxes,
+                image_height=image.height,
+                line_color="red",
+                line_width=3,
+                fill_color="rgba(255, 0, 0, 0.3)"
             )
             )
     
     
         # 修复:优化显示尺寸计算
         # 修复:优化显示尺寸计算
@@ -593,8 +577,8 @@ class OCRLayoutManager:
         
         
         # 设置布局 - 关键修改
         # 设置布局 - 关键修改
         fig.update_layout(
         fig.update_layout(
-            width=display_width,    # 注释掉固定宽度
-            height=display_height,  # 注释掉固定高度
+            width=display_width,
+            height=display_height,
             
             
             margin=dict(l=0, r=0, t=0, b=0),
             margin=dict(l=0, r=0, t=0, b=0),
             showlegend=False,
             showlegend=False,