Browse Source

feat: 增强布局处理工具类,新增类别合并限制和误检过滤功能

- 在 LayoutUtils 中添加类别合并限制和优先级处理,确保高优先级类别框不被低优先级框覆盖。
- 实现过滤误检图片框的功能,依据文本面积比例判断是否移除图片框。
- 更新 remove_overlapping_boxes 方法,支持面积限制和类别优先级的合并策略。
- 更新 DitLayoutDetector 类,增加新配置选项以启用上述功能,提升布局检测的准确性和可靠性。
zhch158_admin 1 week ago
parent
commit
23326cb1b6
1 changed files with 224 additions and 16 deletions
  1. 224 16
      ocr_tools/universal_doc_parser/models/adapters/dit_layout_adapter.py

+ 224 - 16
ocr_tools/universal_doc_parser/models/adapters/dit_layout_adapter.py

@@ -17,7 +17,7 @@ import cv2
 import numpy as np
 import threading
 from pathlib import Path
-from typing import Dict, List, Union, Any, Optional
+from typing import Dict, List, Union, Any, Optional, Tuple
 from PIL import Image
 
 try:
@@ -87,23 +87,47 @@ class LayoutUtils:
         
         return intersection / min_area
     
+    # 不允许合并的类别组合
+    FORBIDDEN_MERGE = {
+        'image_body': ['text', 'title', 'table_body', 'table'],
+        'figure': ['text', 'title', 'table_body', 'table'],
+    }
+    
+    # 类别优先级(数字越大优先级越高)
+    CATEGORY_PRIORITY = {
+        'text': 3,
+        'title': 3,
+        'table_body': 3,
+        'table': 3,
+        'image_body': 1,
+        'figure': 1,
+    }
+    
     @staticmethod
     def remove_overlapping_boxes(
         layout_results: List[Dict[str, Any]],
         iou_threshold: float = 0.8,
-        overlap_ratio_threshold: float = 0.8
+        overlap_ratio_threshold: float = 0.8,
+        image_size: Optional[Tuple[int, int]] = None,
+        max_area_ratio: float = 0.8,
+        enable_category_restriction: bool = True,
+        enable_category_priority: bool = True
     ) -> List[Dict[str, Any]]:
         """
         处理重叠的布局框(参考 MinerU 的去重策略)
         
         策略:
-        1. 高 IoU 重叠:保留置信度高的框
-        2. 包含关系:小框被大框高度包含时,保留大框并扩展边界
+        1. 高 IoU 重叠:保留置信度高的框(考虑类别优先级)
+        2. 包含关系:小框被大框高度包含时,检查类别限制和面积限制后决定是否合并
         
         Args:
             layout_results: Layout 检测结果列表
             iou_threshold: IoU 阈值,超过此值认为高度重叠
             overlap_ratio_threshold: 重叠面积占小框面积的比例阈值
+            image_size: 图像尺寸 (width, height),用于计算面积限制
+            max_area_ratio: 最大面积比例,合并后的框超过此比例则拒绝合并(默认0.8)
+            enable_category_restriction: 是否启用类别限制(默认True)
+            enable_category_priority: 是否启用类别优先级(默认True)
             
         Returns:
             去重后的布局结果列表
@@ -115,6 +139,44 @@ class LayoutUtils:
         results = [item.copy() for item in layout_results]
         need_remove = set()
         
+        # 计算图像总面积(如果提供了图像尺寸)
+        img_area = None
+        if image_size is not None:
+            img_width, img_height = image_size
+            img_area = img_width * img_height
+        
+        def can_merge(cat1: str, cat2: str) -> bool:
+            """检查两个类别是否允许合并"""
+            if not enable_category_restriction:
+                return True
+            
+            # 检查是否在禁止合并列表中
+            forbidden1 = LayoutUtils.FORBIDDEN_MERGE.get(cat1, [])
+            if cat2 in forbidden1:
+                return False
+            
+            forbidden2 = LayoutUtils.FORBIDDEN_MERGE.get(cat2, [])
+            if cat1 in forbidden2:
+                return False
+            
+            return True
+        
+        def get_priority(category: str) -> int:
+            """获取类别优先级"""
+            if not enable_category_priority:
+                return 0
+            return LayoutUtils.CATEGORY_PRIORITY.get(category, 0)
+        
+        def check_area_limit(merged_bbox: List[float]) -> bool:
+            """检查合并后的框是否超过面积限制"""
+            if img_area is None:
+                return True  # 如果没有提供图像尺寸,不检查
+            
+            merged_area = (merged_bbox[2] - merged_bbox[0]) * (merged_bbox[3] - merged_bbox[1])
+            area_ratio = merged_area / img_area if img_area > 0 else 0
+            
+            return area_ratio <= max_area_ratio
+        
         for i in range(len(results)):
             if i in need_remove:
                 continue
@@ -129,24 +191,41 @@ class LayoutUtils:
                 if len(bbox1) < 4 or len(bbox2) < 4:
                     continue
                 
+                cat1 = results[i].get('category', 'unknown')
+                cat2 = results[j].get('category', 'unknown')
+                
                 # 计算 IoU
                 iou = LayoutUtils.calculate_iou(bbox1, bbox2)
                 
                 if iou > iou_threshold:
-                    # 高度重叠,保留置信度高的
+                    # 高度重叠,保留置信度高的框(考虑类别优先级)
                     score1 = results[i].get('confidence', results[i].get('score', 0))
                     score2 = results[j].get('confidence', results[j].get('score', 0))
+                    priority1 = get_priority(cat1)
+                    priority2 = get_priority(cat2)
                     
-                    if score1 >= score2:
+                    # 如果类别优先级不同,优先保留高优先级
+                    if priority1 != priority2:
+                        if priority1 > priority2:
+                            need_remove.add(j)
+                        else:
+                            need_remove.add(i)
+                            break
+                    # 如果类别优先级相同,保留置信度高的
+                    elif score1 >= score2:
                         need_remove.add(j)
                     else:
                         need_remove.add(i)
-                        break  # i 被移除,跳出内层循环
+                        break
                 else:
                     # 检查包含关系
                     overlap_ratio = LayoutUtils.calculate_overlap_ratio(bbox1, bbox2)
                     
                     if overlap_ratio > overlap_ratio_threshold:
+                        # 检查类别是否允许合并
+                        if not can_merge(cat1, cat2):
+                            continue  # 不允许合并,跳过
+                        
                         # 小框被大框高度包含
                         area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
                         area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
@@ -156,15 +235,31 @@ class LayoutUtils:
                         else:
                             small_idx, large_idx = j, i
                         
-                        # 扩展大框的边界
+                        # 计算合并后的框
                         small_bbox = results[small_idx]['bbox']
                         large_bbox = results[large_idx]['bbox']
-                        results[large_idx]['bbox'] = [
+                        merged_bbox = [
                             min(small_bbox[0], large_bbox[0]),
                             min(small_bbox[1], large_bbox[1]),
                             max(small_bbox[2], large_bbox[2]),
                             max(small_bbox[3], large_bbox[3])
                         ]
+                        
+                        # 检查合并后的面积是否超过限制
+                        if not check_area_limit(merged_bbox):
+                            continue  # 超过面积限制,拒绝合并
+                        
+                        # 检查类别优先级:如果小框优先级更高,不应该被大框合并
+                        small_cat = results[small_idx].get('category', 'unknown')
+                        large_cat = results[large_idx].get('category', 'unknown')
+                        small_priority = get_priority(small_cat)
+                        large_priority = get_priority(large_cat)
+                        
+                        if small_priority > large_priority:
+                            continue  # 小框优先级更高,不应该被合并
+                        
+                        # 执行合并:扩展大框的边界
+                        results[large_idx]['bbox'] = merged_bbox
                         need_remove.add(small_idx)
                         
                         if small_idx == i:
@@ -172,6 +267,90 @@ class LayoutUtils:
         
         # 返回去重后的结果
         return [results[i] for i in range(len(results)) if i not in need_remove]
+    
+    @staticmethod
+    def filter_false_positive_images(
+        layout_results: List[Dict[str, Any]],
+        min_text_area_ratio: float = 0.3
+    ) -> List[Dict[str, Any]]:
+        """
+        过滤误检的图片框:如果图片框内包含的其他类型(如text/title/table)的面积总和
+        与图片框的面积比大于阈值,则认为该图片框是误检,应该移除。
+        
+        Args:
+            layout_results: Layout 检测结果列表
+            min_text_area_ratio: 最小文本面积比例阈值,如果图片框内文本面积占比超过此值则移除(默认0.3)
+            
+        Returns:
+            过滤后的布局结果列表
+        """
+        if not layout_results:
+            return layout_results
+        
+        # 需要移除的图片框索引
+        need_remove = set()
+        
+        # 找出所有图片框
+        image_boxes = []
+        other_boxes = []
+        
+        for i, result in enumerate(layout_results):
+            category = result.get('category', 'unknown')
+            if category in ['image_body', 'figure']:
+                image_boxes.append((i, result))
+            else:
+                other_boxes.append((i, result))
+        
+        # 对每个图片框,检查其内部包含的其他类型框的面积
+        for img_idx, img_result in image_boxes:
+            img_bbox = img_result.get('bbox', [0, 0, 0, 0])
+            if len(img_bbox) < 4:
+                continue
+            
+            img_area = (img_bbox[2] - img_bbox[0]) * (img_bbox[3] - img_bbox[1])
+            if img_area == 0:
+                continue
+            
+            # 计算图片框内包含的其他类型框的总面积
+            total_contained_area = 0.0
+            
+            for other_idx, other_result in other_boxes:
+                if other_idx in need_remove:
+                    continue
+                
+                other_bbox = other_result.get('bbox', [0, 0, 0, 0])
+                if len(other_bbox) < 4:
+                    continue
+                
+                # 检查其他框是否被图片框包含
+                # 使用 IoU 或包含关系判断
+                overlap_ratio = LayoutUtils.calculate_overlap_ratio(other_bbox, img_bbox)
+                
+                # 如果其他框的大部分(>50%)都在图片框内,认为被包含
+                if overlap_ratio > 0.5:
+                    other_area = (other_bbox[2] - other_bbox[0]) * (other_bbox[3] - other_bbox[1])
+                    # 计算实际包含的面积(交集)
+                    x1_i = max(img_bbox[0], other_bbox[0])
+                    y1_i = max(img_bbox[1], other_bbox[1])
+                    x2_i = min(img_bbox[2], other_bbox[2])
+                    y2_i = min(img_bbox[3], other_bbox[3])
+                    
+                    if x2_i > x1_i and y2_i > y1_i:
+                        intersection_area = (x2_i - x1_i) * (y2_i - y1_i)
+                        total_contained_area += intersection_area
+            
+            # 计算文本面积占比
+            text_area_ratio = total_contained_area / img_area if img_area > 0 else 0.0
+            
+            # 如果文本面积占比超过阈值,移除该图片框
+            if text_area_ratio > min_text_area_ratio:
+                need_remove.add(img_idx)
+                # 可选:打印调试信息
+                # print(f"🔄 Removed false positive image box: category={img_result.get('category')}, "
+                #       f"bbox={img_bbox}, text_area_ratio={text_area_ratio:.2f} > {min_text_area_ratio}")
+        
+        # 返回过滤后的结果
+        return [result for i, result in enumerate(layout_results) if i not in need_remove]
 
 
 class DitLayoutDetector(BaseLayoutDetector):
@@ -214,6 +393,11 @@ class DitLayoutDetector(BaseLayoutDetector):
                 - remove_overlap: 是否启用重叠框处理 (默认 True)
                 - iou_threshold: IoU 阈值 (默认 0.8)
                 - overlap_ratio_threshold: 重叠比例阈值 (默认 0.8)
+                - max_area_ratio: 最大面积比例 (默认 0.8)
+                - enable_category_restriction: 是否启用类别限制 (默认 True)
+                - enable_category_priority: 是否启用类别优先级 (默认 True)
+                - filter_false_positive_images: 是否过滤误检的图片框 (默认 True)
+                - min_text_area_ratio: 最小文本面积比例阈值,图片框内文本面积占比超过此值则移除 (默认 0.3)
         """
         super().__init__(config)
         self.predictor = None
@@ -223,6 +407,11 @@ class DitLayoutDetector(BaseLayoutDetector):
         self._remove_overlap = True
         self._iou_threshold = 0.8
         self._overlap_ratio_threshold = 0.8
+        self._max_area_ratio = 0.8 # 最大面积比例,合并后的框超过此比例则拒绝合并(默认0.8)
+        self._enable_category_restriction = True
+        self._enable_category_priority = True
+        self._filter_false_positive_images = True
+        self._min_text_area_ratio = 0.3
     
     def initialize(self):
         """初始化模型"""
@@ -251,17 +440,16 @@ class DitLayoutDetector(BaseLayoutDetector):
             
             # 添加 dit_support 路径(适配到 universal_doc_parser)
             current_dir = os.path.dirname(os.path.abspath(__file__))
-            dit_support_path = os.path.join(current_dir, '..', 'dit_support')
+            dit_support_path = Path(__file__).parents[2] / 'dit_support'
             if dit_support_path not in sys.path:
-                sys.path.insert(0, dit_support_path)
+                sys.path.insert(0, str(dit_support_path))
             
             from ditod import add_vit_config
             
             # 获取配置参数
             config_file = self.config.get(
                 'config_file',
-                os.path.join(current_dir, '..', 'dit_support', 'configs',
-                           'cascade', 'cascade_dit_large.yaml')
+                dit_support_path / 'configs' / 'cascade' / 'cascade_dit_large.yaml'
             )
             model_weights = self.config.get(
                 'model_weights',
@@ -272,6 +460,11 @@ class DitLayoutDetector(BaseLayoutDetector):
             self._remove_overlap = self.config.get('remove_overlap', True)
             self._iou_threshold = self.config.get('iou_threshold', 0.8)
             self._overlap_ratio_threshold = self.config.get('overlap_ratio_threshold', 0.8)
+            self._max_area_ratio = self.config.get('max_area_ratio', 0.8)
+            self._enable_category_restriction = self.config.get('enable_category_restriction', True)
+            self._enable_category_priority = self.config.get('enable_category_priority', True)
+            self._filter_false_positive_images = self.config.get('filter_false_positive_images', True)
+            self._min_text_area_ratio = self.config.get('min_text_area_ratio', 0.3)
             
             # 设置设备
             self._device = torch.device(device)
@@ -393,7 +586,7 @@ class DitLayoutDetector(BaseLayoutDetector):
             # 过滤面积异常大的框
             area = width * height
             img_area = orig_w * orig_h
-            if area > img_area * 0.95:
+            if area > img_area:
                 continue
             
             # 生成多边形坐标
@@ -422,8 +615,23 @@ class DitLayoutDetector(BaseLayoutDetector):
             formatted_results = LayoutUtils.remove_overlapping_boxes(
                 formatted_results,
                 iou_threshold=self._iou_threshold,
-                overlap_ratio_threshold=self._overlap_ratio_threshold
+                overlap_ratio_threshold=self._overlap_ratio_threshold,
+                image_size=(orig_w, orig_h),
+                max_area_ratio=self._max_area_ratio,
+                enable_category_restriction=self._enable_category_restriction,
+                enable_category_priority=self._enable_category_priority
+            )
+        
+        # 过滤误检的图片框(包含过多文本内容的图片框)
+        if self._filter_false_positive_images and len(formatted_results) > 1:
+            before_count = len(formatted_results)
+            formatted_results = LayoutUtils.filter_false_positive_images(
+                formatted_results,
+                min_text_area_ratio=self._min_text_area_ratio
             )
+            removed_count = before_count - len(formatted_results)
+            if removed_count > 0:
+                print(f"🔄 Filtered {removed_count} false positive image boxes")
         
         return formatted_results
     
@@ -457,7 +665,7 @@ class DitLayoutDetector(BaseLayoutDetector):
         self, 
         img: np.ndarray, 
         results: List[Dict],
-        output_path: str = None,
+        output_path: Optional[str] = None,
         show_confidence: bool = True,
         min_confidence: float = 0.0
     ) -> np.ndarray: