Parcourir la source

新增图像旋转配置功能,优化图像缓存管理,修复交互式图像显示问题

zhch158_admin il y a 2 mois
Parent
commit
ea3b4469a7
1 fichiers modifiés avec 215 ajouts et 88 suppressions
  1. 215 88
      ocr_validator_layout.py

+ 215 - 88
ocr_validator_layout.py

@@ -9,12 +9,14 @@ from pathlib import Path
 from PIL import Image
 from typing import Dict, List, Optional
 import plotly.graph_objects as go
+from typing import Tuple
 
 from ocr_validator_utils import (
     convert_html_table_to_markdown, 
     parse_html_tables,
     draw_bbox_on_image,
-    rotate_image_and_coordinates
+    rotate_image_and_coordinates,
+    get_ocr_tool_rotation_config  # 新增导入
 )
 
 
@@ -25,6 +27,32 @@ class OCRLayoutManager:
         self.validator = validator
         self.config = validator.config
         self._rotated_image_cache = {}  # 缓存旋转后的图像
+        self._cache_max_size = 10  # 最大缓存数量
+    
+    def clear_image_cache(self):
+        """清理所有图像缓存"""
+        self._rotated_image_cache.clear()
+        
+    def clear_cache_for_image(self, image_path: str):
+        """清理指定图像的所有缓存"""
+        keys_to_remove = [key for key in self._rotated_image_cache.keys() if key.startswith(image_path)]
+        for key in keys_to_remove:
+            del self._rotated_image_cache[key]
+    
+    def get_cache_info(self) -> dict:
+        """获取缓存信息"""
+        return {
+            'cache_size': len(self._rotated_image_cache),
+            'cached_images': list(self._rotated_image_cache.keys()),
+            'max_size': self._cache_max_size
+        }
+    
+    def _manage_cache_size(self):
+        """管理缓存大小,超出限制时清理最旧的缓存"""
+        if len(self._rotated_image_cache) > self._cache_max_size:
+            # 删除最旧的缓存项(FIFO策略)
+            oldest_key = next(iter(self._rotated_image_cache))
+            del self._rotated_image_cache[oldest_key]
     
     def get_rotation_angle(self) -> float:
         """从OCR数据中获取旋转角度"""
@@ -51,37 +79,64 @@ class OCRLayoutManager:
             
             # 如果需要旋转
             if rotation_angle != 0:
-                st.info(f"🔄 检测到文档旋转角度: {rotation_angle}°,正在自动旋转图像...")
-                
-                # 收集所有bbox坐标
-                all_bboxes = []
-                text_to_bbox_map = {}  # 记录文本到bbox索引的映射
+                # 获取OCR工具的旋转配置
+                rotation_config = get_ocr_tool_rotation_config(self.validator.ocr_data, self.config)
                 
-                bbox_index = 0
-                for text, info_list in self.validator.text_bbox_mapping.items():
-                    text_to_bbox_map[text] = []
-                    for info in info_list:
-                        all_bboxes.append(info['bbox'])
-                        text_to_bbox_map[text].append(bbox_index)
-                        bbox_index += 1
+                st.info(f"🔄 检测到文档旋转角度: {rotation_angle}°,正在处理图像和坐标...")
+                st.info(f"📋 OCR工具配置: 坐标{'已预旋转' if rotation_config['coordinates_are_pre_rotated'] else '需要旋转'}")
                 
-                # 旋转图像和坐标
-                rotated_image, rotated_bboxes = rotate_image_and_coordinates(
-                    image, rotation_angle, all_bboxes
-                )
-                
-                # 更新bbox映射 - 使用映射关系确保正确对应
-                for text, bbox_indices in text_to_bbox_map.items():
-                    for i, bbox_idx in enumerate(bbox_indices):
-                        if bbox_idx < len(rotated_bboxes) and i < len(self.validator.text_bbox_mapping[text]):
-                            self.validator.text_bbox_mapping[text][i]['bbox'] = rotated_bboxes[bbox_idx]
-                
-                # 缓存结果
-                self._rotated_image_cache[cache_key] = rotated_image
-                return rotated_image
+                # 判断是否需要旋转坐标
+                if rotation_config['coordinates_are_pre_rotated']:
+                    # PPStructV3: 坐标已经是旋转后的,只旋转图像
+                    if rotation_angle == 270:
+                        rotated_image = image.rotate(-90, expand=True)  # 顺时针90度
+                    elif rotation_angle == 90:
+                        rotated_image = image.rotate(90, expand=True)   # 逆时针90度
+                    elif rotation_angle == 180:
+                        rotated_image = image.rotate(180, expand=True)  # 180度
+                    else:
+                        rotated_image = image.rotate(-rotation_angle, expand=True)
+                    
+                    # 坐标不需要变换,因为JSON中已经是正确的坐标
+                    self._rotated_image_cache[cache_key] = rotated_image
+                    self._manage_cache_size()
+                    return rotated_image
+                    
+                else:
+                    # Dots OCR: 需要同时旋转图像和坐标
+                    # 收集所有bbox坐标
+                    all_bboxes = []
+                    text_to_bbox_map = {}  # 记录文本到bbox索引的映射
+                    
+                    bbox_index = 0
+                    for text, info_list in self.validator.text_bbox_mapping.items():
+                        text_to_bbox_map[text] = []
+                        for info in info_list:
+                            all_bboxes.append(info['bbox'])
+                            text_to_bbox_map[text].append(bbox_index)
+                            bbox_index += 1
+                    
+                    # 旋转图像和坐标
+                    rotated_image, rotated_bboxes = rotate_image_and_coordinates(
+                        image, rotation_angle, all_bboxes, 
+                        rotate_coordinates=rotation_config['coordinates_need_rotation']
+                    )
+                    
+                    # 更新bbox映射 - 使用映射关系确保正确对应
+                    for text, bbox_indices in text_to_bbox_map.items():
+                        for i, bbox_idx in enumerate(bbox_indices):
+                            if bbox_idx < len(rotated_bboxes) and i < len(self.validator.text_bbox_mapping[text]):
+                                self.validator.text_bbox_mapping[text][i]['bbox'] = rotated_bboxes[bbox_idx]
+                    
+                    # 缓存结果
+                    self._rotated_image_cache[cache_key] = rotated_image
+                    self._manage_cache_size()
+                    return rotated_image
+                    
             else:
                 # 无需旋转,直接缓存原图
                 self._rotated_image_cache[cache_key] = image
+                self._manage_cache_size()  # 检查并管理缓存大小
                 return image
                 
         except Exception as e:
@@ -300,7 +355,7 @@ class OCRLayoutManager:
             self.create_aligned_image_display(zoom_level, "compact")
     
     def create_aligned_image_display(self, zoom_level: float = 1.0, layout_type: str = "aligned"):
-        """创建与左侧对齐的图片显示"""
+        """创建与左侧对齐的图片显示 - 修复显示问题"""
         # 精确对齐CSS
         st.markdown(f"""
         <style>
@@ -312,18 +367,25 @@ class OCRLayoutManager:
             margin-top: 0px !important;
             padding-top: 0px !important;
         }}
+        /* 修复:确保Plotly图表容器没有额外边距 */
+        .js-plotly-plot, .plotly {{
+            margin: 0 !important;
+            padding: 0 !important;
+        }}
         </style>
         """, unsafe_allow_html=True)
         
         st.markdown(f'<div class="aligned-image-container-{layout_type}">', unsafe_allow_html=True)
         st.header("🖼️ 原图标注")
         
-        # 图片缩放控制
-        col1, col2 = st.columns(2)
+        # 图片控制选项
+        col1, col2, col3 = st.columns(3)
         with col1:
             current_zoom = st.slider("图片缩放", 0.3, 2.0, zoom_level, 0.1, key=f"{layout_type}_zoom_level")
         with col2:
             show_all_boxes = st.checkbox("显示所有框", value=False, key=f"{layout_type}_show_all_boxes")
+        with col3:
+            fit_to_container = st.checkbox("适应容器", value=True, key=f"{layout_type}_fit_container")
         
         # 使用新的图像加载方法
         image = self.load_and_rotate_image(self.validator.image_path)
@@ -342,10 +404,70 @@ class OCRLayoutManager:
                     # bbox已经是旋转后的坐标,只需要应用缩放
                     bbox = info['bbox']
                     selected_bbox = [int(coord * current_zoom) for coord in bbox]
+
+                # 收集所有框
+                all_boxes = []
+                if show_all_boxes:
+                    for text, info_list in self.validator.text_bbox_mapping.items():
+                        for info in info_list:
+                            bbox = info['bbox']
+                            if len(bbox) >= 4:
+                                scaled_bbox = [coord * current_zoom for coord in bbox]
+                                all_boxes.append(scaled_bbox)
+                
+                # 添加调试信息
+                with st.expander("🔍 图像和坐标调试信息", expanded=False):
+                    rotation_angle = self.get_rotation_angle()
+                    rotation_config = get_ocr_tool_rotation_config(self.validator.ocr_data, self.config)
+                    
+                    col_debug1, col_debug2 = st.columns(2)
+                    with col_debug1:
+                        st.write("**图像信息:**")
+                        st.write(f"原始尺寸: {image.width} x {image.height}")
+                        st.write(f"缩放后尺寸: {resized_image.width} x {resized_image.height}")
+                        st.write(f"旋转角度: {rotation_angle}°")
+                        
+                    with col_debug2:
+                        st.write("**坐标信息:**")
+                        if selected_bbox:
+                            st.write(f"选中框: {selected_bbox}")
+                        st.write(f"总框数: {len(all_boxes)}")
+                        st.write(f"工具配置: {'预旋转' if rotation_config.get('coordinates_are_pre_rotated') else '需旋转'}")
+                    
+                    if st.session_state.selected_text:
+                        info = self.validator.text_bbox_mapping[st.session_state.selected_text][0]
+                        original_bbox = info['bbox']
+                        
+                        # 验证坐标是否在图像范围内
+                        x1, y1, x2, y2 = original_bbox[:4]
+                        in_bounds = (0 <= x1 < image.width and 
+                                   0 <= x2 <= image.width and 
+                                   0 <= y1 < image.height and 
+                                   0 <= y2 <= image.height)
+                        
+                        color = "🟢" if in_bounds else "🔴"
+                        st.write(f"{color} 坐标范围检查: {in_bounds}")
+                        
+                        if not in_bounds:
+                            st.warning("⚠️ 坐标超出图像范围,可能存在坐标系问题")
                 
                 # 创建交互式图片
-                fig = self.create_resized_interactive_plot(resized_image, selected_bbox, current_zoom, show_all_boxes)
-                st.plotly_chart(fig, use_container_width=True, key=f"{layout_type}_plot")
+                fig = self.create_resized_interactive_plot(resized_image, selected_bbox, current_zoom, all_boxes)
+                
+                # 修复:使用合适的配置显示图表
+                plot_config = {
+                    'displayModeBar': True,
+                    'modeBarButtonsToRemove': ['zoom2d', 'select2d', 'lasso2d', 'autoScale2d'],
+                    'scrollZoom': True,
+                    'doubleClick': 'reset'
+                }
+                
+                st.plotly_chart(
+                    fig, 
+                    use_container_width=fit_to_container,
+                    config=plot_config,
+                    key=f"{layout_type}_plot"
+                )
                 
                 # 显示选中文本的详细信息
                 if st.session_state.selected_text and st.session_state.selected_text in self.validator.text_bbox_mapping:
@@ -382,74 +504,66 @@ class OCRLayoutManager:
                         
             except Exception as e:
                 st.error(f"❌ 图片处理失败: {e}")
-                st.error(f"详细错误: {str(e)}")
+                st.exception(e)  # 显示完整的错误堆栈
         else:
             st.error("未找到对应的图片文件")
             if self.validator.image_path:
                 st.write(f"期望路径: {self.validator.image_path}")
+
+    st.markdown('</div>', unsafe_allow_html=True)
     
-        st.markdown('</div>', unsafe_allow_html=True)
-    
-    def create_resized_interactive_plot(self, image: Image.Image, selected_bbox: Optional[List[int]], zoom_level: float, show_all_boxes: bool) -> go.Figure:
-        """创建可调整大小的交互式图片 - 修复图像显示和bbox对齐问题"""
+    def create_resized_interactive_plot(self, image: Image.Image, selected_bbox: Optional[List[int]], zoom_level: float, all_boxes: list[tuple]) -> go.Figure:
+        """
+        创建可调整大小的交互式图片 - 修复图像显示和bbox对齐问题
+        图片,box坐标全部是已缩放,旋转后的坐标
+        """
         fig = go.Figure()
         
-        # 添加图片 - 修正图像定位,确保与工具栏距离一致
+        # 添加图片 - Plotly坐标系,原点在左下角
         fig.add_layout_image(
             dict(
                 source=image,
                 xref="x", yref="y",
-                x=0, y=image.height * zoom_level,  # 修正:图片左上角位置
-                sizex=image.width * zoom_level, 
-                sizey=image.height * zoom_level,
+                x=0, y=image.height,  # 图片左下角在Plotly坐标系中的位置
+                sizex=image.width, 
+                sizey=image.height,
                 sizing="stretch", 
                 opacity=1.0, 
                 layer="below"
             )
         )
         
-        # 显示所有bbox
-        if show_all_boxes:
-            for text, info_list in self.validator.text_bbox_mapping.items():
-                for info in info_list:
-                    bbox = info['bbox']
-                    if len(bbox) >= 4:
-                        # bbox已经是旋转后的坐标,需要应用缩放并转换坐标系
-                        x1, y1, x2, y2 = bbox[:4]
-                        
-                        # 应用缩放
-                        scaled_x1 = x1 * zoom_level
-                        scaled_y1 = y1 * zoom_level
-                        scaled_x2 = x2 * zoom_level
-                        scaled_y2 = y2 * zoom_level
-                        
-                        # 转换为plotly坐标系(原点在左下角)
-                        plot_x1 = scaled_x1
-                        plot_y1 = (image.height * zoom_level) - scaled_y2  # 翻转Y坐标
-                        plot_x2 = scaled_x2
-                        plot_y2 = (image.height * zoom_level) - scaled_y1  # 翻转Y坐标
-                        
-                        color = "rgba(0, 100, 200, 0.2)"
-                        if text in self.validator.marked_errors:
-                            color = "rgba(255, 0, 0, 0.3)"
-                        
-                        fig.add_shape(
-                            type="rect",
-                            x0=plot_x1, y0=plot_y1,
-                            x1=plot_x2, y1=plot_y2,
-                            line=dict(color=color.replace('0.2', '0.8').replace('0.3', '1.0'), width=1),
-                            fillcolor=color,
-                        )
+        # 显示所有bbox - 需要坐标转换
+        if len(all_boxes) > 0:
+            for bbox in all_boxes:
+                if len(bbox) >= 4:
+                    x1, y1, x2, y2 = bbox[:4]
+                    
+                    # 转换为Plotly坐标系(翻转Y轴)
+                    plot_x1 = x1
+                    plot_x2 = x2
+                    plot_y1 = image.height - y2  # JSON的y2 -> Plotly的底部
+                    plot_y2 = image.height - y1  # JSON的y1 -> Plotly的顶部
+                    
+                    color = "rgba(0, 100, 200, 0.2)"
+                    
+                    fig.add_shape(
+                        type="rect",
+                        x0=plot_x1, y0=plot_y1,
+                        x1=plot_x2, y1=plot_y2,
+                        line=dict(color="blue", width=1),
+                        fillcolor=color,
+                    )
 
         # 高亮显示选中的bbox
         if selected_bbox and len(selected_bbox) >= 4:
             x1, y1, x2, y2 = selected_bbox[:4]
             
-            # 转换为plotly坐标系(selected_bbox已经是缩放后的坐标)
+            # 转换为Plotly坐标系
             plot_x1 = x1
-            plot_y1 = (image.height * zoom_level) - y2  # 翻转Y坐标
             plot_x2 = x2
-            plot_y2 = (image.height * zoom_level) - y1  # 翻转Y坐标
+            plot_y1 = image.height - y2  # 翻转Y坐坐标
+            plot_y2 = image.height - y1  # 翻转Y坐标
             
             fig.add_shape(
                 type="rect",
@@ -458,40 +572,53 @@ class OCRLayoutManager:
                 line=dict(color="red", width=3),
                 fillcolor="rgba(255, 0, 0, 0.3)",
             )
+    
+        # 修复:优化显示尺寸计算
+        max_display_width = 800
+        max_display_height = 600
         
-        # 计算合适的显示尺寸
+        # 计算合适的显示尺寸,保持宽高比
         aspect_ratio = image.width / image.height
-        display_height = min(800, max(400, image.height // 2))
-        display_width = int(display_height * aspect_ratio)
         
-        # 设置布局 - 确保图像完全可见,使用缩放后的尺寸
+        if aspect_ratio > 1:  # 宽图
+            display_width = min(max_display_width, image.width)
+            display_height = int(display_width / aspect_ratio)
+        else:  # 高图
+            display_height = min(max_display_height, image.height)
+            display_width = int(display_height * aspect_ratio)
+        
+        # 修复:设置合理的布局参数
         fig.update_layout(
             width=display_width,
             height=display_height,
-            margin=dict(l=0, r=0, t=0, b=0),
+            margin=dict(l=0, r=0, t=0, b=0),  # 移除所有边距
             showlegend=False,
             plot_bgcolor='white',
             dragmode="pan",
             
-            # X轴设置 - 使用缩放后的图像尺寸
+            # 修复:X轴设置
             xaxis=dict(
                 visible=False,
-                range=[0, image.width * zoom_level],
+                range=[0, image.width],
                 constrain="domain",
                 fixedrange=False,
-                autorange=False
+                autorange=False,
+                showgrid=False,
+                zeroline=False
             ),
             
-            # Y轴设置 - plotly坐标系(原点在左下角)
+            # 修复:Y轴设置,确保范围正确
             yaxis=dict(
                 visible=False,
-                range=[0, image.height * zoom_level],
+                range=[0, image.height],  # 确保Y轴范围从0到图片高度
                 constrain="domain",
                 scaleanchor="x",
                 scaleratio=1,
                 fixedrange=False,
-                autorange=False
+                autorange=False,
+                showgrid=False,
+                zeroline=False
             )
         )
         
-        return fig
+        return fig