Jelajahi Sumber

feat(layout_detector): enhance layout detection with debug visualization and improved overlapping box handling

- Added debug mode initialization and configuration options for layout detection, allowing for better tracking of detection results.
- Implemented a visualization method to save layout detection results as images and JSON files, aiding in debugging and analysis.
- Improved the overlapping box removal algorithm with a more sophisticated scoring system, enhancing the accuracy of layout results.
zhch158_admin 1 bulan lalu
induk
melakukan
4a9c9d1114
1 mengubah file dengan 347 tambahan dan 83 penghapusan
  1. 347 83
      ocr_tools/universal_doc_parser/models/adapters/base.py

+ 347 - 83
ocr_tools/universal_doc_parser/models/adapters/base.py

@@ -2,6 +2,10 @@ from abc import ABC, abstractmethod
 from typing import Dict, Any, List, Union, Optional, Tuple
 from typing import Dict, Any, List, Union, Optional, Tuple
 import numpy as np
 import numpy as np
 from PIL import Image
 from PIL import Image
+from loguru import logger
+from pathlib import Path
+import cv2
+import json
 
 
 class BaseAdapter(ABC):
 class BaseAdapter(ABC):
     """基础适配器接口"""
     """基础适配器接口"""
@@ -44,6 +48,18 @@ class BasePreprocessor(BaseAdapter):
 class BaseLayoutDetector(BaseAdapter):
 class BaseLayoutDetector(BaseAdapter):
     """版式检测器基类"""
     """版式检测器基类"""
     
     
+    def __init__(self, config: Dict[str, Any]):
+        """初始化版式检测器
+        
+        Args:
+            config: 配置字典
+        """
+        super().__init__(config)
+        # 初始化 debug 相关属性(支持从配置或运行时设置)
+        self.debug_mode = None  # 将在 detect() 方法中从配置读取
+        self.output_dir = None  # 将在 detect() 方法中从配置读取
+        self.page_name = None   # 将在 detect() 方法中从配置读取
+    
     def detect(
     def detect(
         self, 
         self, 
         image: Union[np.ndarray, Image.Image],
         image: Union[np.ndarray, Image.Image],
@@ -66,6 +82,58 @@ class BaseLayoutDetector(BaseAdapter):
         # 调用子类实现的原始检测方法
         # 调用子类实现的原始检测方法
         layout_results = self._detect_raw(image, ocr_spans)
         layout_results = self._detect_raw(image, ocr_spans)
         
         
+        # Debug 模式:打印和可视化后处理前的检测结果
+        # 优先从实例属性读取(如果存在),否则从配置读取
+        # 支持两种配置方式:debug_mode 或 debug_options.enabled
+        debug_mode = getattr(self, 'debug_mode', None)
+        if debug_mode is None:
+            if hasattr(self, 'config'):
+                # 优先从 debug_mode 读取
+                debug_mode = self.config.get('debug_mode', False)
+                # 如果没有 debug_mode,尝试从 debug_options.enabled 读取
+                if not debug_mode:
+                    debug_options = self.config.get('debug_options', {})
+                    if isinstance(debug_options, dict):
+                        debug_mode = debug_options.get('enabled', False)
+            else:
+                debug_mode = False
+        
+        if debug_mode:
+            logger.debug(f"🔍 Layout detection raw results (before post-processing): {len(layout_results)} elements")
+            # logger.debug(f"Raw layout_results: {layout_results}")
+            # 可视化 layout 结果
+            output_dir = getattr(self, 'output_dir', None)
+            if output_dir is None:
+                if hasattr(self, 'config'):
+                    # 优先从 output_dir 读取
+                    output_dir = self.config.get('output_dir', None)
+                    # 如果没有 output_dir,尝试从 debug_options.output_dir 读取
+                    if output_dir is None:
+                        debug_options = self.config.get('debug_options', {})
+                        if isinstance(debug_options, dict):
+                            output_dir = debug_options.get('output_dir', None)
+                else:
+                    output_dir = None
+            
+            page_name = getattr(self, 'page_name', None)
+            if page_name is None:
+                if hasattr(self, 'config'):
+                    # 优先从 page_name 读取
+                    page_name = self.config.get('page_name', None)
+                    # 如果没有 page_name,尝试从 debug_options.prefix 读取
+                    if page_name is None:
+                        debug_options = self.config.get('debug_options', {})
+                        if isinstance(debug_options, dict):
+                            prefix = debug_options.get('prefix', '')
+                            page_name = prefix if prefix else 'layout_detection'
+                    if page_name is None:
+                        page_name = 'layout_detection'
+                else:
+                    page_name = 'layout_detection'
+            
+            if output_dir:
+                self._visualize_layout_results(image, layout_results, output_dir, page_name, suffix='raw')
+        
         # 自动执行后处理
         # 自动执行后处理
         if layout_results:
         if layout_results:
             layout_config = self.config.get('post_process', {}) if hasattr(self, 'config') else {}
             layout_config = self.config.get('post_process', {}) if hasattr(self, 'config') else {}
@@ -132,7 +200,7 @@ class BaseLayoutDetector(BaseAdapter):
                 return layout_results
                 return layout_results
         
         
         # 1. 去除重叠框
         # 1. 去除重叠框
-        layout_results = self._remove_overlapping_boxes(layout_results, CoordinateUtils)
+        layout_results_removed_overlapping = self._remove_overlapping_boxes(layout_results, CoordinateUtils)
         
         
         # 2. 将大面积文本块转换为表格(如果配置启用)
         # 2. 将大面积文本块转换为表格(如果配置启用)
         layout_config = config if config is not None else {}
         layout_config = config if config is not None else {}
@@ -143,94 +211,15 @@ class BaseLayoutDetector(BaseAdapter):
             else:
             else:
                 h, w = image.shape[:2] if len(image.shape) >= 2 else (0, 0)
                 h, w = image.shape[:2] if len(image.shape) >= 2 else (0, 0)
             
             
-            layout_results = self._convert_large_text_to_table(
-                layout_results,
+            layout_results_converted_large_text = self._convert_large_text_to_table(
+                layout_results_removed_overlapping,
                 (h, w),
                 (h, w),
                 min_area_ratio=layout_config.get('min_text_area_ratio', 0.25),
                 min_area_ratio=layout_config.get('min_text_area_ratio', 0.25),
                 min_width_ratio=layout_config.get('min_text_width_ratio', 0.4),
                 min_width_ratio=layout_config.get('min_text_width_ratio', 0.4),
                 min_height_ratio=layout_config.get('min_text_height_ratio', 0.3)
                 min_height_ratio=layout_config.get('min_text_height_ratio', 0.3)
             )
             )
         
         
-        return layout_results
-    
-    def _remove_overlapping_boxes(
-        self,
-        layout_results: List[Dict[str, Any]],
-        coordinate_utils: Any,
-        iou_threshold: float = 0.8,
-        overlap_ratio_threshold: float = 0.8
-    ) -> List[Dict[str, Any]]:
-        """
-        处理重叠的布局框(参考 MinerU 的去重策略)
-        
-        策略:
-        1. 高 IoU 重叠:保留置信度高的框
-        2. 包含关系:小框被大框高度包含时,保留大框并扩展边界
-        """
-        if not layout_results or len(layout_results) <= 1:
-            return layout_results
-        
-        # 复制列表避免修改原数据
-        results = [item.copy() for item in layout_results]
-        need_remove = set()
-        
-        for i in range(len(results)):
-            if i in need_remove:
-                continue
-                
-            for j in range(i + 1, len(results)):
-                if j in need_remove:
-                    continue
-                
-                bbox1 = results[i].get('bbox', [0, 0, 0, 0])
-                bbox2 = results[j].get('bbox', [0, 0, 0, 0])
-                
-                if len(bbox1) < 4 or len(bbox2) < 4:
-                    continue
-                
-                # 计算 IoU
-                iou = coordinate_utils.calculate_iou(bbox1, bbox2)
-                
-                if iou > iou_threshold:
-                    # 高度重叠,保留置信度高的
-                    score1 = results[i].get('confidence', results[i].get('score', 0))
-                    score2 = results[j].get('confidence', results[j].get('score', 0))
-                    
-                    if score1 >= score2:
-                        need_remove.add(j)
-                    else:
-                        need_remove.add(i)
-                        break  # i 被移除,跳出内层循环
-                else:
-                    # 检查包含关系
-                    overlap_ratio = coordinate_utils.calculate_overlap_ratio(bbox1, bbox2)
-                    
-                    if overlap_ratio > overlap_ratio_threshold:
-                        # 小框被大框高度包含
-                        area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
-                        area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
-                        
-                        if area1 <= area2:
-                            small_idx, large_idx = i, j
-                        else:
-                            small_idx, large_idx = j, i
-                        
-                        # 扩展大框的边界
-                        small_bbox = results[small_idx]['bbox']
-                        large_bbox = results[large_idx]['bbox']
-                        results[large_idx]['bbox'] = [
-                            min(small_bbox[0], large_bbox[0]),
-                            min(small_bbox[1], large_bbox[1]),
-                            max(small_bbox[2], large_bbox[2]),
-                            max(small_bbox[3], large_bbox[3])
-                        ]
-                        need_remove.add(small_idx)
-                        
-                        if small_idx == i:
-                            break  # i 被移除,跳出内层循环
-        
-        # 返回去重后的结果
-        return [results[i] for i in range(len(results)) if i not in need_remove]
+        return layout_results_converted_large_text
     
     
     def _convert_large_text_to_table(
     def _convert_large_text_to_table(
         self,
         self,
@@ -324,6 +313,281 @@ class BaseLayoutDetector(BaseAdapter):
             101: 'image_footnote'
             101: 'image_footnote'
         }
         }
         return category_map.get(category_id, f'unknown_{category_id}')
         return category_map.get(category_id, f'unknown_{category_id}')
+    
+    def _visualize_layout_results(
+        self,
+        image: Union[np.ndarray, Image.Image],
+        layout_results: List[Dict[str, Any]],
+        output_dir: str,
+        page_name: str,
+        suffix: str = 'raw'
+    ) -> None:
+        """
+        可视化 layout 检测结果
+        
+        Args:
+            image: 输入图像
+            layout_results: 布局检测结果
+            output_dir: 输出目录
+            page_name: 页面名称
+            suffix: 文件名后缀(如 'raw', 'postprocessed')
+        """
+        if not layout_results:
+            return
+        
+        try:
+            # 转换为 numpy 数组
+            if isinstance(image, Image.Image):
+                vis_image = np.array(image)
+                if len(vis_image.shape) == 3 and vis_image.shape[2] == 3:
+                    # PIL RGB -> OpenCV BGR
+                    vis_image = cv2.cvtColor(vis_image, cv2.COLOR_RGB2BGR)
+            else:
+                vis_image = image.copy()
+                if len(vis_image.shape) == 3 and vis_image.shape[2] == 3:
+                    # 如果是 RGB,转换为 BGR
+                    vis_image = cv2.cvtColor(vis_image, cv2.COLOR_RGB2BGR)
+            
+            # 定义类别颜色映射 (BGR格式)
+            category_colors = {
+                'table_body': (0, 0, 255),      # 红色
+                'table_caption': (0, 0, 200),   # 暗红色
+                'table_footnote': (0, 0, 150),  # 更暗的红色
+                'text': (255, 0, 0),            # 蓝色
+                'title': (0, 255, 255),         # 黄色
+                'header': (255, 0, 255),        # 紫色
+                'footer': (0, 165, 255),        # 橙色
+                'image_body': (0, 255, 0),      # 绿色
+                'image_caption': (0, 200, 0),   # 暗绿色
+                'image_footnote': (0, 150, 0),  # 更暗的绿色
+                'abandon': (128, 128, 128),     # 灰色
+            }
+            
+            # 绘制检测框
+            for result in layout_results:
+                bbox = result.get('bbox', [])
+                if not bbox or len(bbox) < 4:
+                    continue
+                
+                category = result.get('category', 'unknown')
+                color = category_colors.get(category, (128, 128, 128))  # 默认灰色
+                thickness = 2
+                
+                x1, y1, x2, y2 = int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
+                cv2.rectangle(vis_image, (x1, y1), (x2, y2), color, thickness)
+                
+                # 添加类别标签
+                label = f"{category}"
+                confidence = result.get('confidence', result.get('score', 0))
+                if confidence:
+                    label += f":{confidence:.2f}"
+                
+                # 计算文本大小
+                font = cv2.FONT_HERSHEY_SIMPLEX
+                font_scale = 0.4
+                text_thickness = 1
+                (text_width, text_height), baseline = cv2.getTextSize(label, font, font_scale, text_thickness)
+                
+                # 在框的上方绘制文本背景
+                text_y = max(y1 - baseline - 1, text_height + baseline)
+                cv2.rectangle(vis_image, (x1, text_y - text_height - baseline - 2), 
+                            (x1 + text_width, text_y), color, -1)
+                # 绘制文本
+                cv2.putText(vis_image, label, (x1, text_y - baseline - 1), 
+                          font, font_scale, (255, 255, 255), text_thickness)
+            
+            # 保存图像
+            debug_dir = Path(output_dir) / "debug_comparison" / "layout_detection"
+            debug_dir.mkdir(parents=True, exist_ok=True)
+            output_path = debug_dir / f"{page_name}_layout_{suffix}.jpg"
+            cv2.imwrite(str(output_path), vis_image)
+            logger.info(f"📊 Saved layout detection image ({suffix}): {output_path}")
+            
+            # 保存 JSON 数据
+            json_data = {
+                'page_name': page_name,
+                'suffix': suffix,
+                'count': len(layout_results),
+                'results': [
+                    {
+                        'category': r.get('category'),
+                        'bbox': r.get('bbox'),
+                        'confidence': r.get('confidence', r.get('score', 0.0))
+                    }
+                    for r in layout_results
+                ]
+            }
+            json_path = debug_dir / f"{page_name}_layout_{suffix}.json"
+            with open(json_path, 'w', encoding='utf-8') as f:
+                json.dump(json_data, f, ensure_ascii=False, indent=2)
+            logger.info(f"📊 Saved layout detection JSON ({suffix}): {json_path}")
+            
+        except Exception as e:
+            logger.warning(f"⚠️ Failed to visualize layout results: {e}")
+    
+    def _remove_overlapping_boxes(
+        self,
+        layout_results: List[Dict[str, Any]],
+        coordinate_utils: Any,
+        iou_threshold: float = 0.8,
+        overlap_ratio_threshold: float = 0.8
+    ) -> List[Dict[str, Any]]:
+        """
+        改进版重叠框处理算法(基于优先级和决策规则的清晰算法)
+        
+        策略:
+        1. 定义类别优先级(abandon < text/image < table_body)
+        2. 使用统一的决策规则
+        3. 按综合评分排序处理,优先保留大的聚合框
+        
+        Args:
+            layout_results: 布局检测结果
+            coordinate_utils: 坐标工具类
+            iou_threshold: IoU阈值(默认0.8)
+            overlap_ratio_threshold: 重叠比例阈值(默认0.8)
+            
+        Returns:
+            去重后的布局结果
+        """
+        if not layout_results or len(layout_results) <= 1:
+            return layout_results
+        
+        # 常量定义
+        CATEGORY_PRIORITY = {
+            'abandon': 0,
+            'text': 1,
+            'image_body': 1,
+            'title': 2,
+            'footer': 2,
+            'header': 2,
+            'table_body': 3,
+        }
+        AGGREGATE_LABELS = {'key-value region', 'form'}
+        MAX_AREA = 4000000.0  # 用于面积归一化
+        AREA_WEIGHT = 0.5
+        CONFIDENCE_WEIGHT = 0.5
+        AGGREGATE_BONUS = 0.1
+        AREA_RATIO_THRESHOLD = 3.0  # 大框面积需大于小框的倍数
+        
+        def get_bbox_area(bbox: List[float]) -> float:
+            """计算bbox面积"""
+            if len(bbox) < 4:
+                return 0.0
+            return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
+        
+        def is_aggregate_type(box: Dict[str, Any]) -> bool:
+            """检查是否是聚合类型"""
+            original_label = box.get('raw', {}).get('original_label', '').lower()
+            return original_label in AGGREGATE_LABELS
+        
+        def is_bbox_inside(inner: List[float], outer: List[float]) -> bool:
+            """检查inner是否完全包含在outer内"""
+            if len(inner) < 4 or len(outer) < 4:
+                return False
+            return (inner[0] >= outer[0] and inner[1] >= outer[1] and
+                    inner[2] <= outer[2] and inner[3] <= outer[3])
+        
+        def calculate_composite_score(box: Dict[str, Any], area: float) -> float:
+            """计算text类型的综合评分(面积+置信度)"""
+            if box.get('category') != 'text':
+                return box.get('confidence', box.get('score', 0))
+            
+            normalized_area = min(area / MAX_AREA, 1.0)
+            area_score = (normalized_area ** 0.5) * AREA_WEIGHT
+            confidence_score = box.get('confidence', box.get('score', 0)) * CONFIDENCE_WEIGHT
+            bonus = AGGREGATE_BONUS if is_aggregate_type(box) else 0.0
+            return area_score + confidence_score + bonus
+        
+        def should_keep_box1(box1: Dict[str, Any], box2: Dict[str, Any],
+                             iou: float, overlap_ratio: float,
+                             contained_1_in_2: bool, contained_2_in_1: bool) -> bool:
+            """判断是否应该保留box1"""
+            # 提取基本信息
+            cat1, cat2 = box1.get('category', ''), box2.get('category', '')
+            score1 = box1.get('confidence', box1.get('score', 0))
+            score2 = box2.get('confidence', box2.get('score', 0))
+            bbox1, bbox2 = box1.get('bbox', [0, 0, 0, 0]), box2.get('bbox', [0, 0, 0, 0])
+            area1, area2 = get_bbox_area(bbox1), get_bbox_area(bbox2)
+            is_agg1, is_agg2 = is_aggregate_type(box1), is_aggregate_type(box2)
+            
+            # 规则1: 类别优先级
+            priority1 = CATEGORY_PRIORITY.get(cat1, 1)
+            priority2 = CATEGORY_PRIORITY.get(cat2, 1)
+            if priority1 != priority2:
+                return priority1 > priority2
+            
+            # 规则2: 包含关系 + 聚合类型优先
+            if contained_2_in_1 and is_agg1 and not is_agg2:
+                return True
+            if contained_1_in_2 and is_agg2 and not is_agg1:
+                return False
+            
+            # 规则3: 包含关系 + 面积比例
+            if contained_2_in_1 and area1 > area2 * AREA_RATIO_THRESHOLD:
+                return True
+            if contained_1_in_2 and area2 > area1 * AREA_RATIO_THRESHOLD:
+                return False
+            
+            # 规则4: text类型使用综合评分
+            if cat1 == 'text' or cat2 == 'text':
+                comp_score1 = calculate_composite_score(box1, area1)
+                comp_score2 = calculate_composite_score(box2, area2)
+                if abs(comp_score1 - comp_score2) > 0.05:
+                    return comp_score1 > comp_score2
+            
+            # 规则5: 置信度比较
+            if abs(score1 - score2) > 0.1:
+                return score1 > score2
+            
+            # 规则6: 面积比较
+            return area1 >= area2
+        
+        # 主处理逻辑
+        results = [item.copy() for item in layout_results]
+        need_remove = set()
+        
+        # 按综合评分排序(高分优先)
+        def get_sort_key(i: int) -> float:
+            item = results[i]
+            if item.get('category') == 'text':
+                return -calculate_composite_score(item, get_bbox_area(item.get('bbox', [])))
+            return -item.get('confidence', item.get('score', 0))
+        
+        sorted_indices = sorted(range(len(results)), key=get_sort_key)
+        
+        # 比较每对框
+        for idx_i, i in enumerate(sorted_indices):
+            if i in need_remove:
+                continue
+            
+            for idx_j, j in enumerate(sorted_indices):
+                if j == i or j in need_remove or idx_j >= idx_i:
+                    continue
+                
+                bbox1, bbox2 = results[i].get('bbox', []), results[j].get('bbox', [])
+                if len(bbox1) < 4 or len(bbox2) < 4:
+                    continue
+                
+                # 计算重叠指标
+                iou = coordinate_utils.calculate_iou(bbox1, bbox2)
+                overlap_ratio = coordinate_utils.calculate_overlap_ratio(bbox1, bbox2)
+                contained_1_in_2 = is_bbox_inside(bbox1, bbox2)
+                contained_2_in_1 = is_bbox_inside(bbox2, bbox1)
+                
+                # 检查是否有显著重叠
+                if not (iou > iou_threshold or overlap_ratio > overlap_ratio_threshold or
+                       contained_1_in_2 or contained_2_in_1):
+                    continue
+                
+                # 应用决策规则
+                if should_keep_box1(results[i], results[j], iou, overlap_ratio,
+                                   contained_1_in_2, contained_2_in_1):
+                    need_remove.add(j)
+                else:
+                    need_remove.add(i)
+                    break
+        
+        return [results[i] for i in range(len(results)) if i not in need_remove]
 
 
 class BaseVLRecognizer(BaseAdapter):
 class BaseVLRecognizer(BaseAdapter):
     """VL识别器基类"""
     """VL识别器基类"""