Ver código fonte

新增图像旋转和加载功能,优化交互式图片显示及bbox对齐逻辑
旋转图片标注不对

zhch158_admin 2 meses atrás
pai
commit
bb9339c410
1 arquivos alterados com 151 adições e 45 exclusões
  1. 151 45
      ocr_validator_layout.py

+ 151 - 45
ocr_validator_layout.py

@@ -13,7 +13,8 @@ import plotly.graph_objects as go
 from ocr_validator_utils import (
     convert_html_table_to_markdown, 
     parse_html_tables,
-    draw_bbox_on_image
+    draw_bbox_on_image,
+    rotate_image_and_coordinates
 )
 
 
@@ -23,11 +24,79 @@ class OCRLayoutManager:
     def __init__(self, validator):
         self.validator = validator
         self.config = validator.config
+        self._rotated_image_cache = {}  # 缓存旋转后的图像
+    
+    def get_rotation_angle(self) -> float:
+        """从OCR数据中获取旋转角度"""
+        if self.validator.ocr_data:
+            for item in self.validator.ocr_data:
+                if isinstance(item, dict) and 'rotation_angle' in item:
+                    return item['rotation_angle']
+        return 0.0
+    
+    def load_and_rotate_image(self, image_path: str) -> Optional[Image.Image]:
+        """加载并根据需要旋转图像"""
+        if not image_path or not Path(image_path).exists():
+            return None
+            
+        # 检查缓存
+        rotation_angle = self.get_rotation_angle()
+        cache_key = f"{image_path}_{rotation_angle}"
+        
+        if cache_key in self._rotated_image_cache:
+            return self._rotated_image_cache[cache_key]
+        
+        try:
+            image = Image.open(image_path)
+            
+            # 如果需要旋转
+            if rotation_angle != 0:
+                st.info(f"🔄 检测到文档旋转角度: {rotation_angle}°,正在自动旋转图像...")
+                
+                # 收集所有bbox坐标
+                all_bboxes = []
+                text_to_bbox_map = {}  # 记录文本到bbox索引的映射
+                
+                bbox_index = 0
+                for text, info_list in self.validator.text_bbox_mapping.items():
+                    text_to_bbox_map[text] = []
+                    for info in info_list:
+                        all_bboxes.append(info['bbox'])
+                        text_to_bbox_map[text].append(bbox_index)
+                        bbox_index += 1
+                
+                # 旋转图像和坐标
+                rotated_image, rotated_bboxes = rotate_image_and_coordinates(
+                    image, rotation_angle, all_bboxes
+                )
+                
+                # 更新bbox映射 - 使用映射关系确保正确对应
+                for text, bbox_indices in text_to_bbox_map.items():
+                    for i, bbox_idx in enumerate(bbox_indices):
+                        if bbox_idx < len(rotated_bboxes) and i < len(self.validator.text_bbox_mapping[text]):
+                            self.validator.text_bbox_mapping[text][i]['bbox'] = rotated_bboxes[bbox_idx]
+                
+                # 缓存结果
+                self._rotated_image_cache[cache_key] = rotated_image
+                return rotated_image
+            else:
+                # 无需旋转,直接缓存原图
+                self._rotated_image_cache[cache_key] = image
+                return image
+                
+        except Exception as e:
+            st.error(f"❌ 图像加载失败: {e}")
+            return None
     
     def render_content_section(self, layout_type: str = "standard"):
         """渲染内容区域 - 统一方法"""
         st.header("📄 OCR识别内容")
         
+        # 显示旋转信息
+        # rotation_angle = self.get_rotation_angle()
+        # if rotation_angle != 0:
+        #     st.info(f"📐 文档旋转角度: {rotation_angle}°")
+        
         # 文本选择器
         if self.validator.text_bbox_mapping:
             text_options = ["请选择文本..."] + list(self.validator.text_bbox_mapping.keys())
@@ -256,24 +325,25 @@ class OCRLayoutManager:
         with col2:
             show_all_boxes = st.checkbox("显示所有框", value=False, key=f"{layout_type}_show_all_boxes")
         
-        if self.validator.image_path and Path(self.validator.image_path).exists():
+        # 使用新的图像加载方法
+        image = self.load_and_rotate_image(self.validator.image_path)
+        
+        if image:
             try:
-                image = Image.open(self.validator.image_path)
-                
                 # 根据缩放级别调整图片大小
                 new_width = int(image.width * current_zoom)
                 new_height = int(image.height * current_zoom)
                 resized_image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
                 
-                # 在固定容器中显示图片
+                # 计算选中的bbox - 注意bbox已经是旋转后的坐标
                 selected_bbox = None
                 if st.session_state.selected_text and st.session_state.selected_text in self.validator.text_bbox_mapping:
                     info = self.validator.text_bbox_mapping[st.session_state.selected_text][0]
-                    # 根据缩放级别调整bbox坐标
+                    # bbox已经是旋转后的坐标,只需要应用缩放
                     bbox = info['bbox']
                     selected_bbox = [int(coord * current_zoom) for coord in bbox]
                 
-                # 创建交互式图片 - 确保从顶部开始显示
+                # 创建交互式图片
                 fig = self.create_resized_interactive_plot(resized_image, selected_bbox, current_zoom, show_all_boxes)
                 st.plotly_chart(fig, use_container_width=True, key=f"{layout_type}_plot")
                 
@@ -288,54 +358,76 @@ class OCRLayoutManager:
                     with info_col1:
                         st.write(f"**文本内容:** {st.session_state.selected_text[:30]}...")
                         st.write(f"**类别:** {info['category']}")
-                    
+                        # 显示旋转信息
+                        rotation_angle = self.get_rotation_angle()
+                        if rotation_angle != 0:
+                            st.write(f"**旋转角度:** {rotation_angle}°")
+                
                     with info_col2:
                         st.write(f"**位置:** [{', '.join(map(str, bbox))}]")
                         if len(bbox) >= 4:
                             st.write(f"**大小:** {bbox[2] - bbox[0]} x {bbox[3] - bbox[1]} px")
-                    
-                    # 错误标记功能
-                    col1, col2 = st.columns(2)
-                    with col1:
-                        if st.button("❌ 标记为错误", key=f"{layout_type}_mark_error"):
-                            st.session_state.marked_errors.add(st.session_state.selected_text)
-                            st.rerun()
-                    
-                    with col2:
-                        if st.button("✅ 取消错误标记", key=f"{layout_type}_unmark_error"):
-                            st.session_state.marked_errors.discard(st.session_state.selected_text)
-                            st.rerun()
-                            
+                
+                # 错误标记功能
+                col1, col2 = st.columns(2)
+                with col1:
+                    if st.button("❌ 标记为错误", key=f"{layout_type}_mark_error"):
+                        st.session_state.marked_errors.add(st.session_state.selected_text)
+                        st.rerun()
+                
+                with col2:
+                    if st.button("✅ 取消错误标记", key=f"{layout_type}_unmark_error"):
+                        st.session_state.marked_errors.discard(st.session_state.selected_text)
+                        st.rerun()
+                        
             except Exception as e:
                 st.error(f"❌ 图片处理失败: {e}")
+                st.error(f"详细错误: {str(e)}")
         else:
             st.error("未找到对应的图片文件")
             if self.validator.image_path:
                 st.write(f"期望路径: {self.validator.image_path}")
-        
+    
         st.markdown('</div>', unsafe_allow_html=True)
     
     def create_resized_interactive_plot(self, image: Image.Image, selected_bbox: Optional[List[int]], zoom_level: float, show_all_boxes: bool) -> go.Figure:
-        """创建可调整大小的交互式图片 - 优化显示和定位"""
+        """创建可调整大小的交互式图片 - 修复图像显示和bbox对齐问题"""
         fig = go.Figure()
         
+        # 添加图片 - 修正图像定位,确保与工具栏距离一致
         fig.add_layout_image(
             dict(
                 source=image,
                 xref="x", yref="y",
-                x=0, y=0,  # 改为从底部开始,这样图片会从顶部显示
-                sizex=image.width, sizey=image.height,
-                sizing="stretch", opacity=1.0, layer="below"
+                x=0, y=image.height * zoom_level,  # 修正:图片左上角位置
+                sizex=image.width * zoom_level, 
+                sizey=image.height * zoom_level,
+                sizing="stretch", 
+                opacity=1.0, 
+                layer="below"
             )
         )
         
-        # 显示所有bbox(如果开启)
+        # 显示所有bbox
         if show_all_boxes:
             for text, info_list in self.validator.text_bbox_mapping.items():
                 for info in info_list:
                     bbox = info['bbox']
                     if len(bbox) >= 4:
-                        x1, y1, x2, y2 = [coord * zoom_level for coord in bbox[:4]]
+                        # bbox已经是旋转后的坐标,需要应用缩放并转换坐标系
+                        x1, y1, x2, y2 = bbox[:4]
+                        
+                        # 应用缩放
+                        scaled_x1 = x1 * zoom_level
+                        scaled_y1 = y1 * zoom_level
+                        scaled_x2 = x2 * zoom_level
+                        scaled_y2 = y2 * zoom_level
+                        
+                        # 转换为plotly坐标系(原点在左下角)
+                        plot_x1 = scaled_x1
+                        plot_y1 = (image.height * zoom_level) - scaled_y2  # 翻转Y坐标
+                        plot_x2 = scaled_x2
+                        plot_y2 = (image.height * zoom_level) - scaled_y1  # 翻转Y坐标
                         
                         color = "rgba(0, 100, 200, 0.2)"
                         if text in self.validator.marked_errors:
@@ -343,48 +435,62 @@ class OCRLayoutManager:
                         
                         fig.add_shape(
                             type="rect",
-                            x0=x1, y0=y1,  # 调整坐标系,不再翻转
-                            x1=x2, y1=y2,
+                            x0=plot_x1, y0=plot_y1,
+                            x1=plot_x2, y1=plot_y2,
                             line=dict(color=color.replace('0.2', '0.8').replace('0.3', '1.0'), width=1),
                             fillcolor=color,
                         )
-        
+
         # 高亮显示选中的bbox
         if selected_bbox and len(selected_bbox) >= 4:
             x1, y1, x2, y2 = selected_bbox[:4]
+            
+            # 转换为plotly坐标系(selected_bbox已经是缩放后的坐标)
+            plot_x1 = x1
+            plot_y1 = (image.height * zoom_level) - y2  # 翻转Y坐标
+            plot_x2 = x2
+            plot_y2 = (image.height * zoom_level) - y1  # 翻转Y坐标
+            
             fig.add_shape(
                 type="rect",
-                x0=x1, y0=y1,  # 调整坐标系
-                x1=x2, y1=y2,
-                line=dict(color="red", width=2),
+                x0=plot_x1, y0=plot_y1,
+                x1=plot_x2, y1=plot_y2,
+                line=dict(color="red", width=3),
                 fillcolor="rgba(255, 0, 0, 0.3)",
             )
         
-        # 设置坐标轴范围 - 让图片从顶部开始显示
-        fig.update_xaxes(visible=False, range=[0, image.width])
-        fig.update_yaxes(visible=False, range=[image.height, 0], scaleanchor="x")  # 翻转Y轴让图片从顶部开始
-        
-        # 计算更大的显示尺寸
+        # 计算合适的显示尺寸
         aspect_ratio = image.width / image.height
-        display_height = min(1000, max(600, image.height))  # 动态调整高度
+        display_height = min(800, max(400, image.height // 2))
         display_width = int(display_height * aspect_ratio)
         
+        # 设置布局 - 确保图像完全可见,使用缩放后的尺寸
         fig.update_layout(
             width=display_width,
             height=display_height,
             margin=dict(l=0, r=0, t=0, b=0),
             showlegend=False,
             plot_bgcolor='white',
-            # 设置初始视图让图片从顶部开始显示
+            dragmode="pan",
+            
+            # X轴设置 - 使用缩放后的图像尺寸
             xaxis=dict(
-                range=[0, image.width],
-                constrain="domain"
+                visible=False,
+                range=[0, image.width * zoom_level],
+                constrain="domain",
+                fixedrange=False,
+                autorange=False
             ),
+            
+            # Y轴设置 - plotly坐标系(原点在左下角)
             yaxis=dict(
-                range=[image.height, 0],  # 翻转范围
+                visible=False,
+                range=[0, image.height * zoom_level],
                 constrain="domain",
                 scaleanchor="x",
-                scaleratio=1
+                scaleratio=1,
+                fixedrange=False,
+                autorange=False
             )
         )