"""使用 ONNX Runtime 进行布局检测的统一接口 (符合 BaseLayoutDetector 规范)""" import cv2 import numpy as np import onnxruntime as ort from pathlib import Path from typing import Dict, List, Tuple, Union, Any from PIL import Image import sys try: from .base import BaseLayoutDetector except ImportError: # 如果相对导入失败,尝试绝对导入(适用于测试环境) from base import BaseLayoutDetector class PaddleLayoutDetector(BaseLayoutDetector): """PaddleX RT-DETR 布局检测器 (ONNX 版本)""" # ⚠️ 修正:使用官方的 RT-DETR-H_layout_17cls 类别定义 # 映射到 MinerU 的类别体系 CATEGORY_MAP = { 0: 'title', # paragraph_title -> title 1: 'image_body', # image -> image_body 2: 'text', # text -> text 3: 'text', # number -> text (合并到text) 4: 'text', # abstract -> text 5: 'text', # content -> text 6: 'image_caption', # figure_title -> image_caption 7: 'interline_equation', # formula -> interline_equation 8: 'table_body', # table -> table_body 9: 'table_caption', # table_title -> table_caption 10: 'text', # reference -> text 11: 'title', # doc_title -> title 12: 'table_footnote', # footnote -> table_footnote 13: 'abandon', # header -> abandon (页眉通常不需要) 14: 'text', # algorithm -> text 15: 'abandon', # footer -> abandon (页脚通常不需要) 16: 'abandon' # seal -> abandon (印章通常不需要) } ORIGINAL_CATEGORY_NAMES = { 0: 'paragraph_title', 1: 'image', 2: 'text', 3: 'number', 4: 'abstract', 5: 'content', 6: 'figure_title', 7: 'formula', 8: 'table', 9: 'table_title', 10: 'reference', 11: 'doc_title', 12: 'footnote', 13: 'header', 14: 'algorithm', 15: 'footer', 16: 'seal' } def __init__(self, config: Dict[str, Any]): super().__init__(config) self.session = None self.inputs = {} self.outputs = {} self.target_size = 640 def initialize(self): """初始化 ONNX 模型""" try: onnx_path = self.config.get('model_dir') if not onnx_path: raise ValueError("model_dir not specified in config") if not Path(onnx_path).exists(): raise FileNotFoundError(f"ONNX model not found: {onnx_path}") # 根据配置选择执行提供器 device = self.config.get('device', 'cpu') if device == 'gpu': # Mac 支持 CoreML providers = ['CoreMLExecutionProvider', 'CPUExecutionProvider'] else: providers = ['CPUExecutionProvider'] self.session = ort.InferenceSession(onnx_path, providers=providers) # 获取模型输入输出信息 self.inputs = {inp.name: inp for inp in self.session.get_inputs()} self.outputs = {out.name: out for out in self.session.get_outputs()} # 自动检测输入尺寸 self.target_size = self._detect_input_size() print(f"✅ PaddleX Layout Detector initialized") print(f" - Model: {Path(onnx_path).name}") print(f" - Target size: {self.target_size}") print(f" - Device: {device}") print(f" - Providers: {self.session.get_providers()}") except Exception as e: print(f"❌ Failed to initialize PaddleX Layout Detector: {e}") raise def cleanup(self): """清理资源""" self.session = None self.inputs = {} self.outputs = {} def detect(self, image: Union[np.ndarray, Image.Image]) -> List[Dict[str, Any]]: """ 检测布局 Args: image: 输入图像 (numpy数组或PIL图像) Returns: 检测结果列表,每个元素包含: - category: MinerU类别名称 - bbox: [x1, y1, x2, y2] - confidence: 置信度 - raw: 原始检测结果 """ if self.session is None: raise RuntimeError("Model not initialized. Call initialize() first.") # 转换为numpy数组 if isinstance(image, Image.Image): image = np.array(image) if image.ndim == 2: # 灰度图 image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) elif image.shape[2] == 4: # RGBA image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGR) elif image.shape[2] == 3: # RGB image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) # 执行预测 conf_threshold = self.config.get('conf', 0.25) results = self._predict(image, conf_threshold) # 转换为 MinerU 格式 formatted_results = [] for result in results: # 映射类别 original_category_id = result['category_id'] mineru_category = self.CATEGORY_MAP.get(original_category_id, 'text') formatted_results.append({ 'category': mineru_category, 'bbox': result['bbox'], 'confidence': result['score'], 'raw': { 'original_category_id': original_category_id, 'original_category_name': result['category_name'], 'poly': result['poly'], 'width': result['width'], 'height': result['height'] } }) return formatted_results def _detect_input_size(self) -> int: """自动检测模型的输入尺寸""" if 'image' in self.inputs: shape = self.inputs['image'].shape # shape 通常是 [batch, channels, height, width] if len(shape) >= 3: # 尝试从 shape[2] 或 shape[3] 获取尺寸 for dim in shape[2:]: if isinstance(dim, int) and dim > 0: return dim return 640 # 默认值 def _preprocess( self, img: np.ndarray ) -> Tuple[Dict[str, np.ndarray], Tuple[float, float], Tuple[int, int]]: """ 预处理图像 (根据 RT-DETR 的配置) Returns: input_dict: 包含所有输入的字典 scale: (scale_h, scale_w) 缩放因子 orig_shape: (h, w) 原始图像尺寸 """ orig_h, orig_w = img.shape[:2] target_size = self.target_size # 640 # 1. Resize 到目标尺寸 (不保持长宽比) img_resized = cv2.resize( img, (target_size, target_size), interpolation=cv2.INTER_LINEAR ) # 2. 转换为 RGB img_rgb = cv2.cvtColor(img_resized, cv2.COLOR_BGR2RGB) # ✅ 修正 3: 归一化 (mean=[0,0,0], std=[1,1,1], norm_type=none) # 只做 /255,不做均值减法和标准差除法 img_normalized = img_rgb.astype(np.float32) / 255.0 # 4. 转换为 CHW 格式 img_chw = img_normalized.transpose(2, 0, 1) img_tensor = img_chw[None, ...].astype(np.float32) # [1, 3, H, W] # 5. 准备所有输入 input_dict = {} # 主图像输入 if 'image' in self.inputs: input_dict['image'] = img_tensor elif 'images' in self.inputs: input_dict['images'] = img_tensor else: # 使用第一个输入 first_input_name = list(self.inputs.keys())[0] input_dict[first_input_name] = img_tensor # ✅ 修正 4: 计算缩放因子 (实际图像尺寸 / 目标尺寸) scale_h = orig_h / target_size scale_w = orig_w / target_size # im_shape 输入 (原始图像尺寸) if 'im_shape' in self.inputs: im_shape = np.array([[float(orig_h), float(orig_w)]], dtype=np.float32) input_dict['im_shape'] = im_shape # scale_factor 输入 if 'scale_factor' in self.inputs: # ⚠️ 注意:这里是原始尺寸/目标尺寸的比例 scale_factor = np.array([[scale_h, scale_w]], dtype=np.float32) input_dict['scale_factor'] = scale_factor # ✅ 返回的 scale 用于后处理坐标还原 # 因为不保持长宽比,所以需要分别记录 x 和 y 的缩放 return input_dict, (scale_h, scale_w), (orig_h, orig_w) def _postprocess( self, outputs: List[np.ndarray], scale: Tuple[float, float], # (scale_h, scale_w) orig_shape: Tuple[int, int], conf_threshold: float = 0.5 ) -> List[Dict]: """ 后处理模型输出 Args: outputs: ONNX 模型输出 scale: (scale_h, scale_w) 缩放因子 orig_shape: (h, w) 原始图像尺寸 conf_threshold: 置信度阈值 Returns: 检测结果列表 """ scale_h, scale_w = scale orig_h, orig_w = orig_shape # 解析输出格式 if len(outputs) >= 2: output0_shape = outputs[0].shape output1_shape = outputs[1].shape # RT-DETR ONNX 格式: (num_boxes, 6) # 格式: [label_id, score, x1, y1, x2, y2] if len(output0_shape) == 2 and output0_shape[1] == 6: pred = outputs[0] labels = pred[:, 0].astype(int) scores = pred[:, 1] bboxes = pred[:, 2:6].copy() # [x1, y1, x2, y2] - 在 640×640 尺度上 # 情况2: output0 是 (batch, num_boxes, 6) - 带batch的合并格式 elif len(output0_shape) == 3 and output0_shape[2] == 6: pred = outputs[0][0] labels = pred[:, 0].astype(int) scores = pred[:, 1] bboxes = pred[:, 2:6].copy() # 情况3: output0 是 bboxes, output1 是 scores (分离格式) elif len(output0_shape) == 2 and output0_shape[1] == 4: bboxes = outputs[0].copy() if len(output1_shape) == 1: scores = outputs[1] labels = np.zeros(len(scores), dtype=int) elif len(output1_shape) == 2: scores_all = outputs[1] scores = scores_all.max(axis=1) labels = scores_all.argmax(axis=1) else: raise ValueError(f"Unexpected output1 shape: {output1_shape}") # 情况4: RT-DETR 格式 (batch, num_boxes, 4) + (batch, num_boxes, num_classes) elif len(output0_shape) == 3 and output0_shape[2] == 4: bboxes = outputs[0][0].copy() scores_all = outputs[1][0] scores = scores_all.max(axis=1) labels = scores_all.argmax(axis=1) else: raise ValueError(f"Unexpected output format: {output0_shape}, {output1_shape}") elif len(outputs) == 1: # 单一输出 output_shape = outputs[0].shape if len(output_shape) == 2 and output_shape[1] == 6: pred = outputs[0] labels = pred[:, 0].astype(int) scores = pred[:, 1] bboxes = pred[:, 2:6].copy() elif len(output_shape) == 3 and output_shape[2] == 6: pred = outputs[0][0] labels = pred[:, 0].astype(int) scores = pred[:, 1] bboxes = pred[:, 2:6].copy() else: raise ValueError(f"Unexpected single output shape: {output_shape}") else: raise ValueError(f"Unexpected number of outputs: {len(outputs)}") # 将坐标从 640×640 还原到原图尺度 bboxes[:, [0, 2]] *= scale_w bboxes[:, [1, 3]] *= scale_h # 自适应阈值 max_score = scores.max() if len(scores) > 0 else 0 if max_score < conf_threshold: adjusted_threshold = max(max_score * 0.5, 0.05) conf_threshold = adjusted_threshold # 过滤低分框 mask = scores > conf_threshold bboxes = bboxes[mask] scores = scores[mask] labels = labels[mask] # 过滤完全在图像外的框 valid_mask = ( (bboxes[:, 2] > 0) & # x2 > 0 (bboxes[:, 3] > 0) & # y2 > 0 (bboxes[:, 0] < orig_w) & # x1 < width (bboxes[:, 1] < orig_h) # y1 < height ) bboxes = bboxes[valid_mask] scores = scores[valid_mask] labels = labels[valid_mask] # 裁剪坐标到图像范围 bboxes[:, [0, 2]] = np.clip(bboxes[:, [0, 2]], 0, orig_w) bboxes[:, [1, 3]] = np.clip(bboxes[:, [1, 3]], 0, orig_h) # 构造结果 results = [] for box, score, label in zip(bboxes, scores, labels): x1, y1, x2, y2 = box # 过滤无效框 width = x2 - x1 height = y2 - y1 # 过滤太小的框 if width < 10 or height < 10: continue # 过滤面积异常大的框 area = width * height img_area = orig_w * orig_h if area > img_area * 0.95: continue results.append({ 'category_id': int(label), 'category_name': self.ORIGINAL_CATEGORY_NAMES.get(int(label), f'unknown_{label}'), 'bbox': [int(x1), int(y1), int(x2), int(y2)], 'poly': [int(x1), int(y1), int(x2), int(y1), int(x2), int(y2), int(x1), int(y2)], 'score': float(score), 'width': int(width), 'height': int(height) }) return results def _predict( self, img: np.ndarray, conf_threshold: float = 0.25 ) -> List[Dict]: """执行预测""" # 预处理 input_dict, scale, orig_shape = self._preprocess(img) # ONNX 推理 output_names = [out.name for out in self.session.get_outputs()] outputs = self.session.run(output_names, input_dict) # 后处理 results = self._postprocess(outputs, scale, orig_shape, conf_threshold) return results def visualize( self, img: np.ndarray, results: List[Dict], output_path: str = None, show_confidence: bool = True, min_confidence: float = 0.0 ) -> np.ndarray: """ 可视化检测结果 Args: img: 输入图像 results: 检测结果 (MinerU格式) output_path: 输出路径(可选) show_confidence: 是否显示置信度 min_confidence: 最小置信度阈值 Returns: 标注后的图像 """ import random vis_img = img.copy() # 为每个类别分配固定颜色 category_colors = {} # 预定义一些常用类别的颜色 predefined_colors = { 'text': (0, 255, 0), # 绿色 'title': (255, 0, 0), # 红色 'table_body': (0, 0, 255), # 蓝色 'table_caption': (255, 255, 0), # 青色 'table_footnote': (255, 128, 0), # 橙色 'image_body': (255, 0, 255), # 洋红 'image_caption': (128, 0, 255), # 紫色 'interline_equation': (0, 255, 255), # 黄色 'abandon': (128, 128, 128), # 灰色 } # 过滤低置信度结果 filtered_results = [ res for res in results if res['confidence'] >= min_confidence ] if not filtered_results: print(f"⚠️ No results to visualize (min_confidence={min_confidence})") return vis_img # 为每个出现的类别分配颜色 for res in filtered_results: cat = res['category'] if cat not in category_colors: if cat in predefined_colors: category_colors[cat] = predefined_colors[cat] else: # 随机生成颜色 category_colors[cat] = ( random.randint(50, 255), random.randint(50, 255), random.randint(50, 255) ) # 绘制检测框 for res in filtered_results: bbox = res['bbox'] x1, y1, x2, y2 = bbox cat = res['category'] confidence = res['confidence'] color = category_colors[cat] # 绘制矩形边框 cv2.rectangle(vis_img, (x1, y1), (x2, y2), color, 2) # 构造标签文本 if show_confidence: label = f"{cat} {confidence:.2f}" else: label = cat # 计算标签尺寸 label_size, baseline = cv2.getTextSize( label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1 ) label_w, label_h = label_size # 绘制标签背景 (填充矩形) cv2.rectangle( vis_img, (x1, y1 - label_h - 4), (x1 + label_w, y1), color, -1 # 填充 ) # 绘制标签文字 (白色) cv2.putText( vis_img, label, (x1, y1 - 2), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), # 白色文字 1, cv2.LINE_AA ) # 添加图例 (在图像右上角) if category_colors: self._draw_legend(vis_img, category_colors, len(filtered_results)) # 保存可视化结果 if output_path: output_path = Path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) cv2.imwrite(str(output_path), vis_img) print(f"💾 Visualization saved to: {output_path}") return vis_img def _draw_legend( self, img: np.ndarray, category_colors: Dict[str, tuple], total_count: int ): """ 在图像上绘制图例 Args: img: 图像 category_colors: 类别颜色映射 total_count: 总检测数量 """ legend_x = img.shape[1] - 200 # 右侧留200像素 legend_y = 20 line_height = 25 # 绘制半透明背景 overlay = img.copy() cv2.rectangle( overlay, (legend_x - 10, legend_y - 10), (img.shape[1] - 10, legend_y + len(category_colors) * line_height + 30), (255, 255, 255), -1 ) cv2.addWeighted(overlay, 0.7, img, 0.3, 0, img) # 绘制标题 cv2.putText( img, f"Legend ({total_count} total)", (legend_x, legend_y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA ) # 绘制每个类别 y_offset = legend_y + line_height for cat, color in sorted(category_colors.items()): # 绘制颜色方块 cv2.rectangle( img, (legend_x, y_offset - 10), (legend_x + 15, y_offset), color, -1 ) cv2.rectangle( img, (legend_x, y_offset - 10), (legend_x + 15, y_offset), (0, 0, 0), 1 ) # 绘制类别名称 cv2.putText( img, cat, (legend_x + 20, y_offset - 2), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 0), 1, cv2.LINE_AA ) y_offset += line_height # 测试代码 if __name__ == "__main__": import yaml # 测试配置 config = { 'model_dir': '/Users/zhch158/workspace/repository.git/PaddleX/zhch/unified_pytorch_models/Layout/RT-DETR-H_layout_17cls.onnx', 'device': 'cpu', 'conf': 0.25 } # 初始化检测器 print("🔧 Initializing detector...") detector = PaddleLayoutDetector(config) detector.initialize() # 读取测试图像 img_path = "/Users/zhch158/workspace/data/流水分析/B用户_扫描流水/PaddleOCR_VL_Results/B用户_扫描流水/B用户_扫描流水_page_001.png" print(f"\n📖 Loading image: {img_path}") img = cv2.imread(img_path) if img is None: print(f"❌ Failed to load image: {img_path}") exit(1) print(f" Image shape: {img.shape}") # 执行检测 print("\n🔍 Detecting layout...") results = detector.detect(img) print(f"\n✅ 检测到 {len(results)} 个区域:") for i, res in enumerate(results, 1): print(f" [{i}] {res['category']}: " f"score={res['confidence']:.3f}, " f"bbox={res['bbox']}, " f"original={res['raw']['original_category_name']}") # 统计各类别 category_counts = {} for res in results: cat = res['category'] category_counts[cat] = category_counts.get(cat, 0) + 1 print(f"\n📊 类别统计 (MinerU格式):") for cat, count in sorted(category_counts.items()): print(f" - {cat}: {count}") # 使用新的可视化方法 if len(results) > 0: print("\n🎨 Generating visualization...") # 创建输出目录 output_dir = Path(__file__).parent.parent.parent / "tests" / "output" output_dir.mkdir(parents=True, exist_ok=True) output_path = output_dir / f"{Path(img_path).stem}_layout_vis.jpg" # 调用可视化方法 vis_img = detector.visualize( img, results, output_path=str(output_path), show_confidence=True, min_confidence=0.0 ) print(f"💾 Visualization saved to: {output_path}") # 清理 detector.cleanup() print("\n✅ 测试完成!")