|
|
@@ -0,0 +1,339 @@
|
|
|
+"""使用 ONNX Runtime 进行有线表格单元格检测的适配器"""
|
|
|
+
|
|
|
+import cv2
|
|
|
+import numpy as np
|
|
|
+import onnxruntime as ort
|
|
|
+from pathlib import Path
|
|
|
+from typing import Dict, List, Tuple, Any, Optional
|
|
|
+
|
|
|
+from loguru import logger
|
|
|
+
|
|
|
+
|
|
|
+class PaddleWiredTableCellsDetector:
|
|
|
+ """
|
|
|
+ PaddleX RT-DETR 有线表格单元格检测器 (ONNX 版本)
|
|
|
+
|
|
|
+ 专门用于检测有线表格中的单元格边界框,配合 UNet 线检测使用。
|
|
|
+ """
|
|
|
+
|
|
|
+ # 单元格检测只有一个类别
|
|
|
+ CATEGORY_NAMES = {
|
|
|
+ 0: 'cell'
|
|
|
+ }
|
|
|
+
|
|
|
+ def __init__(self, config: Dict[str, Any]):
|
|
|
+ """
|
|
|
+ 初始化检测器
|
|
|
+
|
|
|
+ Args:
|
|
|
+ config: 配置字典,必须包含:
|
|
|
+ - model_dir: ONNX 模型路径
|
|
|
+ - device: 'cpu' 或 'gpu' (Mac 只支持 CPU/CoreML)
|
|
|
+ - conf: 置信度阈值 (默认 0.5)
|
|
|
+ """
|
|
|
+ self.config = config
|
|
|
+ self.session = None
|
|
|
+ self.inputs = {}
|
|
|
+ self.outputs = {}
|
|
|
+ self.target_size = 640 # RT-DETR 固定输入尺寸
|
|
|
+ self.conf_threshold = config.get('conf', 0.5)
|
|
|
+
|
|
|
+ def initialize(self):
|
|
|
+ """初始化 ONNX 模型"""
|
|
|
+ try:
|
|
|
+ onnx_path = self.config.get('model_dir')
|
|
|
+ if not onnx_path:
|
|
|
+ raise ValueError("model_dir not specified in config")
|
|
|
+
|
|
|
+ if not Path(onnx_path).exists():
|
|
|
+ raise FileNotFoundError(f"ONNX model not found: {onnx_path}")
|
|
|
+
|
|
|
+ # 根据配置选择执行提供器
|
|
|
+ device = self.config.get('device', 'cpu')
|
|
|
+ if device == 'gpu':
|
|
|
+ # Mac 支持 CoreML
|
|
|
+ providers = ['CoreMLExecutionProvider', 'CPUExecutionProvider']
|
|
|
+ else:
|
|
|
+ providers = ['CPUExecutionProvider']
|
|
|
+
|
|
|
+ self.session = ort.InferenceSession(onnx_path, providers=providers)
|
|
|
+
|
|
|
+ # 获取模型输入输出信息
|
|
|
+ self.inputs = {inp.name: inp for inp in self.session.get_inputs()}
|
|
|
+ self.outputs = {out.name: out for out in self.session.get_outputs()}
|
|
|
+
|
|
|
+ # 自动检测输入尺寸
|
|
|
+ self.target_size = self._detect_input_size()
|
|
|
+
|
|
|
+ logger.info(f"✅ Table Cell Detector initialized: {Path(onnx_path).name}, "
|
|
|
+ f"target_size={self.target_size}, device={device}")
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"❌ Failed to initialize Table Cell Detector: {e}")
|
|
|
+ raise
|
|
|
+
|
|
|
+ def cleanup(self):
|
|
|
+ """清理资源"""
|
|
|
+ self.session = None
|
|
|
+ self.inputs = {}
|
|
|
+ self.outputs = {}
|
|
|
+
|
|
|
+ def _detect_input_size(self) -> int:
|
|
|
+ """自动检测模型的输入尺寸"""
|
|
|
+ if 'image' in self.inputs:
|
|
|
+ shape = self.inputs['image'].shape
|
|
|
+ if len(shape) >= 3:
|
|
|
+ for dim in shape[2:]:
|
|
|
+ if isinstance(dim, int) and dim > 0:
|
|
|
+ return dim
|
|
|
+ return 640 # 默认值
|
|
|
+
|
|
|
+ def detect(
|
|
|
+ self,
|
|
|
+ img: np.ndarray,
|
|
|
+ conf_threshold: Optional[float] = None
|
|
|
+ ) -> List[Dict[str, Any]]:
|
|
|
+ """
|
|
|
+ 检测表格单元格
|
|
|
+
|
|
|
+ Args:
|
|
|
+ img: 输入图像 (BGR 格式)
|
|
|
+ conf_threshold: 置信度阈值 (可选,覆盖初始化时的阈值)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 检测结果列表,每个元素包含:
|
|
|
+ - bbox: [x1, y1, x2, y2] (原图坐标)
|
|
|
+ - score: 置信度
|
|
|
+ - category_id: 类别ID (0=cell)
|
|
|
+ - category_name: 类别名称 ('cell')
|
|
|
+ - width: 单元格宽度
|
|
|
+ - height: 单元格高度
|
|
|
+ """
|
|
|
+ if self.session is None:
|
|
|
+ raise RuntimeError("Model not initialized. Call initialize() first.")
|
|
|
+
|
|
|
+ if conf_threshold is None:
|
|
|
+ conf_threshold = self.conf_threshold
|
|
|
+
|
|
|
+ # 预处理
|
|
|
+ input_dict, scale, orig_shape = self._preprocess(img)
|
|
|
+
|
|
|
+ # ONNX 推理
|
|
|
+ output_names = [out.name for out in self.session.get_outputs()]
|
|
|
+ outputs = self.session.run(output_names, input_dict)
|
|
|
+
|
|
|
+ # 后处理
|
|
|
+ results = self._postprocess(outputs, scale, orig_shape, conf_threshold)
|
|
|
+
|
|
|
+ logger.debug(f"RT-DETR detected {len(results)} cells (conf>{conf_threshold})")
|
|
|
+
|
|
|
+ return results
|
|
|
+
|
|
|
+ def _preprocess(
|
|
|
+ self,
|
|
|
+ img: np.ndarray
|
|
|
+ ) -> Tuple[Dict[str, np.ndarray], Tuple[float, float], Tuple[int, int]]:
|
|
|
+ """
|
|
|
+ 预处理图像 (根据 inference.yml 配置)
|
|
|
+
|
|
|
+ 预处理步骤:
|
|
|
+ 1. Resize: target_size=[640,640], keep_ratio=false, interp=2
|
|
|
+ 2. NormalizeImage: mean=[0,0,0], std=[1,1,1], norm_type=none (只做 /255)
|
|
|
+ 3. Permute: 转换为 CHW 格式
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ input_dict: 包含所有输入的字典
|
|
|
+ scale: (scale_h, scale_w) 缩放因子
|
|
|
+ orig_shape: (h, w) 原始图像尺寸
|
|
|
+ """
|
|
|
+ orig_h, orig_w = img.shape[:2]
|
|
|
+ target_size = self.target_size # 640
|
|
|
+
|
|
|
+ # 1. Resize 到目标尺寸,不保持长宽比 (keep_ratio: false)
|
|
|
+ img_resized = cv2.resize(
|
|
|
+ img,
|
|
|
+ (target_size, target_size),
|
|
|
+ interpolation=cv2.INTER_LINEAR # interp: 2
|
|
|
+ )
|
|
|
+
|
|
|
+ # 2. 转换为 RGB
|
|
|
+ img_rgb = cv2.cvtColor(img_resized, cv2.COLOR_BGR2RGB)
|
|
|
+
|
|
|
+ # 3. 归一化 (mean=[0,0,0], std=[1,1,1], norm_type=none)
|
|
|
+ # 只做 /255,不做均值减法和标准差除法
|
|
|
+ img_normalized = img_rgb.astype(np.float32) / 255.0
|
|
|
+
|
|
|
+ # 4. 转换为 CHW 格式
|
|
|
+ img_chw = img_normalized.transpose(2, 0, 1)
|
|
|
+ img_tensor = img_chw[None, ...].astype(np.float32) # [1, 3, 640, 640]
|
|
|
+
|
|
|
+ # 5. 准备所有输入
|
|
|
+ input_dict = {}
|
|
|
+
|
|
|
+ # 主图像输入
|
|
|
+ if 'image' in self.inputs:
|
|
|
+ input_dict['image'] = img_tensor
|
|
|
+ elif 'images' in self.inputs:
|
|
|
+ input_dict['images'] = img_tensor
|
|
|
+ else:
|
|
|
+ # 使用第一个输入
|
|
|
+ first_input_name = list(self.inputs.keys())[0]
|
|
|
+ input_dict[first_input_name] = img_tensor
|
|
|
+
|
|
|
+ # 计算缩放因子 (原始尺寸 / 目标尺寸)
|
|
|
+ scale_h = orig_h / target_size
|
|
|
+ scale_w = orig_w / target_size
|
|
|
+
|
|
|
+ # im_shape 输入 (原始图像尺寸)
|
|
|
+ if 'im_shape' in self.inputs:
|
|
|
+ im_shape = np.array([[float(orig_h), float(orig_w)]], dtype=np.float32)
|
|
|
+ input_dict['im_shape'] = im_shape
|
|
|
+
|
|
|
+ # scale_factor 输入
|
|
|
+ if 'scale_factor' in self.inputs:
|
|
|
+ scale_factor = np.array([[scale_h, scale_w]], dtype=np.float32)
|
|
|
+ input_dict['scale_factor'] = scale_factor
|
|
|
+
|
|
|
+ return input_dict, (scale_h, scale_w), (orig_h, orig_w)
|
|
|
+
|
|
|
+ def _postprocess(
|
|
|
+ self,
|
|
|
+ outputs: List[np.ndarray],
|
|
|
+ scale: Tuple[float, float], # (scale_h, scale_w)
|
|
|
+ orig_shape: Tuple[int, int],
|
|
|
+ conf_threshold: float = 0.5
|
|
|
+ ) -> List[Dict]:
|
|
|
+ """
|
|
|
+ 后处理模型输出
|
|
|
+
|
|
|
+ Args:
|
|
|
+ outputs: ONNX 模型输出
|
|
|
+ scale: (scale_h, scale_w) 缩放因子
|
|
|
+ orig_shape: (h, w) 原始图像尺寸
|
|
|
+ conf_threshold: 置信度阈值
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 检测结果列表
|
|
|
+ """
|
|
|
+ scale_h, scale_w = scale
|
|
|
+ orig_h, orig_w = orig_shape
|
|
|
+
|
|
|
+ # 解析输出格式
|
|
|
+ if len(outputs) >= 2:
|
|
|
+ output0_shape = outputs[0].shape
|
|
|
+ output1_shape = outputs[1].shape
|
|
|
+
|
|
|
+ # RT-DETR ONNX 格式: (num_boxes, 6)
|
|
|
+ # 格式: [label_id, score, x1, y1, x2, y2]
|
|
|
+ if len(output0_shape) == 2 and output0_shape[1] == 6:
|
|
|
+ pred = outputs[0]
|
|
|
+ labels = pred[:, 0].astype(int)
|
|
|
+ scores = pred[:, 1]
|
|
|
+ bboxes = pred[:, 2:6].copy()
|
|
|
+
|
|
|
+ # 情况2: output0 是 (batch, num_boxes, 6)
|
|
|
+ elif len(output0_shape) == 3 and output0_shape[2] == 6:
|
|
|
+ pred = outputs[0][0]
|
|
|
+ labels = pred[:, 0].astype(int)
|
|
|
+ scores = pred[:, 1]
|
|
|
+ bboxes = pred[:, 2:6].copy()
|
|
|
+
|
|
|
+ # 情况3: output0 是 bboxes, output1 是 scores
|
|
|
+ elif len(output0_shape) == 2 and output0_shape[1] == 4:
|
|
|
+ bboxes = outputs[0].copy()
|
|
|
+ if len(output1_shape) == 1:
|
|
|
+ scores = outputs[1]
|
|
|
+ labels = np.zeros(len(scores), dtype=int)
|
|
|
+ elif len(output1_shape) == 2:
|
|
|
+ scores_all = outputs[1]
|
|
|
+ scores = scores_all.max(axis=1)
|
|
|
+ labels = scores_all.argmax(axis=1)
|
|
|
+ else:
|
|
|
+ raise ValueError(f"Unexpected output1 shape: {output1_shape}")
|
|
|
+
|
|
|
+ # 情况4: RT-DETR 格式 (batch, num_boxes, 4) + (batch, num_boxes, num_classes)
|
|
|
+ elif len(output0_shape) == 3 and output0_shape[2] == 4:
|
|
|
+ bboxes = outputs[0][0].copy()
|
|
|
+ scores_all = outputs[1][0]
|
|
|
+ scores = scores_all.max(axis=1)
|
|
|
+ labels = scores_all.argmax(axis=1)
|
|
|
+
|
|
|
+ else:
|
|
|
+ raise ValueError(f"Unexpected output format: {output0_shape}, {output1_shape}")
|
|
|
+
|
|
|
+ elif len(outputs) == 1:
|
|
|
+ # 单一输出
|
|
|
+ output_shape = outputs[0].shape
|
|
|
+
|
|
|
+ if len(output_shape) == 2 and output_shape[1] == 6:
|
|
|
+ pred = outputs[0]
|
|
|
+ labels = pred[:, 0].astype(int)
|
|
|
+ scores = pred[:, 1]
|
|
|
+ bboxes = pred[:, 2:6].copy()
|
|
|
+
|
|
|
+ elif len(output_shape) == 3 and output_shape[2] == 6:
|
|
|
+ pred = outputs[0][0]
|
|
|
+ labels = pred[:, 0].astype(int)
|
|
|
+ scores = pred[:, 1]
|
|
|
+ bboxes = pred[:, 2:6].copy()
|
|
|
+
|
|
|
+ else:
|
|
|
+ raise ValueError(f"Unexpected single output shape: {output_shape}")
|
|
|
+
|
|
|
+ else:
|
|
|
+ raise ValueError(f"Unexpected number of outputs: {len(outputs)}")
|
|
|
+
|
|
|
+ # 将坐标从 640×640 还原到原图尺度
|
|
|
+ bboxes[:, [0, 2]] *= scale_w
|
|
|
+ bboxes[:, [1, 3]] *= scale_h
|
|
|
+
|
|
|
+ # 过滤低分框
|
|
|
+ mask = scores > conf_threshold
|
|
|
+ bboxes = bboxes[mask]
|
|
|
+ scores = scores[mask]
|
|
|
+ labels = labels[mask]
|
|
|
+
|
|
|
+ # 过滤完全在图像外的框
|
|
|
+ valid_mask = (
|
|
|
+ (bboxes[:, 2] > 0) &
|
|
|
+ (bboxes[:, 3] > 0) &
|
|
|
+ (bboxes[:, 0] < orig_w) &
|
|
|
+ (bboxes[:, 1] < orig_h)
|
|
|
+ )
|
|
|
+ bboxes = bboxes[valid_mask]
|
|
|
+ scores = scores[valid_mask]
|
|
|
+ labels = labels[valid_mask]
|
|
|
+
|
|
|
+ # 裁剪坐标到图像范围
|
|
|
+ bboxes[:, [0, 2]] = np.clip(bboxes[:, [0, 2]], 0, orig_w)
|
|
|
+ bboxes[:, [1, 3]] = np.clip(bboxes[:, [1, 3]], 0, orig_h)
|
|
|
+
|
|
|
+ # 构造结果
|
|
|
+ results = []
|
|
|
+ for box, score, label in zip(bboxes, scores, labels):
|
|
|
+ x1, y1, x2, y2 = box
|
|
|
+
|
|
|
+ # 过滤无效框
|
|
|
+ width = x2 - x1
|
|
|
+ height = y2 - y1
|
|
|
+
|
|
|
+ # 过滤太小的框(单元格通常不会太小)
|
|
|
+ if width < 5 or height < 5:
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 过滤面积异常大的框
|
|
|
+ area = width * height
|
|
|
+ img_area = orig_w * orig_h
|
|
|
+ if area > img_area * 0.95:
|
|
|
+ continue
|
|
|
+
|
|
|
+ results.append({
|
|
|
+ 'bbox': [float(x1), float(y1), float(x2), float(y2)],
|
|
|
+ 'score': float(score),
|
|
|
+ 'category_id': int(label),
|
|
|
+ 'category_name': self.CATEGORY_NAMES.get(int(label), 'cell'),
|
|
|
+ 'width': float(width),
|
|
|
+ 'height': float(height)
|
|
|
+ })
|
|
|
+
|
|
|
+ return results
|