Parcourir la source

feat(paddle_wired_table_cells_detector): 添加 ONNX 版本的有线表格单元格检测器

zhch158_admin il y a 3 semaines
Parent
commit
7db42f9eb6

+ 339 - 0
ocr_tools/universal_doc_parser/models/adapters/paddle_wired_table_cells_detector.py

@@ -0,0 +1,339 @@
+"""使用 ONNX Runtime 进行有线表格单元格检测的适配器"""
+
+import cv2
+import numpy as np
+import onnxruntime as ort
+from pathlib import Path
+from typing import Dict, List, Tuple, Any, Optional
+
+from loguru import logger
+
+
+class PaddleWiredTableCellsDetector:
+    """
+    PaddleX RT-DETR 有线表格单元格检测器 (ONNX 版本)
+    
+    专门用于检测有线表格中的单元格边界框,配合 UNet 线检测使用。
+    """
+    
+    # 单元格检测只有一个类别
+    CATEGORY_NAMES = {
+        0: 'cell'
+    }
+    
+    def __init__(self, config: Dict[str, Any]):
+        """
+        初始化检测器
+        
+        Args:
+            config: 配置字典,必须包含:
+                - model_dir: ONNX 模型路径
+                - device: 'cpu' 或 'gpu' (Mac 只支持 CPU/CoreML)
+                - conf: 置信度阈值 (默认 0.5)
+        """
+        self.config = config
+        self.session = None
+        self.inputs = {}
+        self.outputs = {}
+        self.target_size = 640  # RT-DETR 固定输入尺寸
+        self.conf_threshold = config.get('conf', 0.5)
+    
+    def initialize(self):
+        """初始化 ONNX 模型"""
+        try:
+            onnx_path = self.config.get('model_dir')
+            if not onnx_path:
+                raise ValueError("model_dir not specified in config")
+            
+            if not Path(onnx_path).exists():
+                raise FileNotFoundError(f"ONNX model not found: {onnx_path}")
+            
+            # 根据配置选择执行提供器
+            device = self.config.get('device', 'cpu')
+            if device == 'gpu':
+                # Mac 支持 CoreML
+                providers = ['CoreMLExecutionProvider', 'CPUExecutionProvider']
+            else:
+                providers = ['CPUExecutionProvider']
+            
+            self.session = ort.InferenceSession(onnx_path, providers=providers)
+            
+            # 获取模型输入输出信息
+            self.inputs = {inp.name: inp for inp in self.session.get_inputs()}
+            self.outputs = {out.name: out for out in self.session.get_outputs()}
+            
+            # 自动检测输入尺寸
+            self.target_size = self._detect_input_size()
+            
+            logger.info(f"✅ Table Cell Detector initialized: {Path(onnx_path).name}, "
+                       f"target_size={self.target_size}, device={device}")
+            
+        except Exception as e:
+            logger.error(f"❌ Failed to initialize Table Cell Detector: {e}")
+            raise
+    
+    def cleanup(self):
+        """清理资源"""
+        self.session = None
+        self.inputs = {}
+        self.outputs = {}
+    
+    def _detect_input_size(self) -> int:
+        """自动检测模型的输入尺寸"""
+        if 'image' in self.inputs:
+            shape = self.inputs['image'].shape
+            if len(shape) >= 3:
+                for dim in shape[2:]:
+                    if isinstance(dim, int) and dim > 0:
+                        return dim
+        return 640  # 默认值
+    
+    def detect(
+        self, 
+        img: np.ndarray,
+        conf_threshold: Optional[float] = None
+    ) -> List[Dict[str, Any]]:
+        """
+        检测表格单元格
+        
+        Args:
+            img: 输入图像 (BGR 格式)
+            conf_threshold: 置信度阈值 (可选,覆盖初始化时的阈值)
+            
+        Returns:
+            检测结果列表,每个元素包含:
+            - bbox: [x1, y1, x2, y2] (原图坐标)
+            - score: 置信度
+            - category_id: 类别ID (0=cell)
+            - category_name: 类别名称 ('cell')
+            - width: 单元格宽度
+            - height: 单元格高度
+        """
+        if self.session is None:
+            raise RuntimeError("Model not initialized. Call initialize() first.")
+        
+        if conf_threshold is None:
+            conf_threshold = self.conf_threshold
+        
+        # 预处理
+        input_dict, scale, orig_shape = self._preprocess(img)
+        
+        # ONNX 推理
+        output_names = [out.name for out in self.session.get_outputs()]
+        outputs = self.session.run(output_names, input_dict)
+        
+        # 后处理
+        results = self._postprocess(outputs, scale, orig_shape, conf_threshold)
+        
+        logger.debug(f"RT-DETR detected {len(results)} cells (conf>{conf_threshold})")
+        
+        return results
+    
+    def _preprocess(
+        self, 
+        img: np.ndarray
+    ) -> Tuple[Dict[str, np.ndarray], Tuple[float, float], Tuple[int, int]]:
+        """
+        预处理图像 (根据 inference.yml 配置)
+        
+        预处理步骤:
+        1. Resize: target_size=[640,640], keep_ratio=false, interp=2
+        2. NormalizeImage: mean=[0,0,0], std=[1,1,1], norm_type=none (只做 /255)
+        3. Permute: 转换为 CHW 格式
+        
+        Returns:
+            input_dict: 包含所有输入的字典
+            scale: (scale_h, scale_w) 缩放因子
+            orig_shape: (h, w) 原始图像尺寸
+        """
+        orig_h, orig_w = img.shape[:2]
+        target_size = self.target_size  # 640
+        
+        # 1. Resize 到目标尺寸,不保持长宽比 (keep_ratio: false)
+        img_resized = cv2.resize(
+            img, 
+            (target_size, target_size), 
+            interpolation=cv2.INTER_LINEAR  # interp: 2
+        )
+        
+        # 2. 转换为 RGB
+        img_rgb = cv2.cvtColor(img_resized, cv2.COLOR_BGR2RGB)
+        
+        # 3. 归一化 (mean=[0,0,0], std=[1,1,1], norm_type=none)
+        # 只做 /255,不做均值减法和标准差除法
+        img_normalized = img_rgb.astype(np.float32) / 255.0
+        
+        # 4. 转换为 CHW 格式
+        img_chw = img_normalized.transpose(2, 0, 1)
+        img_tensor = img_chw[None, ...].astype(np.float32)  # [1, 3, 640, 640]
+        
+        # 5. 准备所有输入
+        input_dict = {}
+        
+        # 主图像输入
+        if 'image' in self.inputs:
+            input_dict['image'] = img_tensor
+        elif 'images' in self.inputs:
+            input_dict['images'] = img_tensor
+        else:
+            # 使用第一个输入
+            first_input_name = list(self.inputs.keys())[0]
+            input_dict[first_input_name] = img_tensor
+        
+        # 计算缩放因子 (原始尺寸 / 目标尺寸)
+        scale_h = orig_h / target_size
+        scale_w = orig_w / target_size
+        
+        # im_shape 输入 (原始图像尺寸)
+        if 'im_shape' in self.inputs:
+            im_shape = np.array([[float(orig_h), float(orig_w)]], dtype=np.float32)
+            input_dict['im_shape'] = im_shape
+        
+        # scale_factor 输入
+        if 'scale_factor' in self.inputs:
+            scale_factor = np.array([[scale_h, scale_w]], dtype=np.float32)
+            input_dict['scale_factor'] = scale_factor
+        
+        return input_dict, (scale_h, scale_w), (orig_h, orig_w)
+    
+    def _postprocess(
+        self, 
+        outputs: List[np.ndarray], 
+        scale: Tuple[float, float],  # (scale_h, scale_w)
+        orig_shape: Tuple[int, int],
+        conf_threshold: float = 0.5
+    ) -> List[Dict]:
+        """
+        后处理模型输出
+        
+        Args:
+            outputs: ONNX 模型输出
+            scale: (scale_h, scale_w) 缩放因子
+            orig_shape: (h, w) 原始图像尺寸
+            conf_threshold: 置信度阈值
+            
+        Returns:
+            检测结果列表
+        """
+        scale_h, scale_w = scale
+        orig_h, orig_w = orig_shape
+        
+        # 解析输出格式
+        if len(outputs) >= 2:
+            output0_shape = outputs[0].shape
+            output1_shape = outputs[1].shape
+            
+            # RT-DETR ONNX 格式: (num_boxes, 6)
+            # 格式: [label_id, score, x1, y1, x2, y2]
+            if len(output0_shape) == 2 and output0_shape[1] == 6:
+                pred = outputs[0]
+                labels = pred[:, 0].astype(int)
+                scores = pred[:, 1]
+                bboxes = pred[:, 2:6].copy()
+                
+            # 情况2: output0 是 (batch, num_boxes, 6)
+            elif len(output0_shape) == 3 and output0_shape[2] == 6:
+                pred = outputs[0][0]
+                labels = pred[:, 0].astype(int)
+                scores = pred[:, 1]
+                bboxes = pred[:, 2:6].copy()
+                
+            # 情况3: output0 是 bboxes, output1 是 scores
+            elif len(output0_shape) == 2 and output0_shape[1] == 4:
+                bboxes = outputs[0].copy()
+                if len(output1_shape) == 1:
+                    scores = outputs[1]
+                    labels = np.zeros(len(scores), dtype=int)
+                elif len(output1_shape) == 2:
+                    scores_all = outputs[1]
+                    scores = scores_all.max(axis=1)
+                    labels = scores_all.argmax(axis=1)
+                else:
+                    raise ValueError(f"Unexpected output1 shape: {output1_shape}")
+        
+            # 情况4: RT-DETR 格式 (batch, num_boxes, 4) + (batch, num_boxes, num_classes)
+            elif len(output0_shape) == 3 and output0_shape[2] == 4:
+                bboxes = outputs[0][0].copy()
+                scores_all = outputs[1][0]
+                scores = scores_all.max(axis=1)
+                labels = scores_all.argmax(axis=1)
+            
+            else:
+                raise ValueError(f"Unexpected output format: {output0_shape}, {output1_shape}")
+        
+        elif len(outputs) == 1:
+            # 单一输出
+            output_shape = outputs[0].shape
+            
+            if len(output_shape) == 2 and output_shape[1] == 6:
+                pred = outputs[0]
+                labels = pred[:, 0].astype(int)
+                scores = pred[:, 1]
+                bboxes = pred[:, 2:6].copy()
+            
+            elif len(output_shape) == 3 and output_shape[2] == 6:
+                pred = outputs[0][0]
+                labels = pred[:, 0].astype(int)
+                scores = pred[:, 1]
+                bboxes = pred[:, 2:6].copy()
+            
+            else:
+                raise ValueError(f"Unexpected single output shape: {output_shape}")
+        
+        else:
+            raise ValueError(f"Unexpected number of outputs: {len(outputs)}")
+        
+        # 将坐标从 640×640 还原到原图尺度
+        bboxes[:, [0, 2]] *= scale_w
+        bboxes[:, [1, 3]] *= scale_h
+        
+        # 过滤低分框
+        mask = scores > conf_threshold
+        bboxes = bboxes[mask]
+        scores = scores[mask]
+        labels = labels[mask]
+        
+        # 过滤完全在图像外的框
+        valid_mask = (
+            (bboxes[:, 2] > 0) &
+            (bboxes[:, 3] > 0) &
+            (bboxes[:, 0] < orig_w) &
+            (bboxes[:, 1] < orig_h)
+        )
+        bboxes = bboxes[valid_mask]
+        scores = scores[valid_mask]
+        labels = labels[valid_mask]
+        
+        # 裁剪坐标到图像范围
+        bboxes[:, [0, 2]] = np.clip(bboxes[:, [0, 2]], 0, orig_w)
+        bboxes[:, [1, 3]] = np.clip(bboxes[:, [1, 3]], 0, orig_h)
+        
+        # 构造结果
+        results = []
+        for box, score, label in zip(bboxes, scores, labels):
+            x1, y1, x2, y2 = box
+            
+            # 过滤无效框
+            width = x2 - x1
+            height = y2 - y1
+            
+            # 过滤太小的框(单元格通常不会太小)
+            if width < 5 or height < 5:
+                continue
+            
+            # 过滤面积异常大的框
+            area = width * height
+            img_area = orig_w * orig_h
+            if area > img_area * 0.95:
+                continue
+                
+            results.append({
+                'bbox': [float(x1), float(y1), float(x2), float(y2)],
+                'score': float(score),
+                'category_id': int(label),
+                'category_name': self.CATEGORY_NAMES.get(int(label), 'cell'),
+                'width': float(width),
+                'height': float(height)
+            })
+        
+        return results