Sfoglia il codice sorgente

feat: 新增 DiT Layout Detector 适配器及其核心功能

- 添加 DiT (Document Image Transformer) 布局检测器,支持 PubLayNet 数据集的布局检测。
- 实现布局处理工具类,包含 IoU 计算、重叠框处理等功能。
- 提供模型初始化、检测和可视化功能,支持批量检测和结果输出。
- 更新配置选项,允许用户自定义模型权重、设备和置信度阈值等参数。
zhch158_admin 1 settimana fa
parent
commit
66103ab214

+ 733 - 0
ocr_tools/universal_doc_parser/models/adapters/dit_layout_adapter.py

@@ -0,0 +1,733 @@
+"""DiT Layout Detector 适配器
+
+基于 DiT (Document Image Transformer) 的布局检测适配器,参考 docling_layout_adapter 的实现方式。
+支持 PubLayNet 数据集的 5 个类别:text, title, list, table, figure。
+
+支持的配置:
+- config_file: DiT 配置文件路径
+- model_weights: 模型权重路径或 URL
+- device: 运行设备 ('cpu', 'cuda', 'mps')
+- conf: 置信度阈值 (默认 0.3)
+- remove_overlap: 是否启用重叠框处理 (默认 True)
+- iou_threshold: IoU 阈值 (默认 0.8)
+- overlap_ratio_threshold: 重叠比例阈值 (默认 0.8)
+"""
+
+import cv2
+import numpy as np
+import threading
+from pathlib import Path
+from typing import Dict, List, Union, Any, Optional
+from PIL import Image
+
+try:
+    from .base import BaseLayoutDetector
+except ImportError:
+    from base import BaseLayoutDetector
+
+# 全局锁,防止模型初始化时的线程问题
+_model_init_lock = threading.Lock()
+
+
+class LayoutUtils:
+    """布局处理工具类(简化版,不依赖 external 模块)"""
+    
+    @staticmethod
+    def calculate_iou(bbox1: List[float], bbox2: List[float]) -> float:
+        """计算两个 bbox 的 IoU(交并比)"""
+        x1_1, y1_1, x2_1, y2_1 = bbox1
+        x1_2, y1_2, x2_2, y2_2 = bbox2
+        
+        # 计算交集
+        x1_i = max(x1_1, x1_2)
+        y1_i = max(y1_1, y1_2)
+        x2_i = min(x2_1, x2_2)
+        y2_i = min(y2_1, y2_2)
+        
+        if x2_i <= x1_i or y2_i <= y1_i:
+            return 0.0
+        
+        intersection = (x2_i - x1_i) * (y2_i - y1_i)
+        
+        # 计算并集
+        area1 = (x2_1 - x1_1) * (y2_1 - y1_1)
+        area2 = (x2_2 - x1_2) * (y2_2 - y1_2)
+        union = area1 + area2 - intersection
+        
+        if union == 0:
+            return 0.0
+        
+        return intersection / union
+    
+    @staticmethod
+    def calculate_overlap_ratio(bbox1: List[float], bbox2: List[float]) -> float:
+        """计算重叠面积占小框面积的比例"""
+        x1_1, y1_1, x2_1, y2_1 = bbox1
+        x1_2, y1_2, x2_2, y2_2 = bbox2
+        
+        # 计算交集
+        x1_i = max(x1_1, x1_2)
+        y1_i = max(y1_1, y1_2)
+        x2_i = min(x2_1, x2_2)
+        y2_i = min(y2_1, y2_2)
+        
+        if x2_i <= x1_i or y2_i <= y1_i:
+            return 0.0
+        
+        intersection = (x2_i - x1_i) * (y2_i - y1_i)
+        
+        # 计算两个框的面积
+        area1 = (x2_1 - x1_1) * (y2_1 - y1_1)
+        area2 = (x2_2 - x1_2) * (y2_2 - y1_2)
+        
+        # 返回交集占小框面积的比例
+        min_area = min(area1, area2)
+        if min_area == 0:
+            return 0.0
+        
+        return intersection / min_area
+    
+    @staticmethod
+    def remove_overlapping_boxes(
+        layout_results: List[Dict[str, Any]],
+        iou_threshold: float = 0.8,
+        overlap_ratio_threshold: float = 0.8
+    ) -> List[Dict[str, Any]]:
+        """
+        处理重叠的布局框(参考 MinerU 的去重策略)
+        
+        策略:
+        1. 高 IoU 重叠:保留置信度高的框
+        2. 包含关系:小框被大框高度包含时,保留大框并扩展边界
+        
+        Args:
+            layout_results: Layout 检测结果列表
+            iou_threshold: IoU 阈值,超过此值认为高度重叠
+            overlap_ratio_threshold: 重叠面积占小框面积的比例阈值
+            
+        Returns:
+            去重后的布局结果列表
+        """
+        if not layout_results or len(layout_results) <= 1:
+            return layout_results
+        
+        # 复制列表避免修改原数据
+        results = [item.copy() for item in layout_results]
+        need_remove = set()
+        
+        for i in range(len(results)):
+            if i in need_remove:
+                continue
+                
+            for j in range(i + 1, len(results)):
+                if j in need_remove:
+                    continue
+                
+                bbox1 = results[i].get('bbox', [0, 0, 0, 0])
+                bbox2 = results[j].get('bbox', [0, 0, 0, 0])
+                
+                if len(bbox1) < 4 or len(bbox2) < 4:
+                    continue
+                
+                # 计算 IoU
+                iou = LayoutUtils.calculate_iou(bbox1, bbox2)
+                
+                if iou > iou_threshold:
+                    # 高度重叠,保留置信度高的
+                    score1 = results[i].get('confidence', results[i].get('score', 0))
+                    score2 = results[j].get('confidence', results[j].get('score', 0))
+                    
+                    if score1 >= score2:
+                        need_remove.add(j)
+                    else:
+                        need_remove.add(i)
+                        break  # i 被移除,跳出内层循环
+                else:
+                    # 检查包含关系
+                    overlap_ratio = LayoutUtils.calculate_overlap_ratio(bbox1, bbox2)
+                    
+                    if overlap_ratio > overlap_ratio_threshold:
+                        # 小框被大框高度包含
+                        area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
+                        area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
+                        
+                        if area1 <= area2:
+                            small_idx, large_idx = i, j
+                        else:
+                            small_idx, large_idx = j, i
+                        
+                        # 扩展大框的边界
+                        small_bbox = results[small_idx]['bbox']
+                        large_bbox = results[large_idx]['bbox']
+                        results[large_idx]['bbox'] = [
+                            min(small_bbox[0], large_bbox[0]),
+                            min(small_bbox[1], large_bbox[1]),
+                            max(small_bbox[2], large_bbox[2]),
+                            max(small_bbox[3], large_bbox[3])
+                        ]
+                        need_remove.add(small_idx)
+                        
+                        if small_idx == i:
+                            break  # i 被移除,跳出内层循环
+        
+        # 返回去重后的结果
+        return [results[i] for i in range(len(results)) if i not in need_remove]
+
+
+class DitLayoutDetector(BaseLayoutDetector):
+    """DiT Layout Detector 适配器
+    
+    基于 DiT (Document Image Transformer) 的布局检测器,使用 detectron2 + DiT backbone。
+    支持 PubLayNet 数据集的布局检测。
+    """
+    
+    # DiT/PubLayNet 原始类别定义
+    DIT_LABELS = {
+        0: 'text',
+        1: 'title',
+        2: 'list',
+        3: 'table',
+        4: 'figure',
+    }
+    
+    # 类别映射:PubLayNet → MinerU/EnhancedDocPipeline 类别体系
+    # 参考:
+    # - Pipeline: universal_doc_parser/core/pipeline_manager_v2.py (EnhancedDocPipeline 类别定义)
+    CATEGORY_MAP = {
+        'text': 'text',                    # Text -> text (TEXT_CATEGORIES)
+        'title': 'title',                  # Title -> title (TEXT_CATEGORIES)
+        'list': 'text',                    # List-item -> text (TEXT_CATEGORIES)
+        'table': 'table_body',             # Table -> table_body (TABLE_BODY_CATEGORIES)
+        'figure': 'image_body',            # Figure -> image_body (IMAGE_BODY_CATEGORIES)
+    }
+    
+    def __init__(self, config: Dict[str, Any]):
+        """
+        初始化 DiT Layout 检测器
+        
+        Args:
+            config: 配置字典,支持以下参数:
+                - config_file: DiT 配置文件路径(默认使用 cascade_dit_large.yaml)
+                - model_weights: 模型权重路径或 URL
+                - device: 运行设备 ('cpu', 'cuda', 'mps')
+                - conf: 置信度阈值 (默认 0.3)
+                - remove_overlap: 是否启用重叠框处理 (默认 True)
+                - iou_threshold: IoU 阈值 (默认 0.8)
+                - overlap_ratio_threshold: 重叠比例阈值 (默认 0.8)
+        """
+        super().__init__(config)
+        self.predictor = None
+        self.cfg = None
+        self._device = None
+        self._threshold = 0.3
+        self._remove_overlap = True
+        self._iou_threshold = 0.8
+        self._overlap_ratio_threshold = 0.8
+    
+    def initialize(self):
+        """初始化模型"""
+        import os
+        import sys
+        
+        try:
+            import torch
+            from detectron2.config import get_cfg
+            from detectron2.engine import DefaultPredictor
+            from detectron2.data import MetadataCatalog
+            
+            # PyTorch 2.6+ 兼容性修复
+            if hasattr(torch, '__version__'):
+                torch_version = tuple(map(int, torch.__version__.split('.')[:2]))
+                if torch_version >= (2, 6):
+                    _original_torch_load = torch.load
+                    def _patched_torch_load(f, map_location=None, pickle_module=None, 
+                                            weights_only=None, **kwargs):
+                        if weights_only is None:
+                            weights_only = False
+                        return _original_torch_load(f, map_location=map_location, 
+                                                  pickle_module=pickle_module,
+                                                  weights_only=weights_only, **kwargs)
+                    torch.load = _patched_torch_load
+            
+            # 添加 dit_support 路径(适配到 universal_doc_parser)
+            current_dir = os.path.dirname(os.path.abspath(__file__))
+            dit_support_path = os.path.join(current_dir, '..', 'dit_support')
+            if dit_support_path not in sys.path:
+                sys.path.insert(0, dit_support_path)
+            
+            from ditod import add_vit_config
+            
+            # 获取配置参数
+            config_file = self.config.get(
+                'config_file',
+                os.path.join(current_dir, '..', 'dit_support', 'configs',
+                           'cascade', 'cascade_dit_large.yaml')
+            )
+            model_weights = self.config.get(
+                'model_weights',
+                'https://huggingface.co/HYPJUDY/dit/resolve/main/dit-fts/publaynet_dit-l_cascade.pth'
+            )
+            device = self.config.get('device', 'cpu')
+            self._threshold = self.config.get('conf', 0.3)
+            self._remove_overlap = self.config.get('remove_overlap', True)
+            self._iou_threshold = self.config.get('iou_threshold', 0.8)
+            self._overlap_ratio_threshold = self.config.get('overlap_ratio_threshold', 0.8)
+            
+            # 设置设备
+            self._device = torch.device(device)
+            
+            # 验证配置文件存在
+            if not os.path.exists(config_file):
+                raise FileNotFoundError(f"Config file not found: {config_file}")
+            
+            # 加载配置
+            self.cfg = get_cfg()
+            add_vit_config(self.cfg)
+            self.cfg.merge_from_file(config_file)
+            self.cfg.merge_from_list(["MODEL.WEIGHTS", model_weights])
+            self.cfg.MODEL.DEVICE = str(self._device)
+            
+            # 设置元数据
+            dataset_name = self.cfg.DATASETS.TEST[0]
+            md = MetadataCatalog.get(dataset_name)
+            if dataset_name == 'icdar2019_test':
+                md.set(thing_classes=["table"])
+            else:
+                md.set(thing_classes=["text", "title", "list", "table", "figure"])
+            
+            # 创建预测器(使用锁防止线程问题)
+            with _model_init_lock:
+                self.predictor = DefaultPredictor(self.cfg)
+            
+            print(f"✅ DiT Layout Detector initialized")
+            print(f"   - Config: {config_file}")
+            print(f"   - Device: {self._device}")
+            print(f"   - Threshold: {self._threshold}")
+            print(f"   - Remove overlap: {self._remove_overlap}")
+            
+        except ImportError as e:
+            print(f"❌ Failed to import required libraries: {e}")
+            print("   Please ensure detectron2 and ditod are installed")
+            raise
+        except Exception as e:
+            print(f"❌ Failed to initialize DiT Layout Detector: {e}")
+            raise
+    
+    def cleanup(self):
+        """清理资源"""
+        self.predictor = None
+        self.cfg = None
+        self._device = None
+    
+    def detect(self, image: Union[np.ndarray, Image.Image]) -> List[Dict[str, Any]]:
+        """
+        检测布局
+        
+        Args:
+            image: 输入图像 (numpy数组或PIL图像)
+            
+        Returns:
+            检测结果列表,每个元素包含:
+            - category: MinerU类别名称
+            - bbox: [x1, y1, x2, y2]
+            - confidence: 置信度
+            - raw: 原始检测结果
+        """
+        if self.predictor is None:
+            raise RuntimeError("Model not initialized. Call initialize() first.")
+        
+        # 转换为 numpy 数组 (BGR 格式)
+        if isinstance(image, Image.Image):
+            image = np.array(image)
+            if len(image.shape) == 3 and image.shape[2] == 3:
+                # PIL RGB -> OpenCV BGR
+                image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
+        
+        # 确保是 BGR 格式
+        if isinstance(image, np.ndarray):
+            if len(image.shape) == 2:
+                image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
+            elif len(image.shape) == 3 and image.shape[2] == 3:
+                # 假设是 RGB,转换为 BGR
+                image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) if image.dtype == np.uint8 else image
+        
+        orig_h, orig_w = image.shape[:2]
+        
+        # 运行推理
+        outputs = self.predictor(image)
+        instances = outputs["instances"]
+        
+        # 解析结果
+        formatted_results = []
+        for i in range(len(instances)):
+            score = float(instances.scores[i].cpu().item())
+            
+            # 过滤低置信度
+            if score < self._threshold:
+                continue
+            
+            # 获取类别
+            class_id = int(instances.pred_classes[i].cpu().item())
+            original_label = self.DIT_LABELS.get(class_id, f'unknown_{class_id}')
+            
+            # 映射到 MinerU 类别
+            mineru_category = self.CATEGORY_MAP.get(original_label, 'text')
+            
+            # 提取边界框
+            bbox_tensor = instances.pred_boxes[i].tensor[0].cpu().numpy()
+            x1 = max(0, min(orig_w, float(bbox_tensor[0])))
+            y1 = max(0, min(orig_h, float(bbox_tensor[1])))
+            x2 = max(0, min(orig_w, float(bbox_tensor[2])))
+            y2 = max(0, min(orig_h, float(bbox_tensor[3])))
+            
+            bbox = [int(x1), int(y1), int(x2), int(y2)]
+            
+            # 计算宽高
+            width = bbox[2] - bbox[0]
+            height = bbox[3] - bbox[1]
+            
+            # 过滤太小的框
+            if width < 10 or height < 10:
+                continue
+            
+            # 过滤面积异常大的框
+            area = width * height
+            img_area = orig_w * orig_h
+            if area > img_area * 0.95:
+                continue
+            
+            # 生成多边形坐标
+            poly = [
+                bbox[0], bbox[1],  # 左上
+                bbox[2], bbox[1],  # 右上
+                bbox[2], bbox[3],  # 右下
+                bbox[0], bbox[3],  # 左下
+            ]
+            
+            formatted_results.append({
+                'category': mineru_category,
+                'bbox': bbox,
+                'confidence': score,
+                'raw': {
+                    'original_label': original_label,
+                    'original_label_id': class_id,
+                    'poly': poly,
+                    'width': width,
+                    'height': height
+                }
+            })
+        
+        # 应用重叠框处理
+        if self._remove_overlap and len(formatted_results) > 1:
+            formatted_results = LayoutUtils.remove_overlapping_boxes(
+                formatted_results,
+                iou_threshold=self._iou_threshold,
+                overlap_ratio_threshold=self._overlap_ratio_threshold
+            )
+        
+        return formatted_results
+    
+    def detect_batch(
+        self, 
+        images: List[Union[np.ndarray, Image.Image]]
+    ) -> List[List[Dict[str, Any]]]:
+        """
+        批量检测布局
+        
+        Args:
+            images: 输入图像列表
+            
+        Returns:
+            每个图像的检测结果列表
+        """
+        if self.predictor is None:
+            raise RuntimeError("Model not initialized. Call initialize() first.")
+        
+        if not images:
+            return []
+        
+        all_results = []
+        for image in images:
+            results = self.detect(image)
+            all_results.append(results)
+        
+        return all_results
+    
+    def visualize(
+        self, 
+        img: np.ndarray, 
+        results: List[Dict],
+        output_path: str = None,
+        show_confidence: bool = True,
+        min_confidence: float = 0.0
+    ) -> np.ndarray:
+        """
+        可视化检测结果
+        
+        Args:
+            img: 输入图像 (BGR 格式)
+            results: 检测结果 (MinerU 格式)
+            output_path: 输出路径(可选)
+            show_confidence: 是否显示置信度
+            min_confidence: 最小置信度阈值
+            
+        Returns:
+            标注后的图像
+        """
+        import random
+        
+        vis_img = img.copy()
+        
+        # 预定义类别颜色(与 EnhancedDocPipeline 保持一致)
+        predefined_colors = {
+            # 文本类
+            'text': (153, 0, 76),
+            'title': (102, 102, 255),
+            'header': (128, 128, 128),
+            'footer': (128, 128, 128),
+            'page_footnote': (200, 200, 200),
+            # 表格类
+            'table_body': (204, 204, 0),
+            'table_caption': (255, 255, 102),
+            # 图片类
+            'image_body': (153, 255, 51),
+            'image_caption': (102, 178, 255),
+            # 公式类
+            'interline_equation': (0, 255, 0),
+            # 代码类
+            'code': (102, 0, 204),
+            # 丢弃类
+            'abandon': (100, 100, 100),
+        }
+        
+        # 过滤低置信度结果
+        filtered_results = [
+            res for res in results 
+            if res['confidence'] >= min_confidence
+        ]
+        
+        if not filtered_results:
+            print(f"⚠️ No results to visualize (min_confidence={min_confidence})")
+            return vis_img
+        
+        # 为每个出现的类别分配颜色
+        category_colors = {}
+        for res in filtered_results:
+            cat = res['category']
+            if cat not in category_colors:
+                if cat in predefined_colors:
+                    category_colors[cat] = predefined_colors[cat]
+                else:
+                    category_colors[cat] = (
+                        random.randint(50, 255),
+                        random.randint(50, 255),
+                        random.randint(50, 255)
+                    )
+        
+        # 绘制检测框
+        for res in filtered_results:
+            bbox = res['bbox']
+            x1, y1, x2, y2 = bbox
+            cat = res['category']
+            confidence = res['confidence']
+            color = category_colors[cat]
+            
+            # 获取原始标签
+            original_label = res.get('raw', {}).get('original_label', cat)
+            
+            # 绘制矩形边框
+            cv2.rectangle(vis_img, (x1, y1), (x2, y2), color, 2)
+            
+            # 构造标签文本
+            if show_confidence:
+                label = f"{original_label}->{cat} {confidence:.2f}"
+            else:
+                label = f"{original_label}->{cat}"
+            
+            # 计算标签尺寸
+            label_size, baseline = cv2.getTextSize(
+                label, 
+                cv2.FONT_HERSHEY_SIMPLEX, 
+                0.4, 
+                1
+            )
+            label_w, label_h = label_size
+            
+            # 绘制标签背景
+            cv2.rectangle(
+                vis_img,
+                (x1, y1 - label_h - 4),
+                (x1 + label_w, y1),
+                color,
+                -1
+            )
+            
+            # 绘制标签文字
+            cv2.putText(
+                vis_img,
+                label,
+                (x1, y1 - 2),
+                cv2.FONT_HERSHEY_SIMPLEX,
+                0.4,
+                (255, 255, 255),
+                1,
+                cv2.LINE_AA
+            )
+        
+        # 添加图例
+        if category_colors:
+            self._draw_legend(vis_img, category_colors, len(filtered_results))
+        
+        # 保存可视化结果
+        if output_path:
+            output_path_obj = Path(output_path)
+            output_path_obj.parent.mkdir(parents=True, exist_ok=True)
+            cv2.imwrite(str(output_path_obj), vis_img)
+            print(f"💾 Visualization saved to: {output_path_obj}")
+        
+        return vis_img
+    
+    def _draw_legend(
+        self, 
+        img: np.ndarray, 
+        category_colors: Dict[str, tuple],
+        total_count: int
+    ):
+        """在图像上绘制图例"""
+        legend_x = img.shape[1] - 200
+        legend_y = 20
+        line_height = 25
+        
+        # 绘制半透明背景
+        overlay = img.copy()
+        cv2.rectangle(
+            overlay,
+            (legend_x - 10, legend_y - 10),
+            (img.shape[1] - 10, legend_y + len(category_colors) * line_height + 30),
+            (255, 255, 255),
+            -1
+        )
+        cv2.addWeighted(overlay, 0.7, img, 0.3, 0, img)
+        
+        # 绘制标题
+        cv2.putText(
+            img,
+            f"Legend ({total_count} total)",
+            (legend_x, legend_y),
+            cv2.FONT_HERSHEY_SIMPLEX,
+            0.5,
+            (0, 0, 0),
+            1,
+            cv2.LINE_AA
+        )
+        
+        # 绘制每个类别
+        y_offset = legend_y + line_height
+        for cat, color in sorted(category_colors.items()):
+            cv2.rectangle(
+                img,
+                (legend_x, y_offset - 10),
+                (legend_x + 15, y_offset),
+                color,
+                -1
+            )
+            cv2.rectangle(
+                img,
+                (legend_x, y_offset - 10),
+                (legend_x + 15, y_offset),
+                (0, 0, 0),
+                1
+            )
+            
+            cv2.putText(
+                img,
+                cat,
+                (legend_x + 20, y_offset - 2),
+                cv2.FONT_HERSHEY_SIMPLEX,
+                0.4,
+                (0, 0, 0),
+                1,
+                cv2.LINE_AA
+            )
+            
+            y_offset += line_height
+
+
+# 测试代码
+if __name__ == "__main__":
+    import sys
+    import os
+    
+    # 测试配置
+    project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
+    config = {
+        'config_file': os.path.join(project_root, 'dit', 'object_detection',
+                                   'publaynet_configs', 'cascade', 'cascade_dit_large.yaml'),
+        'model_weights': 'https://huggingface.co/HYPJUDY/dit/resolve/main/dit-fts/publaynet_dit-l_cascade.pth',
+        'device': 'cpu',
+        'conf': 0.3,
+        'remove_overlap': True,
+        'iou_threshold': 0.8,
+        'overlap_ratio_threshold': 0.8
+    }
+    
+    # 初始化检测器
+    print("🔧 Initializing DiT Layout Detector...")
+    detector = DitLayoutDetector(config)
+    detector.initialize()
+    
+    # 读取测试图像
+    img_path = "/Users/zhch158/workspace/data/流水分析/2023年度报告母公司/paddleocr_vl_results/2023年度报告母公司/2023年度报告母公司_page_021.png"
+    
+    print(f"\n📖 Loading image: {img_path}")
+    img = cv2.imread(img_path)
+    
+    if img is None:
+        print(f"❌ Failed to load image: {img_path}")
+        sys.exit(1)
+    
+    print(f"   Image shape: {img.shape}")
+    
+    # 执行检测
+    print("\n🔍 Detecting layout...")
+    results = detector.detect(img)
+    
+    print(f"\n✅ 检测到 {len(results)} 个区域:")
+    for i, res in enumerate(results, 1):
+        print(f"  [{i}] {res['category']}: "
+              f"score={res['confidence']:.3f}, "
+              f"bbox={res['bbox']}, "
+              f"original={res['raw']['original_label']}")
+    
+    # 统计各类别
+    category_counts = {}
+    for res in results:
+        cat = res['category']
+        category_counts[cat] = category_counts.get(cat, 0) + 1
+    
+    print(f"\n📊 类别统计 (MinerU格式):")
+    for cat, count in sorted(category_counts.items()):
+        print(f"  - {cat}: {count}")
+    
+    # 可视化
+    if len(results) > 0:
+        print("\n🎨 Generating visualization...")
+        
+        output_dir = Path(__file__).parent / "output"
+        output_dir.mkdir(parents=True, exist_ok=True)
+        output_path = output_dir / f"{Path(img_path).stem}_dit_layout_vis.jpg"
+        
+        vis_img = detector.visualize(
+            img, 
+            results, 
+            output_path=str(output_path),
+            show_confidence=True,
+            min_confidence=0.0
+        )
+        
+        print(f"💾 Visualization saved to: {output_path}")
+    
+    # 清理
+    detector.cleanup()
+    print("\n✅ 测试完成!")
+