Browse Source

feat: 添加 ONNX Runtime 布局检测器,支持图像预处理和后处理功能

zhch158_admin 2 tuần trước cách đây
mục cha
commit
1a899019a8
1 tập tin đã thay đổi với 463 bổ sung0 xóa
  1. 463 0
      zhch/unified_pytorch_models/layout_detect_onnx.py

+ 463 - 0
zhch/unified_pytorch_models/layout_detect_onnx.py

@@ -0,0 +1,463 @@
+"""使用 ONNX Runtime 进行布局检测的统一接口"""
+
+import cv2
+import numpy as np
+import onnxruntime as ort
+from pathlib import Path
+from typing import Dict, List, Tuple
+
+
+class LayoutDetectorONNX:
+    """布局检测器 ONNX 版本"""
+    
+    # ⚠️ 修正:使用官方的 RT-DETR-H_layout_17cls 类别定义
+    CATEGORY_NAMES = {
+        0: 'paragraph_title',
+        1: 'image',
+        2: 'text',
+        3: 'number',
+        4: 'abstract',
+        5: 'content',
+        6: 'figure_title',
+        7: 'formula',
+        8: 'table',
+        9: 'table_title',
+        10: 'reference',
+        11: 'doc_title',
+        12: 'footnote',
+        13: 'header',
+        14: 'algorithm',
+        15: 'footer',
+        16: 'seal'
+    }
+    
+    def __init__(self, onnx_path: str, use_gpu: bool = False):
+        """
+        初始化 ONNX 模型
+        
+        Args:
+            onnx_path: ONNX 模型路径
+            use_gpu: 是否使用 GPU(Mac 不支持 CUDA)
+        """
+        # Mac 只支持 CPU 或 CoreML
+        if use_gpu:
+            providers = ['CoreMLExecutionProvider', 'CPUExecutionProvider']
+        else:
+            providers = ['CPUExecutionProvider']
+        
+        self.session = ort.InferenceSession(onnx_path, providers=providers)
+        
+        # 获取模型输入信息
+        self.inputs = {inp.name: inp for inp in self.session.get_inputs()}
+        self.outputs = {out.name: out for out in self.session.get_outputs()}
+        
+        print(f"📋 Model inputs: {list(self.inputs.keys())}")
+        print(f"📋 Model outputs: {list(self.outputs.keys())}")
+        
+        # 自动检测输入尺寸
+        self.target_size = self._detect_input_size()
+        print(f"🎯 Detected target size: {self.target_size}")
+        
+        # 检查输入形状
+        for name, inp in self.inputs.items():
+            print(f"   - {name}: shape={inp.shape}, dtype={inp.type}")
+    
+    def _detect_input_size(self) -> int:
+        """自动检测模型的输入尺寸"""
+        if 'image' in self.inputs:
+            shape = self.inputs['image'].shape
+            # shape 通常是 [batch, channels, height, width]
+            if len(shape) >= 3:
+                # 尝试从 shape[2] 或 shape[3] 获取尺寸
+                for dim in shape[2:]:
+                    if isinstance(dim, int) and dim > 0:
+                        return dim
+        
+        # 默认值
+        return 640
+    
+    def preprocess(
+        self, 
+        img: np.ndarray
+    ) -> Tuple[Dict[str, np.ndarray], float, Tuple[int, int]]:
+        """
+        预处理图像 (根据 inference.yml 配置)
+        
+        Args:
+            img: BGR 格式的输入图像
+            
+        Returns:
+            input_dict: 包含所有输入的字典
+            scale_factor: 缩放因子 (用于后处理)
+            orig_shape: 原始图像尺寸 (h, w)
+        """
+        orig_h, orig_w = img.shape[:2]
+        target_size = self.target_size  # 640
+        
+        # ✅ 修正 1: 直接 resize 到目标尺寸,不保持长宽比 (keep_ratio: false)
+        img_resized = cv2.resize(
+            img, 
+            (target_size, target_size), 
+            interpolation=cv2.INTER_LINEAR  # interp: 2
+        )
+        
+        # ✅ 修正 2: 转换为 RGB
+        img_rgb = cv2.cvtColor(img_resized, cv2.COLOR_BGR2RGB)
+        
+        # ✅ 修正 3: 归一化 (mean=[0,0,0], std=[1,1,1], norm_type=none)
+        # 只做 /255,不做均值减法和标准差除法
+        img_normalized = img_rgb.astype(np.float32) / 255.0
+        
+        # 4. 转换为 CHW 格式
+        img_chw = img_normalized.transpose(2, 0, 1)
+        img_tensor = img_chw[None, ...].astype(np.float32)  # [1, 3, 640, 640]
+        
+        # 5. 准备所有输入
+        input_dict = {}
+        
+        # 主图像输入
+        if 'image' in self.inputs:
+            input_dict['image'] = img_tensor
+        elif 'images' in self.inputs:
+            input_dict['images'] = img_tensor
+        else:
+            # 使用第一个输入
+            first_input_name = list(self.inputs.keys())[0]
+            input_dict[first_input_name] = img_tensor
+        
+        # ✅ 修正 4: 计算缩放因子 (实际图像尺寸 / 目标尺寸)
+        scale_h = orig_h / target_size
+        scale_w = orig_w / target_size
+        
+        # im_shape 输入 (原始图像尺寸)
+        if 'im_shape' in self.inputs:
+            im_shape = np.array([[float(orig_h), float(orig_w)]], dtype=np.float32)
+            input_dict['im_shape'] = im_shape
+        
+        # scale_factor 输入
+        if 'scale_factor' in self.inputs:
+            # ⚠️ 注意:这里是原始尺寸/目标尺寸的比例
+            scale_factor = np.array([[scale_h, scale_w]], dtype=np.float32)
+            input_dict['scale_factor'] = scale_factor
+        
+        # ✅ 返回的 scale 用于后处理坐标还原
+        # 因为不保持长宽比,所以需要分别记录 x 和 y 的缩放
+        return input_dict, (scale_h, scale_w), (orig_h, orig_w)
+    
+    def postprocess(
+        self, 
+        outputs: List[np.ndarray], 
+        scale: Tuple[float, float],  # (scale_h, scale_w)
+        orig_shape: Tuple[int, int],
+        conf_threshold: float = 0.5
+    ) -> List[Dict]:
+        """
+        后处理模型输出
+        
+        Args:
+            outputs: ONNX 模型输出
+            scale: 缩放因子 (scale_h, scale_w) = (原图高/640, 原图宽/640)
+            orig_shape: 原始图像尺寸 (h, w)
+            conf_threshold: 置信度阈值
+            
+        Returns:
+            检测结果列表
+        """
+        # 打印调试信息
+        print(f"   📊 Processing {len(outputs)} outputs")
+        for i, output in enumerate(outputs):
+            print(f"      Output[{i}] shape: {output.shape}, dtype: {output.dtype}, range: [{output.min():.2f}, {output.max():.2f}]")
+        
+        scale_h, scale_w = scale
+        orig_h, orig_w = orig_shape
+        
+        print(f"   🔄 Scale factors: scale_h={scale_h:.3f}, scale_w={scale_w:.3f}")
+        print(f"   📐 Original shape: {orig_h} x {orig_w}")
+        
+        # 根据输出形状判断格式
+        if len(outputs) >= 2:
+            output0_shape = outputs[0].shape
+            output1_shape = outputs[1].shape
+            
+            # RT-DETR ONNX 格式: (num_boxes, 6)
+            # 格式: [label_id, score, x1, y1, x2, y2]
+            if len(output0_shape) == 2 and output0_shape[1] == 6:
+                print(f"   ✅ Detected RT-DETR ONNX format: (num_boxes, 6) [label, score, x1, y1, x2, y2]")
+                pred = outputs[0]  # [num_boxes, 6]
+                
+                labels = pred[:, 0].astype(int)
+                scores = pred[:, 1]
+                bboxes = pred[:, 2:6].copy()  # [x1, y1, x2, y2] - 在 640×640 尺度上
+                
+            # 情况2: output0 是 (batch, num_boxes, 6) - 带batch的合并格式
+            elif len(output0_shape) == 3 and output0_shape[2] == 6:
+                print(f"   ✅ Detected batched RT-DETR format: (batch, num_boxes, 6)")
+                pred = outputs[0][0]  # Remove batch dimension
+                
+                labels = pred[:, 0].astype(int)
+                scores = pred[:, 1]
+                bboxes = pred[:, 2:6].copy()
+                
+            # 情况3: output0 是 bboxes, output1 是 scores (分离格式)
+            elif len(output0_shape) == 2 and output0_shape[1] == 4:
+                print(f"   ✅ Detected separate format: bboxes + scores")
+                bboxes = outputs[0].copy()  # [num_boxes, 4]
+                
+                if len(output1_shape) == 1:
+                    scores = outputs[1]
+                    labels = np.zeros(len(scores), dtype=int)
+                elif len(output1_shape) == 2:
+                    scores_all = outputs[1]
+                    scores = scores_all.max(axis=1)
+                    labels = scores_all.argmax(axis=1)
+                else:
+                    raise ValueError(f"Unexpected output1 shape: {output1_shape}")
+        
+            # 情况4: RT-DETR 格式 (batch, num_boxes, 4) + (batch, num_boxes, num_classes)
+            elif len(output0_shape) == 3 and output0_shape[2] == 4:
+                print(f"   ✅ Detected RT-DETR separate format")
+                bboxes = outputs[0][0].copy()
+                scores_all = outputs[1][0]
+                scores = scores_all.max(axis=1)
+                labels = scores_all.argmax(axis=1)
+            
+            else:
+                raise ValueError(f"Unexpected output format: {output0_shape}, {output1_shape}")
+        
+        elif len(outputs) == 1:
+            # 单一输出
+            output_shape = outputs[0].shape
+            
+            if len(output_shape) == 2 and output_shape[1] == 6:
+                print(f"   ✅ Detected single RT-DETR output: (num_boxes, 6)")
+                pred = outputs[0]
+                labels = pred[:, 0].astype(int)
+                scores = pred[:, 1]
+                bboxes = pred[:, 2:6].copy()
+            
+            elif len(output_shape) == 3 and output_shape[2] == 6:
+                print(f"   ✅ Detected single batched output: (batch, num_boxes, 6)")
+                pred = outputs[0][0]
+                labels = pred[:, 0].astype(int)
+                scores = pred[:, 1]
+                bboxes = pred[:, 2:6].copy()
+            
+            else:
+                raise ValueError(f"Unexpected single output shape: {output_shape}")
+        
+        else:
+            raise ValueError(f"Unexpected number of outputs: {len(outputs)}")
+        
+        print(f"   📦 Parsed: {len(bboxes)} boxes, score range: [{scores.min():.6f}, {scores.max():.6f}]")
+        print(f"   📏 Bbox range before scaling: x=[{bboxes[:, 0].min():.1f}, {bboxes[:, 2].max():.1f}], y=[{bboxes[:, 1].min():.1f}, {bboxes[:, 3].max():.1f}]")
+        
+        # ✅ 关键修复:将坐标从 640×640 还原到原图尺度
+        # bboxes 当前在 [0, 640] 范围内,需要乘以缩放因子
+        bboxes[:, [0, 2]] *= scale_w  # x1, x2 乘以 width scale
+        bboxes[:, [1, 3]] *= scale_h  # y1, y2 乘以 height scale
+        
+        print(f"   📏 Bbox range after scaling: x=[{bboxes[:, 0].min():.1f}, {bboxes[:, 2].max():.1f}], y=[{bboxes[:, 1].min():.1f}, {bboxes[:, 3].max():.1f}]")
+        
+        # ⚠️ 自适应阈值
+        max_score = scores.max() if len(scores) > 0 else 0
+        if max_score < conf_threshold:
+            adjusted_threshold = max(max_score * 0.5, 0.05)
+            print(f"   ⚙️  Auto-adjusting threshold: {conf_threshold:.3f} → {adjusted_threshold:.3f} (max_score={max_score:.3f})")
+            conf_threshold = adjusted_threshold
+        
+        # 过滤低分框
+        mask = scores > conf_threshold
+        bboxes = bboxes[mask]
+        scores = scores[mask]
+        labels = labels[mask]
+        
+        print(f"   ✂️  After filtering (score > {conf_threshold:.3f}): {len(bboxes)} boxes")
+        
+        # 过滤完全在图像外的框
+        valid_mask = (
+            (bboxes[:, 2] > 0) &  # x2 > 0
+            (bboxes[:, 3] > 0) &  # y2 > 0
+            (bboxes[:, 0] < orig_w) &  # x1 < width
+            (bboxes[:, 1] < orig_h)    # y1 < height
+        )
+        bboxes = bboxes[valid_mask]
+        scores = scores[valid_mask]
+        labels = labels[valid_mask]
+        
+        print(f"   🗺️  After spatial filtering: {len(bboxes)} boxes")
+        
+        # 裁剪坐标到图像范围
+        bboxes[:, [0, 2]] = np.clip(bboxes[:, [0, 2]], 0, orig_w)
+        bboxes[:, [1, 3]] = np.clip(bboxes[:, [1, 3]], 0, orig_h)
+        
+        # 构造结果
+        results = []
+        for box, score, label in zip(bboxes, scores, labels):
+            x1, y1, x2, y2 = box
+            
+            # 过滤无效框
+            width = x2 - x1
+            height = y2 - y1
+            
+            # 过滤太小的框
+            if width < 10 or height < 10:
+                continue
+            
+            # 过滤面积异常大的框
+            area = width * height
+            img_area = orig_w * orig_h
+            if area > img_area * 0.95:
+                continue
+                
+            results.append({
+                'category_id': int(label),
+                'category_name': self.CATEGORY_NAMES.get(int(label), f'unknown_{label}'),
+                'bbox': [int(x1), int(y1), int(x2), int(y2)],
+                'poly': [int(x1), int(y1), int(x2), int(y1), int(x2), int(y2), int(x1), int(y2)],
+                'score': float(score),
+                'width': int(width),
+                'height': int(height)
+            })
+        
+        print(f"   ✅ Final valid boxes: {len(results)}")
+        
+        return results
+
+    def predict(
+        self, 
+        img: np.ndarray, 
+        conf_threshold: float = 0.05  # 🔧 降低默认阈值
+    ) -> List[Dict]:
+        """
+        执行预测
+        
+        Args:
+            img: BGR 格式的输入图像
+            conf_threshold: 置信度阈值(默认 0.05,会自动调整)
+            
+        Returns:
+            检测结果列表
+        """
+        # 预处理
+        input_dict, scale, orig_shape = self.preprocess(img)
+        
+        # 打印输入形状(调试用)
+        for name, tensor in input_dict.items():
+            print(f"   Input '{name}' shape: {tensor.shape}")
+        
+        # ONNX 推理
+        output_names = [out.name for out in self.session.get_outputs()]
+        outputs = self.session.run(output_names, input_dict)
+        
+        # 打印输出形状(调试用)
+        for i, output in enumerate(outputs):
+            print(f"   Output {i} shape: {output.shape}")
+        
+        # 后处理
+        results = self.postprocess(outputs, scale, orig_shape, conf_threshold)
+        
+        return results
+    
+    def visualize(
+        self, 
+        img: np.ndarray, 
+        results: List[Dict],
+        output_path: str = None
+    ) -> np.ndarray:
+        """
+        可视化检测结果
+        
+        Args:
+            img: 输入图像
+            results: 检测结果
+            output_path: 输出路径(可选)
+            
+        Returns:
+            标注后的图像
+        """
+        img_vis = img.copy()
+        
+        # 颜色映射
+        colors = [
+            (255, 0, 0),    # text: 红色
+            (0, 255, 0),    # title: 绿色
+            (0, 0, 255),    # figure: 蓝色
+            (255, 255, 0),  # figure_caption: 青色
+            (255, 0, 255),  # table: 洋红
+            (0, 255, 255),  # table_caption: 黄色
+            (128, 0, 128),  # header: 紫色
+            (128, 128, 0),  # footer: 橄榄绿
+            (0, 128, 128),  # reference: 青绿
+            (255, 128, 0),  # equation: 橙色
+        ]
+        
+        for res in results:
+            x1, y1, x2, y2 = res['bbox']
+            category_id = res['category_id']
+            category_name = res['category_name']
+            score = res['score']
+            
+            # 选择颜色
+            color = colors[category_id % len(colors)]
+            
+            # 绘制边框
+            cv2.rectangle(img_vis, (x1, y1), (x2, y2), color, 2)
+            
+            # 绘制标签
+            label = f"{category_name}: {score:.2f}"
+            label_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
+            label_w, label_h = label_size
+            
+            # 标签背景
+            cv2.rectangle(img_vis, (x1, y1 - label_h - 10), (x1 + label_w, y1), color, -1)
+            # 标签文字
+            cv2.putText(img_vis, label, (x1, y1 - 5), 
+                       cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
+        
+        if output_path:
+            Path(output_path).parent.mkdir(parents=True, exist_ok=True)
+            cv2.imwrite(output_path, img_vis)
+            print(f"✅ Visualization saved to: {output_path}")
+        
+        return img_vis
+
+
+# 使用示例
+if __name__ == "__main__":
+    # 初始化检测器
+    onnx_model_path = "./Layout/RT-DETR-H_layout_17cls.onnx"
+    detector = LayoutDetectorONNX(onnx_model_path, use_gpu=False)
+    
+    # 读取图像
+    img_path = "/Users/zhch158/workspace/data/流水分析/B用户_扫描流水/PaddleOCR_VL_Results/B用户_扫描流水/B用户_扫描流水_page_001.png"
+    img = cv2.imread(img_path)
+    
+    if img is None:
+        print(f"❌ Failed to load image: {img_path}")
+        exit(1)
+    
+    # 执行检测
+    print(f"🔄 Processing image: {img_path}")
+    results = detector.predict(img, conf_threshold=0.3)
+    
+    print(f"\n✅ 检测到 {len(results)} 个区域:")
+    for i, res in enumerate(results, 1):
+        print(f"  [{i}] {res['category_name']}: "
+              f"score={res['score']:.3f}, "
+              f"bbox={res['bbox']}")
+    
+    # 可视化
+    output_path = "/Users/zhch158/workspace/repository.git/PaddleX/zhch/sample_data/single_model_output/RT-DETR-H_layout_17cls/B用户_扫描流水_page_001_layout_onnx.png"
+    img_vis = detector.visualize(img, results, output_path)
+    
+    print(f"\n📊 Detection Summary:")
+    print(f"  Total detections: {len(results)}")
+    
+    # 统计各类别数量
+    category_counts = {}
+    for res in results:
+        cat_name = res['category_name']
+        category_counts[cat_name] = category_counts.get(cat_name, 0) + 1
+    
+    for cat_name, count in sorted(category_counts.items()):
+        print(f"  - {cat_name}: {count}")