Browse Source

feat(layout_detector): enhance layout detection with debug visualization and improved overlapping box handling

- Added debug mode initialization and configuration options for layout detection, allowing for better tracking of detection results.
- Implemented a visualization method to save layout detection results as images and JSON files, aiding in debugging and analysis.
- Improved the overlapping box removal algorithm with a more sophisticated scoring system, enhancing the accuracy of layout results.
zhch158_admin 1 month ago
parent
commit
4a9c9d1114
1 changed files with 347 additions and 83 deletions
  1. 347 83
      ocr_tools/universal_doc_parser/models/adapters/base.py

+ 347 - 83
ocr_tools/universal_doc_parser/models/adapters/base.py

@@ -2,6 +2,10 @@ from abc import ABC, abstractmethod
 from typing import Dict, Any, List, Union, Optional, Tuple
 import numpy as np
 from PIL import Image
+from loguru import logger
+from pathlib import Path
+import cv2
+import json
 
 class BaseAdapter(ABC):
     """基础适配器接口"""
@@ -44,6 +48,18 @@ class BasePreprocessor(BaseAdapter):
 class BaseLayoutDetector(BaseAdapter):
     """版式检测器基类"""
     
+    def __init__(self, config: Dict[str, Any]):
+        """初始化版式检测器
+        
+        Args:
+            config: 配置字典
+        """
+        super().__init__(config)
+        # 初始化 debug 相关属性(支持从配置或运行时设置)
+        self.debug_mode = None  # 将在 detect() 方法中从配置读取
+        self.output_dir = None  # 将在 detect() 方法中从配置读取
+        self.page_name = None   # 将在 detect() 方法中从配置读取
+    
     def detect(
         self, 
         image: Union[np.ndarray, Image.Image],
@@ -66,6 +82,58 @@ class BaseLayoutDetector(BaseAdapter):
         # 调用子类实现的原始检测方法
         layout_results = self._detect_raw(image, ocr_spans)
         
+        # Debug 模式:打印和可视化后处理前的检测结果
+        # 优先从实例属性读取(如果存在),否则从配置读取
+        # 支持两种配置方式:debug_mode 或 debug_options.enabled
+        debug_mode = getattr(self, 'debug_mode', None)
+        if debug_mode is None:
+            if hasattr(self, 'config'):
+                # 优先从 debug_mode 读取
+                debug_mode = self.config.get('debug_mode', False)
+                # 如果没有 debug_mode,尝试从 debug_options.enabled 读取
+                if not debug_mode:
+                    debug_options = self.config.get('debug_options', {})
+                    if isinstance(debug_options, dict):
+                        debug_mode = debug_options.get('enabled', False)
+            else:
+                debug_mode = False
+        
+        if debug_mode:
+            logger.debug(f"🔍 Layout detection raw results (before post-processing): {len(layout_results)} elements")
+            # logger.debug(f"Raw layout_results: {layout_results}")
+            # 可视化 layout 结果
+            output_dir = getattr(self, 'output_dir', None)
+            if output_dir is None:
+                if hasattr(self, 'config'):
+                    # 优先从 output_dir 读取
+                    output_dir = self.config.get('output_dir', None)
+                    # 如果没有 output_dir,尝试从 debug_options.output_dir 读取
+                    if output_dir is None:
+                        debug_options = self.config.get('debug_options', {})
+                        if isinstance(debug_options, dict):
+                            output_dir = debug_options.get('output_dir', None)
+                else:
+                    output_dir = None
+            
+            page_name = getattr(self, 'page_name', None)
+            if page_name is None:
+                if hasattr(self, 'config'):
+                    # 优先从 page_name 读取
+                    page_name = self.config.get('page_name', None)
+                    # 如果没有 page_name,尝试从 debug_options.prefix 读取
+                    if page_name is None:
+                        debug_options = self.config.get('debug_options', {})
+                        if isinstance(debug_options, dict):
+                            prefix = debug_options.get('prefix', '')
+                            page_name = prefix if prefix else 'layout_detection'
+                    if page_name is None:
+                        page_name = 'layout_detection'
+                else:
+                    page_name = 'layout_detection'
+            
+            if output_dir:
+                self._visualize_layout_results(image, layout_results, output_dir, page_name, suffix='raw')
+        
         # 自动执行后处理
         if layout_results:
             layout_config = self.config.get('post_process', {}) if hasattr(self, 'config') else {}
@@ -132,7 +200,7 @@ class BaseLayoutDetector(BaseAdapter):
                 return layout_results
         
         # 1. 去除重叠框
-        layout_results = self._remove_overlapping_boxes(layout_results, CoordinateUtils)
+        layout_results_removed_overlapping = self._remove_overlapping_boxes(layout_results, CoordinateUtils)
         
         # 2. 将大面积文本块转换为表格(如果配置启用)
         layout_config = config if config is not None else {}
@@ -143,94 +211,15 @@ class BaseLayoutDetector(BaseAdapter):
             else:
                 h, w = image.shape[:2] if len(image.shape) >= 2 else (0, 0)
             
-            layout_results = self._convert_large_text_to_table(
-                layout_results,
+            layout_results_converted_large_text = self._convert_large_text_to_table(
+                layout_results_removed_overlapping,
                 (h, w),
                 min_area_ratio=layout_config.get('min_text_area_ratio', 0.25),
                 min_width_ratio=layout_config.get('min_text_width_ratio', 0.4),
                 min_height_ratio=layout_config.get('min_text_height_ratio', 0.3)
             )
         
-        return layout_results
-    
-    def _remove_overlapping_boxes(
-        self,
-        layout_results: List[Dict[str, Any]],
-        coordinate_utils: Any,
-        iou_threshold: float = 0.8,
-        overlap_ratio_threshold: float = 0.8
-    ) -> List[Dict[str, Any]]:
-        """
-        处理重叠的布局框(参考 MinerU 的去重策略)
-        
-        策略:
-        1. 高 IoU 重叠:保留置信度高的框
-        2. 包含关系:小框被大框高度包含时,保留大框并扩展边界
-        """
-        if not layout_results or len(layout_results) <= 1:
-            return layout_results
-        
-        # 复制列表避免修改原数据
-        results = [item.copy() for item in layout_results]
-        need_remove = set()
-        
-        for i in range(len(results)):
-            if i in need_remove:
-                continue
-                
-            for j in range(i + 1, len(results)):
-                if j in need_remove:
-                    continue
-                
-                bbox1 = results[i].get('bbox', [0, 0, 0, 0])
-                bbox2 = results[j].get('bbox', [0, 0, 0, 0])
-                
-                if len(bbox1) < 4 or len(bbox2) < 4:
-                    continue
-                
-                # 计算 IoU
-                iou = coordinate_utils.calculate_iou(bbox1, bbox2)
-                
-                if iou > iou_threshold:
-                    # 高度重叠,保留置信度高的
-                    score1 = results[i].get('confidence', results[i].get('score', 0))
-                    score2 = results[j].get('confidence', results[j].get('score', 0))
-                    
-                    if score1 >= score2:
-                        need_remove.add(j)
-                    else:
-                        need_remove.add(i)
-                        break  # i 被移除,跳出内层循环
-                else:
-                    # 检查包含关系
-                    overlap_ratio = coordinate_utils.calculate_overlap_ratio(bbox1, bbox2)
-                    
-                    if overlap_ratio > overlap_ratio_threshold:
-                        # 小框被大框高度包含
-                        area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
-                        area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
-                        
-                        if area1 <= area2:
-                            small_idx, large_idx = i, j
-                        else:
-                            small_idx, large_idx = j, i
-                        
-                        # 扩展大框的边界
-                        small_bbox = results[small_idx]['bbox']
-                        large_bbox = results[large_idx]['bbox']
-                        results[large_idx]['bbox'] = [
-                            min(small_bbox[0], large_bbox[0]),
-                            min(small_bbox[1], large_bbox[1]),
-                            max(small_bbox[2], large_bbox[2]),
-                            max(small_bbox[3], large_bbox[3])
-                        ]
-                        need_remove.add(small_idx)
-                        
-                        if small_idx == i:
-                            break  # i 被移除,跳出内层循环
-        
-        # 返回去重后的结果
-        return [results[i] for i in range(len(results)) if i not in need_remove]
+        return layout_results_converted_large_text
     
     def _convert_large_text_to_table(
         self,
@@ -324,6 +313,281 @@ class BaseLayoutDetector(BaseAdapter):
             101: 'image_footnote'
         }
         return category_map.get(category_id, f'unknown_{category_id}')
+    
+    def _visualize_layout_results(
+        self,
+        image: Union[np.ndarray, Image.Image],
+        layout_results: List[Dict[str, Any]],
+        output_dir: str,
+        page_name: str,
+        suffix: str = 'raw'
+    ) -> None:
+        """
+        可视化 layout 检测结果
+        
+        Args:
+            image: 输入图像
+            layout_results: 布局检测结果
+            output_dir: 输出目录
+            page_name: 页面名称
+            suffix: 文件名后缀(如 'raw', 'postprocessed')
+        """
+        if not layout_results:
+            return
+        
+        try:
+            # 转换为 numpy 数组
+            if isinstance(image, Image.Image):
+                vis_image = np.array(image)
+                if len(vis_image.shape) == 3 and vis_image.shape[2] == 3:
+                    # PIL RGB -> OpenCV BGR
+                    vis_image = cv2.cvtColor(vis_image, cv2.COLOR_RGB2BGR)
+            else:
+                vis_image = image.copy()
+                if len(vis_image.shape) == 3 and vis_image.shape[2] == 3:
+                    # 如果是 RGB,转换为 BGR
+                    vis_image = cv2.cvtColor(vis_image, cv2.COLOR_RGB2BGR)
+            
+            # 定义类别颜色映射 (BGR格式)
+            category_colors = {
+                'table_body': (0, 0, 255),      # 红色
+                'table_caption': (0, 0, 200),   # 暗红色
+                'table_footnote': (0, 0, 150),  # 更暗的红色
+                'text': (255, 0, 0),            # 蓝色
+                'title': (0, 255, 255),         # 黄色
+                'header': (255, 0, 255),        # 紫色
+                'footer': (0, 165, 255),        # 橙色
+                'image_body': (0, 255, 0),      # 绿色
+                'image_caption': (0, 200, 0),   # 暗绿色
+                'image_footnote': (0, 150, 0),  # 更暗的绿色
+                'abandon': (128, 128, 128),     # 灰色
+            }
+            
+            # 绘制检测框
+            for result in layout_results:
+                bbox = result.get('bbox', [])
+                if not bbox or len(bbox) < 4:
+                    continue
+                
+                category = result.get('category', 'unknown')
+                color = category_colors.get(category, (128, 128, 128))  # 默认灰色
+                thickness = 2
+                
+                x1, y1, x2, y2 = int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
+                cv2.rectangle(vis_image, (x1, y1), (x2, y2), color, thickness)
+                
+                # 添加类别标签
+                label = f"{category}"
+                confidence = result.get('confidence', result.get('score', 0))
+                if confidence:
+                    label += f":{confidence:.2f}"
+                
+                # 计算文本大小
+                font = cv2.FONT_HERSHEY_SIMPLEX
+                font_scale = 0.4
+                text_thickness = 1
+                (text_width, text_height), baseline = cv2.getTextSize(label, font, font_scale, text_thickness)
+                
+                # 在框的上方绘制文本背景
+                text_y = max(y1 - baseline - 1, text_height + baseline)
+                cv2.rectangle(vis_image, (x1, text_y - text_height - baseline - 2), 
+                            (x1 + text_width, text_y), color, -1)
+                # 绘制文本
+                cv2.putText(vis_image, label, (x1, text_y - baseline - 1), 
+                          font, font_scale, (255, 255, 255), text_thickness)
+            
+            # 保存图像
+            debug_dir = Path(output_dir) / "debug_comparison" / "layout_detection"
+            debug_dir.mkdir(parents=True, exist_ok=True)
+            output_path = debug_dir / f"{page_name}_layout_{suffix}.jpg"
+            cv2.imwrite(str(output_path), vis_image)
+            logger.info(f"📊 Saved layout detection image ({suffix}): {output_path}")
+            
+            # 保存 JSON 数据
+            json_data = {
+                'page_name': page_name,
+                'suffix': suffix,
+                'count': len(layout_results),
+                'results': [
+                    {
+                        'category': r.get('category'),
+                        'bbox': r.get('bbox'),
+                        'confidence': r.get('confidence', r.get('score', 0.0))
+                    }
+                    for r in layout_results
+                ]
+            }
+            json_path = debug_dir / f"{page_name}_layout_{suffix}.json"
+            with open(json_path, 'w', encoding='utf-8') as f:
+                json.dump(json_data, f, ensure_ascii=False, indent=2)
+            logger.info(f"📊 Saved layout detection JSON ({suffix}): {json_path}")
+            
+        except Exception as e:
+            logger.warning(f"⚠️ Failed to visualize layout results: {e}")
+    
+    def _remove_overlapping_boxes(
+        self,
+        layout_results: List[Dict[str, Any]],
+        coordinate_utils: Any,
+        iou_threshold: float = 0.8,
+        overlap_ratio_threshold: float = 0.8
+    ) -> List[Dict[str, Any]]:
+        """
+        改进版重叠框处理算法(基于优先级和决策规则的清晰算法)
+        
+        策略:
+        1. 定义类别优先级(abandon < text/image < table_body)
+        2. 使用统一的决策规则
+        3. 按综合评分排序处理,优先保留大的聚合框
+        
+        Args:
+            layout_results: 布局检测结果
+            coordinate_utils: 坐标工具类
+            iou_threshold: IoU阈值(默认0.8)
+            overlap_ratio_threshold: 重叠比例阈值(默认0.8)
+            
+        Returns:
+            去重后的布局结果
+        """
+        if not layout_results or len(layout_results) <= 1:
+            return layout_results
+        
+        # 常量定义
+        CATEGORY_PRIORITY = {
+            'abandon': 0,
+            'text': 1,
+            'image_body': 1,
+            'title': 2,
+            'footer': 2,
+            'header': 2,
+            'table_body': 3,
+        }
+        AGGREGATE_LABELS = {'key-value region', 'form'}
+        MAX_AREA = 4000000.0  # 用于面积归一化
+        AREA_WEIGHT = 0.5
+        CONFIDENCE_WEIGHT = 0.5
+        AGGREGATE_BONUS = 0.1
+        AREA_RATIO_THRESHOLD = 3.0  # 大框面积需大于小框的倍数
+        
+        def get_bbox_area(bbox: List[float]) -> float:
+            """计算bbox面积"""
+            if len(bbox) < 4:
+                return 0.0
+            return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
+        
+        def is_aggregate_type(box: Dict[str, Any]) -> bool:
+            """检查是否是聚合类型"""
+            original_label = box.get('raw', {}).get('original_label', '').lower()
+            return original_label in AGGREGATE_LABELS
+        
+        def is_bbox_inside(inner: List[float], outer: List[float]) -> bool:
+            """检查inner是否完全包含在outer内"""
+            if len(inner) < 4 or len(outer) < 4:
+                return False
+            return (inner[0] >= outer[0] and inner[1] >= outer[1] and
+                    inner[2] <= outer[2] and inner[3] <= outer[3])
+        
+        def calculate_composite_score(box: Dict[str, Any], area: float) -> float:
+            """计算text类型的综合评分(面积+置信度)"""
+            if box.get('category') != 'text':
+                return box.get('confidence', box.get('score', 0))
+            
+            normalized_area = min(area / MAX_AREA, 1.0)
+            area_score = (normalized_area ** 0.5) * AREA_WEIGHT
+            confidence_score = box.get('confidence', box.get('score', 0)) * CONFIDENCE_WEIGHT
+            bonus = AGGREGATE_BONUS if is_aggregate_type(box) else 0.0
+            return area_score + confidence_score + bonus
+        
+        def should_keep_box1(box1: Dict[str, Any], box2: Dict[str, Any],
+                             iou: float, overlap_ratio: float,
+                             contained_1_in_2: bool, contained_2_in_1: bool) -> bool:
+            """判断是否应该保留box1"""
+            # 提取基本信息
+            cat1, cat2 = box1.get('category', ''), box2.get('category', '')
+            score1 = box1.get('confidence', box1.get('score', 0))
+            score2 = box2.get('confidence', box2.get('score', 0))
+            bbox1, bbox2 = box1.get('bbox', [0, 0, 0, 0]), box2.get('bbox', [0, 0, 0, 0])
+            area1, area2 = get_bbox_area(bbox1), get_bbox_area(bbox2)
+            is_agg1, is_agg2 = is_aggregate_type(box1), is_aggregate_type(box2)
+            
+            # 规则1: 类别优先级
+            priority1 = CATEGORY_PRIORITY.get(cat1, 1)
+            priority2 = CATEGORY_PRIORITY.get(cat2, 1)
+            if priority1 != priority2:
+                return priority1 > priority2
+            
+            # 规则2: 包含关系 + 聚合类型优先
+            if contained_2_in_1 and is_agg1 and not is_agg2:
+                return True
+            if contained_1_in_2 and is_agg2 and not is_agg1:
+                return False
+            
+            # 规则3: 包含关系 + 面积比例
+            if contained_2_in_1 and area1 > area2 * AREA_RATIO_THRESHOLD:
+                return True
+            if contained_1_in_2 and area2 > area1 * AREA_RATIO_THRESHOLD:
+                return False
+            
+            # 规则4: text类型使用综合评分
+            if cat1 == 'text' or cat2 == 'text':
+                comp_score1 = calculate_composite_score(box1, area1)
+                comp_score2 = calculate_composite_score(box2, area2)
+                if abs(comp_score1 - comp_score2) > 0.05:
+                    return comp_score1 > comp_score2
+            
+            # 规则5: 置信度比较
+            if abs(score1 - score2) > 0.1:
+                return score1 > score2
+            
+            # 规则6: 面积比较
+            return area1 >= area2
+        
+        # 主处理逻辑
+        results = [item.copy() for item in layout_results]
+        need_remove = set()
+        
+        # 按综合评分排序(高分优先)
+        def get_sort_key(i: int) -> float:
+            item = results[i]
+            if item.get('category') == 'text':
+                return -calculate_composite_score(item, get_bbox_area(item.get('bbox', [])))
+            return -item.get('confidence', item.get('score', 0))
+        
+        sorted_indices = sorted(range(len(results)), key=get_sort_key)
+        
+        # 比较每对框
+        for idx_i, i in enumerate(sorted_indices):
+            if i in need_remove:
+                continue
+            
+            for idx_j, j in enumerate(sorted_indices):
+                if j == i or j in need_remove or idx_j >= idx_i:
+                    continue
+                
+                bbox1, bbox2 = results[i].get('bbox', []), results[j].get('bbox', [])
+                if len(bbox1) < 4 or len(bbox2) < 4:
+                    continue
+                
+                # 计算重叠指标
+                iou = coordinate_utils.calculate_iou(bbox1, bbox2)
+                overlap_ratio = coordinate_utils.calculate_overlap_ratio(bbox1, bbox2)
+                contained_1_in_2 = is_bbox_inside(bbox1, bbox2)
+                contained_2_in_1 = is_bbox_inside(bbox2, bbox1)
+                
+                # 检查是否有显著重叠
+                if not (iou > iou_threshold or overlap_ratio > overlap_ratio_threshold or
+                       contained_1_in_2 or contained_2_in_1):
+                    continue
+                
+                # 应用决策规则
+                if should_keep_box1(results[i], results[j], iou, overlap_ratio,
+                                   contained_1_in_2, contained_2_in_1):
+                    need_remove.add(j)
+                else:
+                    need_remove.add(i)
+                    break
+        
+        return [results[i] for i in range(len(results)) if i not in need_remove]
 
 class BaseVLRecognizer(BaseAdapter):
     """VL识别器基类"""