|
|
@@ -0,0 +1,679 @@
|
|
|
+"""使用 ONNX Runtime 进行布局检测的统一接口 (符合 BaseLayoutDetector 规范)"""
|
|
|
+
|
|
|
+import cv2
|
|
|
+import numpy as np
|
|
|
+import onnxruntime as ort
|
|
|
+from pathlib import Path
|
|
|
+from typing import Dict, List, Tuple, Union, Any
|
|
|
+from PIL import Image
|
|
|
+import sys
|
|
|
+
|
|
|
+try:
|
|
|
+ from .base import BaseLayoutDetector
|
|
|
+except ImportError:
|
|
|
+ # 如果相对导入失败,尝试绝对导入(适用于测试环境)
|
|
|
+ from base import BaseLayoutDetector
|
|
|
+
|
|
|
+class PaddleLayoutDetector(BaseLayoutDetector):
|
|
|
+ """PaddleX RT-DETR 布局检测器 (ONNX 版本)"""
|
|
|
+
|
|
|
+ # ⚠️ 修正:使用官方的 RT-DETR-H_layout_17cls 类别定义
|
|
|
+ # 映射到 MinerU 的类别体系
|
|
|
+ CATEGORY_MAP = {
|
|
|
+ 0: 'title', # paragraph_title -> title
|
|
|
+ 1: 'image_body', # image -> image_body
|
|
|
+ 2: 'text', # text -> text
|
|
|
+ 3: 'text', # number -> text (合并到text)
|
|
|
+ 4: 'text', # abstract -> text
|
|
|
+ 5: 'text', # content -> text
|
|
|
+ 6: 'image_caption', # figure_title -> image_caption
|
|
|
+ 7: 'interline_equation', # formula -> interline_equation
|
|
|
+ 8: 'table_body', # table -> table_body
|
|
|
+ 9: 'table_caption', # table_title -> table_caption
|
|
|
+ 10: 'text', # reference -> text
|
|
|
+ 11: 'title', # doc_title -> title
|
|
|
+ 12: 'table_footnote', # footnote -> table_footnote
|
|
|
+ 13: 'abandon', # header -> abandon (页眉通常不需要)
|
|
|
+ 14: 'text', # algorithm -> text
|
|
|
+ 15: 'abandon', # footer -> abandon (页脚通常不需要)
|
|
|
+ 16: 'abandon' # seal -> abandon (印章通常不需要)
|
|
|
+ }
|
|
|
+
|
|
|
+ ORIGINAL_CATEGORY_NAMES = {
|
|
|
+ 0: 'paragraph_title',
|
|
|
+ 1: 'image',
|
|
|
+ 2: 'text',
|
|
|
+ 3: 'number',
|
|
|
+ 4: 'abstract',
|
|
|
+ 5: 'content',
|
|
|
+ 6: 'figure_title',
|
|
|
+ 7: 'formula',
|
|
|
+ 8: 'table',
|
|
|
+ 9: 'table_title',
|
|
|
+ 10: 'reference',
|
|
|
+ 11: 'doc_title',
|
|
|
+ 12: 'footnote',
|
|
|
+ 13: 'header',
|
|
|
+ 14: 'algorithm',
|
|
|
+ 15: 'footer',
|
|
|
+ 16: 'seal'
|
|
|
+ }
|
|
|
+
|
|
|
+ def __init__(self, config: Dict[str, Any]):
|
|
|
+ super().__init__(config)
|
|
|
+ self.session = None
|
|
|
+ self.inputs = {}
|
|
|
+ self.outputs = {}
|
|
|
+ self.target_size = 640
|
|
|
+
|
|
|
+ def initialize(self):
|
|
|
+ """初始化 ONNX 模型"""
|
|
|
+ try:
|
|
|
+ onnx_path = self.config.get('model_dir')
|
|
|
+ if not onnx_path:
|
|
|
+ raise ValueError("model_dir not specified in config")
|
|
|
+
|
|
|
+ if not Path(onnx_path).exists():
|
|
|
+ raise FileNotFoundError(f"ONNX model not found: {onnx_path}")
|
|
|
+
|
|
|
+ # 根据配置选择执行提供器
|
|
|
+ device = self.config.get('device', 'cpu')
|
|
|
+ if device == 'gpu':
|
|
|
+ # Mac 支持 CoreML
|
|
|
+ providers = ['CoreMLExecutionProvider', 'CPUExecutionProvider']
|
|
|
+ else:
|
|
|
+ providers = ['CPUExecutionProvider']
|
|
|
+
|
|
|
+ self.session = ort.InferenceSession(onnx_path, providers=providers)
|
|
|
+
|
|
|
+ # 获取模型输入输出信息
|
|
|
+ self.inputs = {inp.name: inp for inp in self.session.get_inputs()}
|
|
|
+ self.outputs = {out.name: out for out in self.session.get_outputs()}
|
|
|
+
|
|
|
+ # 自动检测输入尺寸
|
|
|
+ self.target_size = self._detect_input_size()
|
|
|
+
|
|
|
+ print(f"✅ PaddleX Layout Detector initialized")
|
|
|
+ print(f" - Model: {Path(onnx_path).name}")
|
|
|
+ print(f" - Target size: {self.target_size}")
|
|
|
+ print(f" - Device: {device}")
|
|
|
+ print(f" - Providers: {self.session.get_providers()}")
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"❌ Failed to initialize PaddleX Layout Detector: {e}")
|
|
|
+ raise
|
|
|
+
|
|
|
+ def cleanup(self):
|
|
|
+ """清理资源"""
|
|
|
+ self.session = None
|
|
|
+ self.inputs = {}
|
|
|
+ self.outputs = {}
|
|
|
+
|
|
|
+ def detect(self, image: Union[np.ndarray, Image.Image]) -> List[Dict[str, Any]]:
|
|
|
+ """
|
|
|
+ 检测布局
|
|
|
+
|
|
|
+ Args:
|
|
|
+ image: 输入图像 (numpy数组或PIL图像)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 检测结果列表,每个元素包含:
|
|
|
+ - category: MinerU类别名称
|
|
|
+ - bbox: [x1, y1, x2, y2]
|
|
|
+ - confidence: 置信度
|
|
|
+ - raw: 原始检测结果
|
|
|
+ """
|
|
|
+ if self.session is None:
|
|
|
+ raise RuntimeError("Model not initialized. Call initialize() first.")
|
|
|
+
|
|
|
+ # 转换为numpy数组
|
|
|
+ if isinstance(image, Image.Image):
|
|
|
+ image = np.array(image)
|
|
|
+ if image.ndim == 2: # 灰度图
|
|
|
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
|
|
|
+ elif image.shape[2] == 4: # RGBA
|
|
|
+ image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGR)
|
|
|
+ elif image.shape[2] == 3: # RGB
|
|
|
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
|
|
+
|
|
|
+ # 执行预测
|
|
|
+ conf_threshold = self.config.get('conf', 0.25)
|
|
|
+ results = self._predict(image, conf_threshold)
|
|
|
+
|
|
|
+ # 转换为 MinerU 格式
|
|
|
+ formatted_results = []
|
|
|
+ for result in results:
|
|
|
+ # 映射类别
|
|
|
+ original_category_id = result['category_id']
|
|
|
+ mineru_category = self.CATEGORY_MAP.get(original_category_id, 'text')
|
|
|
+
|
|
|
+ formatted_results.append({
|
|
|
+ 'category': mineru_category,
|
|
|
+ 'bbox': result['bbox'],
|
|
|
+ 'confidence': result['score'],
|
|
|
+ 'raw': {
|
|
|
+ 'original_category_id': original_category_id,
|
|
|
+ 'original_category_name': result['category_name'],
|
|
|
+ 'poly': result['poly'],
|
|
|
+ 'width': result['width'],
|
|
|
+ 'height': result['height']
|
|
|
+ }
|
|
|
+ })
|
|
|
+
|
|
|
+ return formatted_results
|
|
|
+
|
|
|
+ def _detect_input_size(self) -> int:
|
|
|
+ """自动检测模型的输入尺寸"""
|
|
|
+ if 'image' in self.inputs:
|
|
|
+ shape = self.inputs['image'].shape
|
|
|
+ # shape 通常是 [batch, channels, height, width]
|
|
|
+ if len(shape) >= 3:
|
|
|
+ # 尝试从 shape[2] 或 shape[3] 获取尺寸
|
|
|
+ for dim in shape[2:]:
|
|
|
+ if isinstance(dim, int) and dim > 0:
|
|
|
+ return dim
|
|
|
+ return 640 # 默认值
|
|
|
+
|
|
|
+ def _preprocess(
|
|
|
+ self,
|
|
|
+ img: np.ndarray
|
|
|
+ ) -> Tuple[Dict[str, np.ndarray], Tuple[float, float], Tuple[int, int]]:
|
|
|
+ """
|
|
|
+ 预处理图像 (根据 RT-DETR 的配置)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ input_dict: 包含所有输入的字典
|
|
|
+ scale: (scale_h, scale_w) 缩放因子
|
|
|
+ orig_shape: (h, w) 原始图像尺寸
|
|
|
+ """
|
|
|
+ orig_h, orig_w = img.shape[:2]
|
|
|
+ target_size = self.target_size # 640
|
|
|
+
|
|
|
+ # 1. Resize 到目标尺寸 (不保持长宽比)
|
|
|
+ img_resized = cv2.resize(
|
|
|
+ img,
|
|
|
+ (target_size, target_size),
|
|
|
+ interpolation=cv2.INTER_LINEAR
|
|
|
+ )
|
|
|
+
|
|
|
+ # 2. 转换为 RGB
|
|
|
+ img_rgb = cv2.cvtColor(img_resized, cv2.COLOR_BGR2RGB)
|
|
|
+
|
|
|
+ # ✅ 修正 3: 归一化 (mean=[0,0,0], std=[1,1,1], norm_type=none)
|
|
|
+ # 只做 /255,不做均值减法和标准差除法
|
|
|
+ img_normalized = img_rgb.astype(np.float32) / 255.0
|
|
|
+
|
|
|
+ # 4. 转换为 CHW 格式
|
|
|
+ img_chw = img_normalized.transpose(2, 0, 1)
|
|
|
+ img_tensor = img_chw[None, ...].astype(np.float32) # [1, 3, H, W]
|
|
|
+
|
|
|
+ # 5. 准备所有输入
|
|
|
+ input_dict = {}
|
|
|
+
|
|
|
+ # 主图像输入
|
|
|
+ if 'image' in self.inputs:
|
|
|
+ input_dict['image'] = img_tensor
|
|
|
+ elif 'images' in self.inputs:
|
|
|
+ input_dict['images'] = img_tensor
|
|
|
+ else:
|
|
|
+ # 使用第一个输入
|
|
|
+ first_input_name = list(self.inputs.keys())[0]
|
|
|
+ input_dict[first_input_name] = img_tensor
|
|
|
+
|
|
|
+ # ✅ 修正 4: 计算缩放因子 (实际图像尺寸 / 目标尺寸)
|
|
|
+ scale_h = orig_h / target_size
|
|
|
+ scale_w = orig_w / target_size
|
|
|
+
|
|
|
+ # im_shape 输入 (原始图像尺寸)
|
|
|
+ if 'im_shape' in self.inputs:
|
|
|
+ im_shape = np.array([[float(orig_h), float(orig_w)]], dtype=np.float32)
|
|
|
+ input_dict['im_shape'] = im_shape
|
|
|
+
|
|
|
+ # scale_factor 输入
|
|
|
+ if 'scale_factor' in self.inputs:
|
|
|
+ # ⚠️ 注意:这里是原始尺寸/目标尺寸的比例
|
|
|
+ scale_factor = np.array([[scale_h, scale_w]], dtype=np.float32)
|
|
|
+ input_dict['scale_factor'] = scale_factor
|
|
|
+
|
|
|
+ # ✅ 返回的 scale 用于后处理坐标还原
|
|
|
+ # 因为不保持长宽比,所以需要分别记录 x 和 y 的缩放
|
|
|
+ return input_dict, (scale_h, scale_w), (orig_h, orig_w)
|
|
|
+
|
|
|
+ def _postprocess(
|
|
|
+ self,
|
|
|
+ outputs: List[np.ndarray],
|
|
|
+ scale: Tuple[float, float], # (scale_h, scale_w)
|
|
|
+ orig_shape: Tuple[int, int],
|
|
|
+ conf_threshold: float = 0.5
|
|
|
+ ) -> List[Dict]:
|
|
|
+ """
|
|
|
+ 后处理模型输出
|
|
|
+
|
|
|
+ Args:
|
|
|
+ outputs: ONNX 模型输出
|
|
|
+ scale: (scale_h, scale_w) 缩放因子
|
|
|
+ orig_shape: (h, w) 原始图像尺寸
|
|
|
+ conf_threshold: 置信度阈值
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 检测结果列表
|
|
|
+ """
|
|
|
+ scale_h, scale_w = scale
|
|
|
+ orig_h, orig_w = orig_shape
|
|
|
+
|
|
|
+ # 解析输出格式
|
|
|
+ if len(outputs) >= 2:
|
|
|
+ output0_shape = outputs[0].shape
|
|
|
+ output1_shape = outputs[1].shape
|
|
|
+
|
|
|
+ # RT-DETR ONNX 格式: (num_boxes, 6)
|
|
|
+ # 格式: [label_id, score, x1, y1, x2, y2]
|
|
|
+ if len(output0_shape) == 2 and output0_shape[1] == 6:
|
|
|
+ pred = outputs[0]
|
|
|
+ labels = pred[:, 0].astype(int)
|
|
|
+ scores = pred[:, 1]
|
|
|
+ bboxes = pred[:, 2:6].copy() # [x1, y1, x2, y2] - 在 640×640 尺度上
|
|
|
+
|
|
|
+ # 情况2: output0 是 (batch, num_boxes, 6) - 带batch的合并格式
|
|
|
+ elif len(output0_shape) == 3 and output0_shape[2] == 6:
|
|
|
+ pred = outputs[0][0]
|
|
|
+ labels = pred[:, 0].astype(int)
|
|
|
+ scores = pred[:, 1]
|
|
|
+ bboxes = pred[:, 2:6].copy()
|
|
|
+
|
|
|
+ # 情况3: output0 是 bboxes, output1 是 scores (分离格式)
|
|
|
+ elif len(output0_shape) == 2 and output0_shape[1] == 4:
|
|
|
+ bboxes = outputs[0].copy()
|
|
|
+ if len(output1_shape) == 1:
|
|
|
+ scores = outputs[1]
|
|
|
+ labels = np.zeros(len(scores), dtype=int)
|
|
|
+ elif len(output1_shape) == 2:
|
|
|
+ scores_all = outputs[1]
|
|
|
+ scores = scores_all.max(axis=1)
|
|
|
+ labels = scores_all.argmax(axis=1)
|
|
|
+ else:
|
|
|
+ raise ValueError(f"Unexpected output1 shape: {output1_shape}")
|
|
|
+
|
|
|
+ # 情况4: RT-DETR 格式 (batch, num_boxes, 4) + (batch, num_boxes, num_classes)
|
|
|
+ elif len(output0_shape) == 3 and output0_shape[2] == 4:
|
|
|
+ bboxes = outputs[0][0].copy()
|
|
|
+ scores_all = outputs[1][0]
|
|
|
+ scores = scores_all.max(axis=1)
|
|
|
+ labels = scores_all.argmax(axis=1)
|
|
|
+
|
|
|
+ else:
|
|
|
+ raise ValueError(f"Unexpected output format: {output0_shape}, {output1_shape}")
|
|
|
+
|
|
|
+ elif len(outputs) == 1:
|
|
|
+ # 单一输出
|
|
|
+ output_shape = outputs[0].shape
|
|
|
+
|
|
|
+ if len(output_shape) == 2 and output_shape[1] == 6:
|
|
|
+ pred = outputs[0]
|
|
|
+ labels = pred[:, 0].astype(int)
|
|
|
+ scores = pred[:, 1]
|
|
|
+ bboxes = pred[:, 2:6].copy()
|
|
|
+
|
|
|
+ elif len(output_shape) == 3 and output_shape[2] == 6:
|
|
|
+ pred = outputs[0][0]
|
|
|
+ labels = pred[:, 0].astype(int)
|
|
|
+ scores = pred[:, 1]
|
|
|
+ bboxes = pred[:, 2:6].copy()
|
|
|
+
|
|
|
+ else:
|
|
|
+ raise ValueError(f"Unexpected single output shape: {output_shape}")
|
|
|
+
|
|
|
+ else:
|
|
|
+ raise ValueError(f"Unexpected number of outputs: {len(outputs)}")
|
|
|
+
|
|
|
+ # 将坐标从 640×640 还原到原图尺度
|
|
|
+ bboxes[:, [0, 2]] *= scale_w
|
|
|
+ bboxes[:, [1, 3]] *= scale_h
|
|
|
+
|
|
|
+ # 自适应阈值
|
|
|
+ max_score = scores.max() if len(scores) > 0 else 0
|
|
|
+ if max_score < conf_threshold:
|
|
|
+ adjusted_threshold = max(max_score * 0.5, 0.05)
|
|
|
+ conf_threshold = adjusted_threshold
|
|
|
+
|
|
|
+ # 过滤低分框
|
|
|
+ mask = scores > conf_threshold
|
|
|
+ bboxes = bboxes[mask]
|
|
|
+ scores = scores[mask]
|
|
|
+ labels = labels[mask]
|
|
|
+
|
|
|
+ # 过滤完全在图像外的框
|
|
|
+ valid_mask = (
|
|
|
+ (bboxes[:, 2] > 0) & # x2 > 0
|
|
|
+ (bboxes[:, 3] > 0) & # y2 > 0
|
|
|
+ (bboxes[:, 0] < orig_w) & # x1 < width
|
|
|
+ (bboxes[:, 1] < orig_h) # y1 < height
|
|
|
+ )
|
|
|
+ bboxes = bboxes[valid_mask]
|
|
|
+ scores = scores[valid_mask]
|
|
|
+ labels = labels[valid_mask]
|
|
|
+
|
|
|
+ # 裁剪坐标到图像范围
|
|
|
+ bboxes[:, [0, 2]] = np.clip(bboxes[:, [0, 2]], 0, orig_w)
|
|
|
+ bboxes[:, [1, 3]] = np.clip(bboxes[:, [1, 3]], 0, orig_h)
|
|
|
+
|
|
|
+ # 构造结果
|
|
|
+ results = []
|
|
|
+ for box, score, label in zip(bboxes, scores, labels):
|
|
|
+ x1, y1, x2, y2 = box
|
|
|
+
|
|
|
+ # 过滤无效框
|
|
|
+ width = x2 - x1
|
|
|
+ height = y2 - y1
|
|
|
+
|
|
|
+ # 过滤太小的框
|
|
|
+ if width < 10 or height < 10:
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 过滤面积异常大的框
|
|
|
+ area = width * height
|
|
|
+ img_area = orig_w * orig_h
|
|
|
+ if area > img_area * 0.95:
|
|
|
+ continue
|
|
|
+
|
|
|
+ results.append({
|
|
|
+ 'category_id': int(label),
|
|
|
+ 'category_name': self.ORIGINAL_CATEGORY_NAMES.get(int(label), f'unknown_{label}'),
|
|
|
+ 'bbox': [int(x1), int(y1), int(x2), int(y2)],
|
|
|
+ 'poly': [int(x1), int(y1), int(x2), int(y1), int(x2), int(y2), int(x1), int(y2)],
|
|
|
+ 'score': float(score),
|
|
|
+ 'width': int(width),
|
|
|
+ 'height': int(height)
|
|
|
+ })
|
|
|
+
|
|
|
+ return results
|
|
|
+
|
|
|
+ def _predict(
|
|
|
+ self,
|
|
|
+ img: np.ndarray,
|
|
|
+ conf_threshold: float = 0.25
|
|
|
+ ) -> List[Dict]:
|
|
|
+ """执行预测"""
|
|
|
+ # 预处理
|
|
|
+ input_dict, scale, orig_shape = self._preprocess(img)
|
|
|
+
|
|
|
+ # ONNX 推理
|
|
|
+ output_names = [out.name for out in self.session.get_outputs()]
|
|
|
+ outputs = self.session.run(output_names, input_dict)
|
|
|
+
|
|
|
+ # 后处理
|
|
|
+ results = self._postprocess(outputs, scale, orig_shape, conf_threshold)
|
|
|
+
|
|
|
+ return results
|
|
|
+
|
|
|
+ def visualize(
|
|
|
+ self,
|
|
|
+ img: np.ndarray,
|
|
|
+ results: List[Dict],
|
|
|
+ output_path: str = None,
|
|
|
+ show_confidence: bool = True,
|
|
|
+ min_confidence: float = 0.0
|
|
|
+ ) -> np.ndarray:
|
|
|
+ """
|
|
|
+ 可视化检测结果
|
|
|
+
|
|
|
+ Args:
|
|
|
+ img: 输入图像
|
|
|
+ results: 检测结果 (MinerU格式)
|
|
|
+ output_path: 输出路径(可选)
|
|
|
+ show_confidence: 是否显示置信度
|
|
|
+ min_confidence: 最小置信度阈值
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 标注后的图像
|
|
|
+ """
|
|
|
+ import random
|
|
|
+
|
|
|
+ vis_img = img.copy()
|
|
|
+
|
|
|
+ # 为每个类别分配固定颜色
|
|
|
+ category_colors = {}
|
|
|
+
|
|
|
+ # 预定义一些常用类别的颜色
|
|
|
+ predefined_colors = {
|
|
|
+ 'text': (0, 255, 0), # 绿色
|
|
|
+ 'title': (255, 0, 0), # 红色
|
|
|
+ 'table_body': (0, 0, 255), # 蓝色
|
|
|
+ 'table_caption': (255, 255, 0), # 青色
|
|
|
+ 'table_footnote': (255, 128, 0), # 橙色
|
|
|
+ 'image_body': (255, 0, 255), # 洋红
|
|
|
+ 'image_caption': (128, 0, 255), # 紫色
|
|
|
+ 'interline_equation': (0, 255, 255), # 黄色
|
|
|
+ 'abandon': (128, 128, 128), # 灰色
|
|
|
+ }
|
|
|
+
|
|
|
+ # 过滤低置信度结果
|
|
|
+ filtered_results = [
|
|
|
+ res for res in results
|
|
|
+ if res['confidence'] >= min_confidence
|
|
|
+ ]
|
|
|
+
|
|
|
+ if not filtered_results:
|
|
|
+ print(f"⚠️ No results to visualize (min_confidence={min_confidence})")
|
|
|
+ return vis_img
|
|
|
+
|
|
|
+ # 为每个出现的类别分配颜色
|
|
|
+ for res in filtered_results:
|
|
|
+ cat = res['category']
|
|
|
+ if cat not in category_colors:
|
|
|
+ if cat in predefined_colors:
|
|
|
+ category_colors[cat] = predefined_colors[cat]
|
|
|
+ else:
|
|
|
+ # 随机生成颜色
|
|
|
+ category_colors[cat] = (
|
|
|
+ random.randint(50, 255),
|
|
|
+ random.randint(50, 255),
|
|
|
+ random.randint(50, 255)
|
|
|
+ )
|
|
|
+
|
|
|
+ # 绘制检测框
|
|
|
+ for res in filtered_results:
|
|
|
+ bbox = res['bbox']
|
|
|
+ x1, y1, x2, y2 = bbox
|
|
|
+ cat = res['category']
|
|
|
+ confidence = res['confidence']
|
|
|
+ color = category_colors[cat]
|
|
|
+
|
|
|
+ # 绘制矩形边框
|
|
|
+ cv2.rectangle(vis_img, (x1, y1), (x2, y2), color, 2)
|
|
|
+
|
|
|
+ # 构造标签文本
|
|
|
+ if show_confidence:
|
|
|
+ label = f"{cat} {confidence:.2f}"
|
|
|
+ else:
|
|
|
+ label = cat
|
|
|
+
|
|
|
+ # 计算标签尺寸
|
|
|
+ label_size, baseline = cv2.getTextSize(
|
|
|
+ label,
|
|
|
+ cv2.FONT_HERSHEY_SIMPLEX,
|
|
|
+ 0.5,
|
|
|
+ 1
|
|
|
+ )
|
|
|
+ label_w, label_h = label_size
|
|
|
+
|
|
|
+ # 绘制标签背景 (填充矩形)
|
|
|
+ cv2.rectangle(
|
|
|
+ vis_img,
|
|
|
+ (x1, y1 - label_h - 4),
|
|
|
+ (x1 + label_w, y1),
|
|
|
+ color,
|
|
|
+ -1 # 填充
|
|
|
+ )
|
|
|
+
|
|
|
+ # 绘制标签文字 (白色)
|
|
|
+ cv2.putText(
|
|
|
+ vis_img,
|
|
|
+ label,
|
|
|
+ (x1, y1 - 2),
|
|
|
+ cv2.FONT_HERSHEY_SIMPLEX,
|
|
|
+ 0.5,
|
|
|
+ (255, 255, 255), # 白色文字
|
|
|
+ 1,
|
|
|
+ cv2.LINE_AA
|
|
|
+ )
|
|
|
+
|
|
|
+ # 添加图例 (在图像右上角)
|
|
|
+ if category_colors:
|
|
|
+ self._draw_legend(vis_img, category_colors, len(filtered_results))
|
|
|
+
|
|
|
+ # 保存可视化结果
|
|
|
+ if output_path:
|
|
|
+ output_path = Path(output_path)
|
|
|
+ output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
+ cv2.imwrite(str(output_path), vis_img)
|
|
|
+ print(f"💾 Visualization saved to: {output_path}")
|
|
|
+
|
|
|
+ return vis_img
|
|
|
+
|
|
|
+ def _draw_legend(
|
|
|
+ self,
|
|
|
+ img: np.ndarray,
|
|
|
+ category_colors: Dict[str, tuple],
|
|
|
+ total_count: int
|
|
|
+ ):
|
|
|
+ """
|
|
|
+ 在图像上绘制图例
|
|
|
+
|
|
|
+ Args:
|
|
|
+ img: 图像
|
|
|
+ category_colors: 类别颜色映射
|
|
|
+ total_count: 总检测数量
|
|
|
+ """
|
|
|
+ legend_x = img.shape[1] - 200 # 右侧留200像素
|
|
|
+ legend_y = 20
|
|
|
+ line_height = 25
|
|
|
+
|
|
|
+ # 绘制半透明背景
|
|
|
+ overlay = img.copy()
|
|
|
+ cv2.rectangle(
|
|
|
+ overlay,
|
|
|
+ (legend_x - 10, legend_y - 10),
|
|
|
+ (img.shape[1] - 10, legend_y + len(category_colors) * line_height + 30),
|
|
|
+ (255, 255, 255),
|
|
|
+ -1
|
|
|
+ )
|
|
|
+ cv2.addWeighted(overlay, 0.7, img, 0.3, 0, img)
|
|
|
+
|
|
|
+ # 绘制标题
|
|
|
+ cv2.putText(
|
|
|
+ img,
|
|
|
+ f"Legend ({total_count} total)",
|
|
|
+ (legend_x, legend_y),
|
|
|
+ cv2.FONT_HERSHEY_SIMPLEX,
|
|
|
+ 0.5,
|
|
|
+ (0, 0, 0),
|
|
|
+ 1,
|
|
|
+ cv2.LINE_AA
|
|
|
+ )
|
|
|
+
|
|
|
+ # 绘制每个类别
|
|
|
+ y_offset = legend_y + line_height
|
|
|
+ for cat, color in sorted(category_colors.items()):
|
|
|
+ # 绘制颜色方块
|
|
|
+ cv2.rectangle(
|
|
|
+ img,
|
|
|
+ (legend_x, y_offset - 10),
|
|
|
+ (legend_x + 15, y_offset),
|
|
|
+ color,
|
|
|
+ -1
|
|
|
+ )
|
|
|
+ cv2.rectangle(
|
|
|
+ img,
|
|
|
+ (legend_x, y_offset - 10),
|
|
|
+ (legend_x + 15, y_offset),
|
|
|
+ (0, 0, 0),
|
|
|
+ 1
|
|
|
+ )
|
|
|
+
|
|
|
+ # 绘制类别名称
|
|
|
+ cv2.putText(
|
|
|
+ img,
|
|
|
+ cat,
|
|
|
+ (legend_x + 20, y_offset - 2),
|
|
|
+ cv2.FONT_HERSHEY_SIMPLEX,
|
|
|
+ 0.4,
|
|
|
+ (0, 0, 0),
|
|
|
+ 1,
|
|
|
+ cv2.LINE_AA
|
|
|
+ )
|
|
|
+
|
|
|
+ y_offset += line_height
|
|
|
+
|
|
|
+
|
|
|
+# 测试代码
|
|
|
+if __name__ == "__main__":
|
|
|
+ import yaml
|
|
|
+
|
|
|
+ # 测试配置
|
|
|
+ config = {
|
|
|
+ 'model_dir': '/Users/zhch158/workspace/repository.git/PaddleX/zhch/unified_pytorch_models/Layout/RT-DETR-H_layout_17cls.onnx',
|
|
|
+ 'device': 'cpu',
|
|
|
+ 'conf': 0.25
|
|
|
+ }
|
|
|
+
|
|
|
+ # 初始化检测器
|
|
|
+ print("🔧 Initializing detector...")
|
|
|
+ detector = PaddleLayoutDetector(config)
|
|
|
+ detector.initialize()
|
|
|
+
|
|
|
+ # 读取测试图像
|
|
|
+ img_path = "/Users/zhch158/workspace/data/流水分析/B用户_扫描流水/PaddleOCR_VL_Results/B用户_扫描流水/B用户_扫描流水_page_001.png"
|
|
|
+ print(f"\n📖 Loading image: {img_path}")
|
|
|
+ img = cv2.imread(img_path)
|
|
|
+
|
|
|
+ if img is None:
|
|
|
+ print(f"❌ Failed to load image: {img_path}")
|
|
|
+ exit(1)
|
|
|
+
|
|
|
+ print(f" Image shape: {img.shape}")
|
|
|
+
|
|
|
+ # 执行检测
|
|
|
+ print("\n🔍 Detecting layout...")
|
|
|
+ results = detector.detect(img)
|
|
|
+
|
|
|
+ print(f"\n✅ 检测到 {len(results)} 个区域:")
|
|
|
+ for i, res in enumerate(results, 1):
|
|
|
+ print(f" [{i}] {res['category']}: "
|
|
|
+ f"score={res['confidence']:.3f}, "
|
|
|
+ f"bbox={res['bbox']}, "
|
|
|
+ f"original={res['raw']['original_category_name']}")
|
|
|
+
|
|
|
+ # 统计各类别
|
|
|
+ category_counts = {}
|
|
|
+ for res in results:
|
|
|
+ cat = res['category']
|
|
|
+ category_counts[cat] = category_counts.get(cat, 0) + 1
|
|
|
+
|
|
|
+ print(f"\n📊 类别统计 (MinerU格式):")
|
|
|
+ for cat, count in sorted(category_counts.items()):
|
|
|
+ print(f" - {cat}: {count}")
|
|
|
+
|
|
|
+ # 使用新的可视化方法
|
|
|
+ if len(results) > 0:
|
|
|
+ print("\n🎨 Generating visualization...")
|
|
|
+
|
|
|
+ # 创建输出目录
|
|
|
+ output_dir = Path(__file__).parent.parent.parent / "tests" / "output"
|
|
|
+ output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
+ output_path = output_dir / f"{Path(img_path).stem}_layout_vis.jpg"
|
|
|
+
|
|
|
+ # 调用可视化方法
|
|
|
+ vis_img = detector.visualize(
|
|
|
+ img,
|
|
|
+ results,
|
|
|
+ output_path=str(output_path),
|
|
|
+ show_confidence=True,
|
|
|
+ min_confidence=0.0
|
|
|
+ )
|
|
|
+
|
|
|
+ print(f"💾 Visualization saved to: {output_path}")
|
|
|
+
|
|
|
+ # 清理
|
|
|
+ detector.cleanup()
|
|
|
+ print("\n✅ 测试完成!")
|