| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679 |
- """使用 ONNX Runtime 进行布局检测的统一接口 (符合 BaseLayoutDetector 规范)"""
- import cv2
- import numpy as np
- import onnxruntime as ort
- from pathlib import Path
- from typing import Dict, List, Tuple, Union, Any
- from PIL import Image
- import sys
- try:
- from .base import BaseLayoutDetector
- except ImportError:
- # 如果相对导入失败,尝试绝对导入(适用于测试环境)
- from base import BaseLayoutDetector
- class PaddleLayoutDetector(BaseLayoutDetector):
- """PaddleX RT-DETR 布局检测器 (ONNX 版本)"""
-
- # ⚠️ 修正:使用官方的 RT-DETR-H_layout_17cls 类别定义
- # 映射到 MinerU 的类别体系
- CATEGORY_MAP = {
- 0: 'title', # paragraph_title -> title
- 1: 'image_body', # image -> image_body
- 2: 'text', # text -> text
- 3: 'text', # number -> text (合并到text)
- 4: 'text', # abstract -> text
- 5: 'text', # content -> text
- 6: 'image_caption', # figure_title -> image_caption
- 7: 'interline_equation', # formula -> interline_equation
- 8: 'table_body', # table -> table_body
- 9: 'table_caption', # table_title -> table_caption
- 10: 'text', # reference -> text
- 11: 'title', # doc_title -> title
- 12: 'table_footnote', # footnote -> table_footnote
- 13: 'abandon', # header -> abandon (页眉通常不需要)
- 14: 'text', # algorithm -> text
- 15: 'abandon', # footer -> abandon (页脚通常不需要)
- 16: 'abandon' # seal -> abandon (印章通常不需要)
- }
-
- ORIGINAL_CATEGORY_NAMES = {
- 0: 'paragraph_title',
- 1: 'image',
- 2: 'text',
- 3: 'number',
- 4: 'abstract',
- 5: 'content',
- 6: 'figure_title',
- 7: 'formula',
- 8: 'table',
- 9: 'table_title',
- 10: 'reference',
- 11: 'doc_title',
- 12: 'footnote',
- 13: 'header',
- 14: 'algorithm',
- 15: 'footer',
- 16: 'seal'
- }
-
- def __init__(self, config: Dict[str, Any]):
- super().__init__(config)
- self.session = None
- self.inputs = {}
- self.outputs = {}
- self.target_size = 640
-
- def initialize(self):
- """初始化 ONNX 模型"""
- try:
- onnx_path = self.config.get('model_dir')
- if not onnx_path:
- raise ValueError("model_dir not specified in config")
-
- if not Path(onnx_path).exists():
- raise FileNotFoundError(f"ONNX model not found: {onnx_path}")
-
- # 根据配置选择执行提供器
- device = self.config.get('device', 'cpu')
- if device == 'gpu':
- # Mac 支持 CoreML
- providers = ['CoreMLExecutionProvider', 'CPUExecutionProvider']
- else:
- providers = ['CPUExecutionProvider']
-
- self.session = ort.InferenceSession(onnx_path, providers=providers)
-
- # 获取模型输入输出信息
- self.inputs = {inp.name: inp for inp in self.session.get_inputs()}
- self.outputs = {out.name: out for out in self.session.get_outputs()}
-
- # 自动检测输入尺寸
- self.target_size = self._detect_input_size()
-
- print(f"✅ PaddleX Layout Detector initialized")
- print(f" - Model: {Path(onnx_path).name}")
- print(f" - Target size: {self.target_size}")
- print(f" - Device: {device}")
- print(f" - Providers: {self.session.get_providers()}")
-
- except Exception as e:
- print(f"❌ Failed to initialize PaddleX Layout Detector: {e}")
- raise
-
- def cleanup(self):
- """清理资源"""
- self.session = None
- self.inputs = {}
- self.outputs = {}
-
- def detect(self, image: Union[np.ndarray, Image.Image]) -> List[Dict[str, Any]]:
- """
- 检测布局
-
- Args:
- image: 输入图像 (numpy数组或PIL图像)
-
- Returns:
- 检测结果列表,每个元素包含:
- - category: MinerU类别名称
- - bbox: [x1, y1, x2, y2]
- - confidence: 置信度
- - raw: 原始检测结果
- """
- if self.session is None:
- raise RuntimeError("Model not initialized. Call initialize() first.")
-
- # 转换为numpy数组
- if isinstance(image, Image.Image):
- image = np.array(image)
- if image.ndim == 2: # 灰度图
- image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
- elif image.shape[2] == 4: # RGBA
- image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGR)
- elif image.shape[2] == 3: # RGB
- image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
-
- # 执行预测
- conf_threshold = self.config.get('conf', 0.25)
- results = self._predict(image, conf_threshold)
-
- # 转换为 MinerU 格式
- formatted_results = []
- for result in results:
- # 映射类别
- original_category_id = result['category_id']
- mineru_category = self.CATEGORY_MAP.get(original_category_id, 'text')
-
- formatted_results.append({
- 'category': mineru_category,
- 'bbox': result['bbox'],
- 'confidence': result['score'],
- 'raw': {
- 'original_category_id': original_category_id,
- 'original_category_name': result['category_name'],
- 'poly': result['poly'],
- 'width': result['width'],
- 'height': result['height']
- }
- })
-
- return formatted_results
-
- def _detect_input_size(self) -> int:
- """自动检测模型的输入尺寸"""
- if 'image' in self.inputs:
- shape = self.inputs['image'].shape
- # shape 通常是 [batch, channels, height, width]
- if len(shape) >= 3:
- # 尝试从 shape[2] 或 shape[3] 获取尺寸
- for dim in shape[2:]:
- if isinstance(dim, int) and dim > 0:
- return dim
- return 640 # 默认值
-
- def _preprocess(
- self,
- img: np.ndarray
- ) -> Tuple[Dict[str, np.ndarray], Tuple[float, float], Tuple[int, int]]:
- """
- 预处理图像 (根据 RT-DETR 的配置)
-
- Returns:
- input_dict: 包含所有输入的字典
- scale: (scale_h, scale_w) 缩放因子
- orig_shape: (h, w) 原始图像尺寸
- """
- orig_h, orig_w = img.shape[:2]
- target_size = self.target_size # 640
-
- # 1. Resize 到目标尺寸 (不保持长宽比)
- img_resized = cv2.resize(
- img,
- (target_size, target_size),
- interpolation=cv2.INTER_LINEAR
- )
-
- # 2. 转换为 RGB
- img_rgb = cv2.cvtColor(img_resized, cv2.COLOR_BGR2RGB)
-
- # ✅ 修正 3: 归一化 (mean=[0,0,0], std=[1,1,1], norm_type=none)
- # 只做 /255,不做均值减法和标准差除法
- img_normalized = img_rgb.astype(np.float32) / 255.0
-
- # 4. 转换为 CHW 格式
- img_chw = img_normalized.transpose(2, 0, 1)
- img_tensor = img_chw[None, ...].astype(np.float32) # [1, 3, H, W]
-
- # 5. 准备所有输入
- input_dict = {}
-
- # 主图像输入
- if 'image' in self.inputs:
- input_dict['image'] = img_tensor
- elif 'images' in self.inputs:
- input_dict['images'] = img_tensor
- else:
- # 使用第一个输入
- first_input_name = list(self.inputs.keys())[0]
- input_dict[first_input_name] = img_tensor
-
- # ✅ 修正 4: 计算缩放因子 (实际图像尺寸 / 目标尺寸)
- scale_h = orig_h / target_size
- scale_w = orig_w / target_size
-
- # im_shape 输入 (原始图像尺寸)
- if 'im_shape' in self.inputs:
- im_shape = np.array([[float(orig_h), float(orig_w)]], dtype=np.float32)
- input_dict['im_shape'] = im_shape
-
- # scale_factor 输入
- if 'scale_factor' in self.inputs:
- # ⚠️ 注意:这里是原始尺寸/目标尺寸的比例
- scale_factor = np.array([[scale_h, scale_w]], dtype=np.float32)
- input_dict['scale_factor'] = scale_factor
-
- # ✅ 返回的 scale 用于后处理坐标还原
- # 因为不保持长宽比,所以需要分别记录 x 和 y 的缩放
- return input_dict, (scale_h, scale_w), (orig_h, orig_w)
-
- def _postprocess(
- self,
- outputs: List[np.ndarray],
- scale: Tuple[float, float], # (scale_h, scale_w)
- orig_shape: Tuple[int, int],
- conf_threshold: float = 0.5
- ) -> List[Dict]:
- """
- 后处理模型输出
-
- Args:
- outputs: ONNX 模型输出
- scale: (scale_h, scale_w) 缩放因子
- orig_shape: (h, w) 原始图像尺寸
- conf_threshold: 置信度阈值
-
- Returns:
- 检测结果列表
- """
- scale_h, scale_w = scale
- orig_h, orig_w = orig_shape
-
- # 解析输出格式
- if len(outputs) >= 2:
- output0_shape = outputs[0].shape
- output1_shape = outputs[1].shape
-
- # RT-DETR ONNX 格式: (num_boxes, 6)
- # 格式: [label_id, score, x1, y1, x2, y2]
- if len(output0_shape) == 2 and output0_shape[1] == 6:
- pred = outputs[0]
- labels = pred[:, 0].astype(int)
- scores = pred[:, 1]
- bboxes = pred[:, 2:6].copy() # [x1, y1, x2, y2] - 在 640×640 尺度上
-
- # 情况2: output0 是 (batch, num_boxes, 6) - 带batch的合并格式
- elif len(output0_shape) == 3 and output0_shape[2] == 6:
- pred = outputs[0][0]
- labels = pred[:, 0].astype(int)
- scores = pred[:, 1]
- bboxes = pred[:, 2:6].copy()
-
- # 情况3: output0 是 bboxes, output1 是 scores (分离格式)
- elif len(output0_shape) == 2 and output0_shape[1] == 4:
- bboxes = outputs[0].copy()
- if len(output1_shape) == 1:
- scores = outputs[1]
- labels = np.zeros(len(scores), dtype=int)
- elif len(output1_shape) == 2:
- scores_all = outputs[1]
- scores = scores_all.max(axis=1)
- labels = scores_all.argmax(axis=1)
- else:
- raise ValueError(f"Unexpected output1 shape: {output1_shape}")
-
- # 情况4: RT-DETR 格式 (batch, num_boxes, 4) + (batch, num_boxes, num_classes)
- elif len(output0_shape) == 3 and output0_shape[2] == 4:
- bboxes = outputs[0][0].copy()
- scores_all = outputs[1][0]
- scores = scores_all.max(axis=1)
- labels = scores_all.argmax(axis=1)
-
- else:
- raise ValueError(f"Unexpected output format: {output0_shape}, {output1_shape}")
-
- elif len(outputs) == 1:
- # 单一输出
- output_shape = outputs[0].shape
-
- if len(output_shape) == 2 and output_shape[1] == 6:
- pred = outputs[0]
- labels = pred[:, 0].astype(int)
- scores = pred[:, 1]
- bboxes = pred[:, 2:6].copy()
-
- elif len(output_shape) == 3 and output_shape[2] == 6:
- pred = outputs[0][0]
- labels = pred[:, 0].astype(int)
- scores = pred[:, 1]
- bboxes = pred[:, 2:6].copy()
-
- else:
- raise ValueError(f"Unexpected single output shape: {output_shape}")
-
- else:
- raise ValueError(f"Unexpected number of outputs: {len(outputs)}")
-
- # 将坐标从 640×640 还原到原图尺度
- bboxes[:, [0, 2]] *= scale_w
- bboxes[:, [1, 3]] *= scale_h
-
- # 自适应阈值
- max_score = scores.max() if len(scores) > 0 else 0
- if max_score < conf_threshold:
- adjusted_threshold = max(max_score * 0.5, 0.05)
- conf_threshold = adjusted_threshold
-
- # 过滤低分框
- mask = scores > conf_threshold
- bboxes = bboxes[mask]
- scores = scores[mask]
- labels = labels[mask]
-
- # 过滤完全在图像外的框
- valid_mask = (
- (bboxes[:, 2] > 0) & # x2 > 0
- (bboxes[:, 3] > 0) & # y2 > 0
- (bboxes[:, 0] < orig_w) & # x1 < width
- (bboxes[:, 1] < orig_h) # y1 < height
- )
- bboxes = bboxes[valid_mask]
- scores = scores[valid_mask]
- labels = labels[valid_mask]
-
- # 裁剪坐标到图像范围
- bboxes[:, [0, 2]] = np.clip(bboxes[:, [0, 2]], 0, orig_w)
- bboxes[:, [1, 3]] = np.clip(bboxes[:, [1, 3]], 0, orig_h)
-
- # 构造结果
- results = []
- for box, score, label in zip(bboxes, scores, labels):
- x1, y1, x2, y2 = box
-
- # 过滤无效框
- width = x2 - x1
- height = y2 - y1
-
- # 过滤太小的框
- if width < 10 or height < 10:
- continue
-
- # 过滤面积异常大的框
- area = width * height
- img_area = orig_w * orig_h
- if area > img_area * 0.95:
- continue
-
- results.append({
- 'category_id': int(label),
- 'category_name': self.ORIGINAL_CATEGORY_NAMES.get(int(label), f'unknown_{label}'),
- 'bbox': [int(x1), int(y1), int(x2), int(y2)],
- 'poly': [int(x1), int(y1), int(x2), int(y1), int(x2), int(y2), int(x1), int(y2)],
- 'score': float(score),
- 'width': int(width),
- 'height': int(height)
- })
-
- return results
-
- def _predict(
- self,
- img: np.ndarray,
- conf_threshold: float = 0.25
- ) -> List[Dict]:
- """执行预测"""
- # 预处理
- input_dict, scale, orig_shape = self._preprocess(img)
-
- # ONNX 推理
- output_names = [out.name for out in self.session.get_outputs()]
- outputs = self.session.run(output_names, input_dict)
-
- # 后处理
- results = self._postprocess(outputs, scale, orig_shape, conf_threshold)
-
- return results
- def visualize(
- self,
- img: np.ndarray,
- results: List[Dict],
- output_path: str = None,
- show_confidence: bool = True,
- min_confidence: float = 0.0
- ) -> np.ndarray:
- """
- 可视化检测结果
-
- Args:
- img: 输入图像
- results: 检测结果 (MinerU格式)
- output_path: 输出路径(可选)
- show_confidence: 是否显示置信度
- min_confidence: 最小置信度阈值
-
- Returns:
- 标注后的图像
- """
- import random
-
- vis_img = img.copy()
-
- # 为每个类别分配固定颜色
- category_colors = {}
-
- # 预定义一些常用类别的颜色
- predefined_colors = {
- 'text': (0, 255, 0), # 绿色
- 'title': (255, 0, 0), # 红色
- 'table_body': (0, 0, 255), # 蓝色
- 'table_caption': (255, 255, 0), # 青色
- 'table_footnote': (255, 128, 0), # 橙色
- 'image_body': (255, 0, 255), # 洋红
- 'image_caption': (128, 0, 255), # 紫色
- 'interline_equation': (0, 255, 255), # 黄色
- 'abandon': (128, 128, 128), # 灰色
- }
-
- # 过滤低置信度结果
- filtered_results = [
- res for res in results
- if res['confidence'] >= min_confidence
- ]
-
- if not filtered_results:
- print(f"⚠️ No results to visualize (min_confidence={min_confidence})")
- return vis_img
-
- # 为每个出现的类别分配颜色
- for res in filtered_results:
- cat = res['category']
- if cat not in category_colors:
- if cat in predefined_colors:
- category_colors[cat] = predefined_colors[cat]
- else:
- # 随机生成颜色
- category_colors[cat] = (
- random.randint(50, 255),
- random.randint(50, 255),
- random.randint(50, 255)
- )
-
- # 绘制检测框
- for res in filtered_results:
- bbox = res['bbox']
- x1, y1, x2, y2 = bbox
- cat = res['category']
- confidence = res['confidence']
- color = category_colors[cat]
-
- # 绘制矩形边框
- cv2.rectangle(vis_img, (x1, y1), (x2, y2), color, 2)
-
- # 构造标签文本
- if show_confidence:
- label = f"{cat} {confidence:.2f}"
- else:
- label = cat
-
- # 计算标签尺寸
- label_size, baseline = cv2.getTextSize(
- label,
- cv2.FONT_HERSHEY_SIMPLEX,
- 0.5,
- 1
- )
- label_w, label_h = label_size
-
- # 绘制标签背景 (填充矩形)
- cv2.rectangle(
- vis_img,
- (x1, y1 - label_h - 4),
- (x1 + label_w, y1),
- color,
- -1 # 填充
- )
-
- # 绘制标签文字 (白色)
- cv2.putText(
- vis_img,
- label,
- (x1, y1 - 2),
- cv2.FONT_HERSHEY_SIMPLEX,
- 0.5,
- (255, 255, 255), # 白色文字
- 1,
- cv2.LINE_AA
- )
-
- # 添加图例 (在图像右上角)
- if category_colors:
- self._draw_legend(vis_img, category_colors, len(filtered_results))
-
- # 保存可视化结果
- if output_path:
- output_path = Path(output_path)
- output_path.parent.mkdir(parents=True, exist_ok=True)
- cv2.imwrite(str(output_path), vis_img)
- print(f"💾 Visualization saved to: {output_path}")
-
- return vis_img
-
- def _draw_legend(
- self,
- img: np.ndarray,
- category_colors: Dict[str, tuple],
- total_count: int
- ):
- """
- 在图像上绘制图例
-
- Args:
- img: 图像
- category_colors: 类别颜色映射
- total_count: 总检测数量
- """
- legend_x = img.shape[1] - 200 # 右侧留200像素
- legend_y = 20
- line_height = 25
-
- # 绘制半透明背景
- overlay = img.copy()
- cv2.rectangle(
- overlay,
- (legend_x - 10, legend_y - 10),
- (img.shape[1] - 10, legend_y + len(category_colors) * line_height + 30),
- (255, 255, 255),
- -1
- )
- cv2.addWeighted(overlay, 0.7, img, 0.3, 0, img)
-
- # 绘制标题
- cv2.putText(
- img,
- f"Legend ({total_count} total)",
- (legend_x, legend_y),
- cv2.FONT_HERSHEY_SIMPLEX,
- 0.5,
- (0, 0, 0),
- 1,
- cv2.LINE_AA
- )
-
- # 绘制每个类别
- y_offset = legend_y + line_height
- for cat, color in sorted(category_colors.items()):
- # 绘制颜色方块
- cv2.rectangle(
- img,
- (legend_x, y_offset - 10),
- (legend_x + 15, y_offset),
- color,
- -1
- )
- cv2.rectangle(
- img,
- (legend_x, y_offset - 10),
- (legend_x + 15, y_offset),
- (0, 0, 0),
- 1
- )
-
- # 绘制类别名称
- cv2.putText(
- img,
- cat,
- (legend_x + 20, y_offset - 2),
- cv2.FONT_HERSHEY_SIMPLEX,
- 0.4,
- (0, 0, 0),
- 1,
- cv2.LINE_AA
- )
-
- y_offset += line_height
- # 测试代码
- if __name__ == "__main__":
- import yaml
-
- # 测试配置
- config = {
- 'model_dir': '/Users/zhch158/workspace/repository.git/PaddleX/zhch/unified_pytorch_models/Layout/RT-DETR-H_layout_17cls.onnx',
- 'device': 'cpu',
- 'conf': 0.25
- }
-
- # 初始化检测器
- print("🔧 Initializing detector...")
- detector = PaddleLayoutDetector(config)
- detector.initialize()
-
- # 读取测试图像
- img_path = "/Users/zhch158/workspace/data/流水分析/B用户_扫描流水/PaddleOCR_VL_Results/B用户_扫描流水/B用户_扫描流水_page_001.png"
- print(f"\n📖 Loading image: {img_path}")
- img = cv2.imread(img_path)
-
- if img is None:
- print(f"❌ Failed to load image: {img_path}")
- exit(1)
-
- print(f" Image shape: {img.shape}")
-
- # 执行检测
- print("\n🔍 Detecting layout...")
- results = detector.detect(img)
-
- print(f"\n✅ 检测到 {len(results)} 个区域:")
- for i, res in enumerate(results, 1):
- print(f" [{i}] {res['category']}: "
- f"score={res['confidence']:.3f}, "
- f"bbox={res['bbox']}, "
- f"original={res['raw']['original_category_name']}")
-
- # 统计各类别
- category_counts = {}
- for res in results:
- cat = res['category']
- category_counts[cat] = category_counts.get(cat, 0) + 1
-
- print(f"\n📊 类别统计 (MinerU格式):")
- for cat, count in sorted(category_counts.items()):
- print(f" - {cat}: {count}")
-
- # 使用新的可视化方法
- if len(results) > 0:
- print("\n🎨 Generating visualization...")
-
- # 创建输出目录
- output_dir = Path(__file__).parent.parent.parent / "tests" / "output"
- output_dir.mkdir(parents=True, exist_ok=True)
- output_path = output_dir / f"{Path(img_path).stem}_layout_vis.jpg"
-
- # 调用可视化方法
- vis_img = detector.visualize(
- img,
- results,
- output_path=str(output_path),
- show_confidence=True,
- min_confidence=0.0
- )
-
- print(f"💾 Visualization saved to: {output_path}")
-
- # 清理
- detector.cleanup()
- print("\n✅ 测试完成!")
|