|
@@ -0,0 +1,733 @@
|
|
|
|
|
+"""DiT Layout Detector 适配器
|
|
|
|
|
+
|
|
|
|
|
+基于 DiT (Document Image Transformer) 的布局检测适配器,参考 docling_layout_adapter 的实现方式。
|
|
|
|
|
+支持 PubLayNet 数据集的 5 个类别:text, title, list, table, figure。
|
|
|
|
|
+
|
|
|
|
|
+支持的配置:
|
|
|
|
|
+- config_file: DiT 配置文件路径
|
|
|
|
|
+- model_weights: 模型权重路径或 URL
|
|
|
|
|
+- device: 运行设备 ('cpu', 'cuda', 'mps')
|
|
|
|
|
+- conf: 置信度阈值 (默认 0.3)
|
|
|
|
|
+- remove_overlap: 是否启用重叠框处理 (默认 True)
|
|
|
|
|
+- iou_threshold: IoU 阈值 (默认 0.8)
|
|
|
|
|
+- overlap_ratio_threshold: 重叠比例阈值 (默认 0.8)
|
|
|
|
|
+"""
|
|
|
|
|
+
|
|
|
|
|
+import cv2
|
|
|
|
|
+import numpy as np
|
|
|
|
|
+import threading
|
|
|
|
|
+from pathlib import Path
|
|
|
|
|
+from typing import Dict, List, Union, Any, Optional
|
|
|
|
|
+from PIL import Image
|
|
|
|
|
+
|
|
|
|
|
+try:
|
|
|
|
|
+ from .base import BaseLayoutDetector
|
|
|
|
|
+except ImportError:
|
|
|
|
|
+ from base import BaseLayoutDetector
|
|
|
|
|
+
|
|
|
|
|
+# 全局锁,防止模型初始化时的线程问题
|
|
|
|
|
+_model_init_lock = threading.Lock()
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class LayoutUtils:
|
|
|
|
|
+ """布局处理工具类(简化版,不依赖 external 模块)"""
|
|
|
|
|
+
|
|
|
|
|
+ @staticmethod
|
|
|
|
|
+ def calculate_iou(bbox1: List[float], bbox2: List[float]) -> float:
|
|
|
|
|
+ """计算两个 bbox 的 IoU(交并比)"""
|
|
|
|
|
+ x1_1, y1_1, x2_1, y2_1 = bbox1
|
|
|
|
|
+ x1_2, y1_2, x2_2, y2_2 = bbox2
|
|
|
|
|
+
|
|
|
|
|
+ # 计算交集
|
|
|
|
|
+ x1_i = max(x1_1, x1_2)
|
|
|
|
|
+ y1_i = max(y1_1, y1_2)
|
|
|
|
|
+ x2_i = min(x2_1, x2_2)
|
|
|
|
|
+ y2_i = min(y2_1, y2_2)
|
|
|
|
|
+
|
|
|
|
|
+ if x2_i <= x1_i or y2_i <= y1_i:
|
|
|
|
|
+ return 0.0
|
|
|
|
|
+
|
|
|
|
|
+ intersection = (x2_i - x1_i) * (y2_i - y1_i)
|
|
|
|
|
+
|
|
|
|
|
+ # 计算并集
|
|
|
|
|
+ area1 = (x2_1 - x1_1) * (y2_1 - y1_1)
|
|
|
|
|
+ area2 = (x2_2 - x1_2) * (y2_2 - y1_2)
|
|
|
|
|
+ union = area1 + area2 - intersection
|
|
|
|
|
+
|
|
|
|
|
+ if union == 0:
|
|
|
|
|
+ return 0.0
|
|
|
|
|
+
|
|
|
|
|
+ return intersection / union
|
|
|
|
|
+
|
|
|
|
|
+ @staticmethod
|
|
|
|
|
+ def calculate_overlap_ratio(bbox1: List[float], bbox2: List[float]) -> float:
|
|
|
|
|
+ """计算重叠面积占小框面积的比例"""
|
|
|
|
|
+ x1_1, y1_1, x2_1, y2_1 = bbox1
|
|
|
|
|
+ x1_2, y1_2, x2_2, y2_2 = bbox2
|
|
|
|
|
+
|
|
|
|
|
+ # 计算交集
|
|
|
|
|
+ x1_i = max(x1_1, x1_2)
|
|
|
|
|
+ y1_i = max(y1_1, y1_2)
|
|
|
|
|
+ x2_i = min(x2_1, x2_2)
|
|
|
|
|
+ y2_i = min(y2_1, y2_2)
|
|
|
|
|
+
|
|
|
|
|
+ if x2_i <= x1_i or y2_i <= y1_i:
|
|
|
|
|
+ return 0.0
|
|
|
|
|
+
|
|
|
|
|
+ intersection = (x2_i - x1_i) * (y2_i - y1_i)
|
|
|
|
|
+
|
|
|
|
|
+ # 计算两个框的面积
|
|
|
|
|
+ area1 = (x2_1 - x1_1) * (y2_1 - y1_1)
|
|
|
|
|
+ area2 = (x2_2 - x1_2) * (y2_2 - y1_2)
|
|
|
|
|
+
|
|
|
|
|
+ # 返回交集占小框面积的比例
|
|
|
|
|
+ min_area = min(area1, area2)
|
|
|
|
|
+ if min_area == 0:
|
|
|
|
|
+ return 0.0
|
|
|
|
|
+
|
|
|
|
|
+ return intersection / min_area
|
|
|
|
|
+
|
|
|
|
|
+ @staticmethod
|
|
|
|
|
+ def remove_overlapping_boxes(
|
|
|
|
|
+ layout_results: List[Dict[str, Any]],
|
|
|
|
|
+ iou_threshold: float = 0.8,
|
|
|
|
|
+ overlap_ratio_threshold: float = 0.8
|
|
|
|
|
+ ) -> List[Dict[str, Any]]:
|
|
|
|
|
+ """
|
|
|
|
|
+ 处理重叠的布局框(参考 MinerU 的去重策略)
|
|
|
|
|
+
|
|
|
|
|
+ 策略:
|
|
|
|
|
+ 1. 高 IoU 重叠:保留置信度高的框
|
|
|
|
|
+ 2. 包含关系:小框被大框高度包含时,保留大框并扩展边界
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ layout_results: Layout 检测结果列表
|
|
|
|
|
+ iou_threshold: IoU 阈值,超过此值认为高度重叠
|
|
|
|
|
+ overlap_ratio_threshold: 重叠面积占小框面积的比例阈值
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ 去重后的布局结果列表
|
|
|
|
|
+ """
|
|
|
|
|
+ if not layout_results or len(layout_results) <= 1:
|
|
|
|
|
+ return layout_results
|
|
|
|
|
+
|
|
|
|
|
+ # 复制列表避免修改原数据
|
|
|
|
|
+ results = [item.copy() for item in layout_results]
|
|
|
|
|
+ need_remove = set()
|
|
|
|
|
+
|
|
|
|
|
+ for i in range(len(results)):
|
|
|
|
|
+ if i in need_remove:
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ for j in range(i + 1, len(results)):
|
|
|
|
|
+ if j in need_remove:
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ bbox1 = results[i].get('bbox', [0, 0, 0, 0])
|
|
|
|
|
+ bbox2 = results[j].get('bbox', [0, 0, 0, 0])
|
|
|
|
|
+
|
|
|
|
|
+ if len(bbox1) < 4 or len(bbox2) < 4:
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ # 计算 IoU
|
|
|
|
|
+ iou = LayoutUtils.calculate_iou(bbox1, bbox2)
|
|
|
|
|
+
|
|
|
|
|
+ if iou > iou_threshold:
|
|
|
|
|
+ # 高度重叠,保留置信度高的
|
|
|
|
|
+ score1 = results[i].get('confidence', results[i].get('score', 0))
|
|
|
|
|
+ score2 = results[j].get('confidence', results[j].get('score', 0))
|
|
|
|
|
+
|
|
|
|
|
+ if score1 >= score2:
|
|
|
|
|
+ need_remove.add(j)
|
|
|
|
|
+ else:
|
|
|
|
|
+ need_remove.add(i)
|
|
|
|
|
+ break # i 被移除,跳出内层循环
|
|
|
|
|
+ else:
|
|
|
|
|
+ # 检查包含关系
|
|
|
|
|
+ overlap_ratio = LayoutUtils.calculate_overlap_ratio(bbox1, bbox2)
|
|
|
|
|
+
|
|
|
|
|
+ if overlap_ratio > overlap_ratio_threshold:
|
|
|
|
|
+ # 小框被大框高度包含
|
|
|
|
|
+ area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
|
|
|
|
|
+ area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
|
|
|
|
|
+
|
|
|
|
|
+ if area1 <= area2:
|
|
|
|
|
+ small_idx, large_idx = i, j
|
|
|
|
|
+ else:
|
|
|
|
|
+ small_idx, large_idx = j, i
|
|
|
|
|
+
|
|
|
|
|
+ # 扩展大框的边界
|
|
|
|
|
+ small_bbox = results[small_idx]['bbox']
|
|
|
|
|
+ large_bbox = results[large_idx]['bbox']
|
|
|
|
|
+ results[large_idx]['bbox'] = [
|
|
|
|
|
+ min(small_bbox[0], large_bbox[0]),
|
|
|
|
|
+ min(small_bbox[1], large_bbox[1]),
|
|
|
|
|
+ max(small_bbox[2], large_bbox[2]),
|
|
|
|
|
+ max(small_bbox[3], large_bbox[3])
|
|
|
|
|
+ ]
|
|
|
|
|
+ need_remove.add(small_idx)
|
|
|
|
|
+
|
|
|
|
|
+ if small_idx == i:
|
|
|
|
|
+ break # i 被移除,跳出内层循环
|
|
|
|
|
+
|
|
|
|
|
+ # 返回去重后的结果
|
|
|
|
|
+ return [results[i] for i in range(len(results)) if i not in need_remove]
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class DitLayoutDetector(BaseLayoutDetector):
|
|
|
|
|
+ """DiT Layout Detector 适配器
|
|
|
|
|
+
|
|
|
|
|
+ 基于 DiT (Document Image Transformer) 的布局检测器,使用 detectron2 + DiT backbone。
|
|
|
|
|
+ 支持 PubLayNet 数据集的布局检测。
|
|
|
|
|
+ """
|
|
|
|
|
+
|
|
|
|
|
+ # DiT/PubLayNet 原始类别定义
|
|
|
|
|
+ DIT_LABELS = {
|
|
|
|
|
+ 0: 'text',
|
|
|
|
|
+ 1: 'title',
|
|
|
|
|
+ 2: 'list',
|
|
|
|
|
+ 3: 'table',
|
|
|
|
|
+ 4: 'figure',
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ # 类别映射:PubLayNet → MinerU/EnhancedDocPipeline 类别体系
|
|
|
|
|
+ # 参考:
|
|
|
|
|
+ # - Pipeline: universal_doc_parser/core/pipeline_manager_v2.py (EnhancedDocPipeline 类别定义)
|
|
|
|
|
+ CATEGORY_MAP = {
|
|
|
|
|
+ 'text': 'text', # Text -> text (TEXT_CATEGORIES)
|
|
|
|
|
+ 'title': 'title', # Title -> title (TEXT_CATEGORIES)
|
|
|
|
|
+ 'list': 'text', # List-item -> text (TEXT_CATEGORIES)
|
|
|
|
|
+ 'table': 'table_body', # Table -> table_body (TABLE_BODY_CATEGORIES)
|
|
|
|
|
+ 'figure': 'image_body', # Figure -> image_body (IMAGE_BODY_CATEGORIES)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ def __init__(self, config: Dict[str, Any]):
|
|
|
|
|
+ """
|
|
|
|
|
+ 初始化 DiT Layout 检测器
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ config: 配置字典,支持以下参数:
|
|
|
|
|
+ - config_file: DiT 配置文件路径(默认使用 cascade_dit_large.yaml)
|
|
|
|
|
+ - model_weights: 模型权重路径或 URL
|
|
|
|
|
+ - device: 运行设备 ('cpu', 'cuda', 'mps')
|
|
|
|
|
+ - conf: 置信度阈值 (默认 0.3)
|
|
|
|
|
+ - remove_overlap: 是否启用重叠框处理 (默认 True)
|
|
|
|
|
+ - iou_threshold: IoU 阈值 (默认 0.8)
|
|
|
|
|
+ - overlap_ratio_threshold: 重叠比例阈值 (默认 0.8)
|
|
|
|
|
+ """
|
|
|
|
|
+ super().__init__(config)
|
|
|
|
|
+ self.predictor = None
|
|
|
|
|
+ self.cfg = None
|
|
|
|
|
+ self._device = None
|
|
|
|
|
+ self._threshold = 0.3
|
|
|
|
|
+ self._remove_overlap = True
|
|
|
|
|
+ self._iou_threshold = 0.8
|
|
|
|
|
+ self._overlap_ratio_threshold = 0.8
|
|
|
|
|
+
|
|
|
|
|
+ def initialize(self):
|
|
|
|
|
+ """初始化模型"""
|
|
|
|
|
+ import os
|
|
|
|
|
+ import sys
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ import torch
|
|
|
|
|
+ from detectron2.config import get_cfg
|
|
|
|
|
+ from detectron2.engine import DefaultPredictor
|
|
|
|
|
+ from detectron2.data import MetadataCatalog
|
|
|
|
|
+
|
|
|
|
|
+ # PyTorch 2.6+ 兼容性修复
|
|
|
|
|
+ if hasattr(torch, '__version__'):
|
|
|
|
|
+ torch_version = tuple(map(int, torch.__version__.split('.')[:2]))
|
|
|
|
|
+ if torch_version >= (2, 6):
|
|
|
|
|
+ _original_torch_load = torch.load
|
|
|
|
|
+ def _patched_torch_load(f, map_location=None, pickle_module=None,
|
|
|
|
|
+ weights_only=None, **kwargs):
|
|
|
|
|
+ if weights_only is None:
|
|
|
|
|
+ weights_only = False
|
|
|
|
|
+ return _original_torch_load(f, map_location=map_location,
|
|
|
|
|
+ pickle_module=pickle_module,
|
|
|
|
|
+ weights_only=weights_only, **kwargs)
|
|
|
|
|
+ torch.load = _patched_torch_load
|
|
|
|
|
+
|
|
|
|
|
+ # 添加 dit_support 路径(适配到 universal_doc_parser)
|
|
|
|
|
+ current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
|
+ dit_support_path = os.path.join(current_dir, '..', 'dit_support')
|
|
|
|
|
+ if dit_support_path not in sys.path:
|
|
|
|
|
+ sys.path.insert(0, dit_support_path)
|
|
|
|
|
+
|
|
|
|
|
+ from ditod import add_vit_config
|
|
|
|
|
+
|
|
|
|
|
+ # 获取配置参数
|
|
|
|
|
+ config_file = self.config.get(
|
|
|
|
|
+ 'config_file',
|
|
|
|
|
+ os.path.join(current_dir, '..', 'dit_support', 'configs',
|
|
|
|
|
+ 'cascade', 'cascade_dit_large.yaml')
|
|
|
|
|
+ )
|
|
|
|
|
+ model_weights = self.config.get(
|
|
|
|
|
+ 'model_weights',
|
|
|
|
|
+ 'https://huggingface.co/HYPJUDY/dit/resolve/main/dit-fts/publaynet_dit-l_cascade.pth'
|
|
|
|
|
+ )
|
|
|
|
|
+ device = self.config.get('device', 'cpu')
|
|
|
|
|
+ self._threshold = self.config.get('conf', 0.3)
|
|
|
|
|
+ self._remove_overlap = self.config.get('remove_overlap', True)
|
|
|
|
|
+ self._iou_threshold = self.config.get('iou_threshold', 0.8)
|
|
|
|
|
+ self._overlap_ratio_threshold = self.config.get('overlap_ratio_threshold', 0.8)
|
|
|
|
|
+
|
|
|
|
|
+ # 设置设备
|
|
|
|
|
+ self._device = torch.device(device)
|
|
|
|
|
+
|
|
|
|
|
+ # 验证配置文件存在
|
|
|
|
|
+ if not os.path.exists(config_file):
|
|
|
|
|
+ raise FileNotFoundError(f"Config file not found: {config_file}")
|
|
|
|
|
+
|
|
|
|
|
+ # 加载配置
|
|
|
|
|
+ self.cfg = get_cfg()
|
|
|
|
|
+ add_vit_config(self.cfg)
|
|
|
|
|
+ self.cfg.merge_from_file(config_file)
|
|
|
|
|
+ self.cfg.merge_from_list(["MODEL.WEIGHTS", model_weights])
|
|
|
|
|
+ self.cfg.MODEL.DEVICE = str(self._device)
|
|
|
|
|
+
|
|
|
|
|
+ # 设置元数据
|
|
|
|
|
+ dataset_name = self.cfg.DATASETS.TEST[0]
|
|
|
|
|
+ md = MetadataCatalog.get(dataset_name)
|
|
|
|
|
+ if dataset_name == 'icdar2019_test':
|
|
|
|
|
+ md.set(thing_classes=["table"])
|
|
|
|
|
+ else:
|
|
|
|
|
+ md.set(thing_classes=["text", "title", "list", "table", "figure"])
|
|
|
|
|
+
|
|
|
|
|
+ # 创建预测器(使用锁防止线程问题)
|
|
|
|
|
+ with _model_init_lock:
|
|
|
|
|
+ self.predictor = DefaultPredictor(self.cfg)
|
|
|
|
|
+
|
|
|
|
|
+ print(f"✅ DiT Layout Detector initialized")
|
|
|
|
|
+ print(f" - Config: {config_file}")
|
|
|
|
|
+ print(f" - Device: {self._device}")
|
|
|
|
|
+ print(f" - Threshold: {self._threshold}")
|
|
|
|
|
+ print(f" - Remove overlap: {self._remove_overlap}")
|
|
|
|
|
+
|
|
|
|
|
+ except ImportError as e:
|
|
|
|
|
+ print(f"❌ Failed to import required libraries: {e}")
|
|
|
|
|
+ print(" Please ensure detectron2 and ditod are installed")
|
|
|
|
|
+ raise
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ print(f"❌ Failed to initialize DiT Layout Detector: {e}")
|
|
|
|
|
+ raise
|
|
|
|
|
+
|
|
|
|
|
+ def cleanup(self):
|
|
|
|
|
+ """清理资源"""
|
|
|
|
|
+ self.predictor = None
|
|
|
|
|
+ self.cfg = None
|
|
|
|
|
+ self._device = None
|
|
|
|
|
+
|
|
|
|
|
+ def detect(self, image: Union[np.ndarray, Image.Image]) -> List[Dict[str, Any]]:
|
|
|
|
|
+ """
|
|
|
|
|
+ 检测布局
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ image: 输入图像 (numpy数组或PIL图像)
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ 检测结果列表,每个元素包含:
|
|
|
|
|
+ - category: MinerU类别名称
|
|
|
|
|
+ - bbox: [x1, y1, x2, y2]
|
|
|
|
|
+ - confidence: 置信度
|
|
|
|
|
+ - raw: 原始检测结果
|
|
|
|
|
+ """
|
|
|
|
|
+ if self.predictor is None:
|
|
|
|
|
+ raise RuntimeError("Model not initialized. Call initialize() first.")
|
|
|
|
|
+
|
|
|
|
|
+ # 转换为 numpy 数组 (BGR 格式)
|
|
|
|
|
+ if isinstance(image, Image.Image):
|
|
|
|
|
+ image = np.array(image)
|
|
|
|
|
+ if len(image.shape) == 3 and image.shape[2] == 3:
|
|
|
|
|
+ # PIL RGB -> OpenCV BGR
|
|
|
|
|
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
|
|
|
|
+
|
|
|
|
|
+ # 确保是 BGR 格式
|
|
|
|
|
+ if isinstance(image, np.ndarray):
|
|
|
|
|
+ if len(image.shape) == 2:
|
|
|
|
|
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
|
|
|
|
|
+ elif len(image.shape) == 3 and image.shape[2] == 3:
|
|
|
|
|
+ # 假设是 RGB,转换为 BGR
|
|
|
|
|
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) if image.dtype == np.uint8 else image
|
|
|
|
|
+
|
|
|
|
|
+ orig_h, orig_w = image.shape[:2]
|
|
|
|
|
+
|
|
|
|
|
+ # 运行推理
|
|
|
|
|
+ outputs = self.predictor(image)
|
|
|
|
|
+ instances = outputs["instances"]
|
|
|
|
|
+
|
|
|
|
|
+ # 解析结果
|
|
|
|
|
+ formatted_results = []
|
|
|
|
|
+ for i in range(len(instances)):
|
|
|
|
|
+ score = float(instances.scores[i].cpu().item())
|
|
|
|
|
+
|
|
|
|
|
+ # 过滤低置信度
|
|
|
|
|
+ if score < self._threshold:
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ # 获取类别
|
|
|
|
|
+ class_id = int(instances.pred_classes[i].cpu().item())
|
|
|
|
|
+ original_label = self.DIT_LABELS.get(class_id, f'unknown_{class_id}')
|
|
|
|
|
+
|
|
|
|
|
+ # 映射到 MinerU 类别
|
|
|
|
|
+ mineru_category = self.CATEGORY_MAP.get(original_label, 'text')
|
|
|
|
|
+
|
|
|
|
|
+ # 提取边界框
|
|
|
|
|
+ bbox_tensor = instances.pred_boxes[i].tensor[0].cpu().numpy()
|
|
|
|
|
+ x1 = max(0, min(orig_w, float(bbox_tensor[0])))
|
|
|
|
|
+ y1 = max(0, min(orig_h, float(bbox_tensor[1])))
|
|
|
|
|
+ x2 = max(0, min(orig_w, float(bbox_tensor[2])))
|
|
|
|
|
+ y2 = max(0, min(orig_h, float(bbox_tensor[3])))
|
|
|
|
|
+
|
|
|
|
|
+ bbox = [int(x1), int(y1), int(x2), int(y2)]
|
|
|
|
|
+
|
|
|
|
|
+ # 计算宽高
|
|
|
|
|
+ width = bbox[2] - bbox[0]
|
|
|
|
|
+ height = bbox[3] - bbox[1]
|
|
|
|
|
+
|
|
|
|
|
+ # 过滤太小的框
|
|
|
|
|
+ if width < 10 or height < 10:
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ # 过滤面积异常大的框
|
|
|
|
|
+ area = width * height
|
|
|
|
|
+ img_area = orig_w * orig_h
|
|
|
|
|
+ if area > img_area * 0.95:
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ # 生成多边形坐标
|
|
|
|
|
+ poly = [
|
|
|
|
|
+ bbox[0], bbox[1], # 左上
|
|
|
|
|
+ bbox[2], bbox[1], # 右上
|
|
|
|
|
+ bbox[2], bbox[3], # 右下
|
|
|
|
|
+ bbox[0], bbox[3], # 左下
|
|
|
|
|
+ ]
|
|
|
|
|
+
|
|
|
|
|
+ formatted_results.append({
|
|
|
|
|
+ 'category': mineru_category,
|
|
|
|
|
+ 'bbox': bbox,
|
|
|
|
|
+ 'confidence': score,
|
|
|
|
|
+ 'raw': {
|
|
|
|
|
+ 'original_label': original_label,
|
|
|
|
|
+ 'original_label_id': class_id,
|
|
|
|
|
+ 'poly': poly,
|
|
|
|
|
+ 'width': width,
|
|
|
|
|
+ 'height': height
|
|
|
|
|
+ }
|
|
|
|
|
+ })
|
|
|
|
|
+
|
|
|
|
|
+ # 应用重叠框处理
|
|
|
|
|
+ if self._remove_overlap and len(formatted_results) > 1:
|
|
|
|
|
+ formatted_results = LayoutUtils.remove_overlapping_boxes(
|
|
|
|
|
+ formatted_results,
|
|
|
|
|
+ iou_threshold=self._iou_threshold,
|
|
|
|
|
+ overlap_ratio_threshold=self._overlap_ratio_threshold
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ return formatted_results
|
|
|
|
|
+
|
|
|
|
|
+ def detect_batch(
|
|
|
|
|
+ self,
|
|
|
|
|
+ images: List[Union[np.ndarray, Image.Image]]
|
|
|
|
|
+ ) -> List[List[Dict[str, Any]]]:
|
|
|
|
|
+ """
|
|
|
|
|
+ 批量检测布局
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ images: 输入图像列表
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ 每个图像的检测结果列表
|
|
|
|
|
+ """
|
|
|
|
|
+ if self.predictor is None:
|
|
|
|
|
+ raise RuntimeError("Model not initialized. Call initialize() first.")
|
|
|
|
|
+
|
|
|
|
|
+ if not images:
|
|
|
|
|
+ return []
|
|
|
|
|
+
|
|
|
|
|
+ all_results = []
|
|
|
|
|
+ for image in images:
|
|
|
|
|
+ results = self.detect(image)
|
|
|
|
|
+ all_results.append(results)
|
|
|
|
|
+
|
|
|
|
|
+ return all_results
|
|
|
|
|
+
|
|
|
|
|
+ def visualize(
|
|
|
|
|
+ self,
|
|
|
|
|
+ img: np.ndarray,
|
|
|
|
|
+ results: List[Dict],
|
|
|
|
|
+ output_path: str = None,
|
|
|
|
|
+ show_confidence: bool = True,
|
|
|
|
|
+ min_confidence: float = 0.0
|
|
|
|
|
+ ) -> np.ndarray:
|
|
|
|
|
+ """
|
|
|
|
|
+ 可视化检测结果
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ img: 输入图像 (BGR 格式)
|
|
|
|
|
+ results: 检测结果 (MinerU 格式)
|
|
|
|
|
+ output_path: 输出路径(可选)
|
|
|
|
|
+ show_confidence: 是否显示置信度
|
|
|
|
|
+ min_confidence: 最小置信度阈值
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ 标注后的图像
|
|
|
|
|
+ """
|
|
|
|
|
+ import random
|
|
|
|
|
+
|
|
|
|
|
+ vis_img = img.copy()
|
|
|
|
|
+
|
|
|
|
|
+ # 预定义类别颜色(与 EnhancedDocPipeline 保持一致)
|
|
|
|
|
+ predefined_colors = {
|
|
|
|
|
+ # 文本类
|
|
|
|
|
+ 'text': (153, 0, 76),
|
|
|
|
|
+ 'title': (102, 102, 255),
|
|
|
|
|
+ 'header': (128, 128, 128),
|
|
|
|
|
+ 'footer': (128, 128, 128),
|
|
|
|
|
+ 'page_footnote': (200, 200, 200),
|
|
|
|
|
+ # 表格类
|
|
|
|
|
+ 'table_body': (204, 204, 0),
|
|
|
|
|
+ 'table_caption': (255, 255, 102),
|
|
|
|
|
+ # 图片类
|
|
|
|
|
+ 'image_body': (153, 255, 51),
|
|
|
|
|
+ 'image_caption': (102, 178, 255),
|
|
|
|
|
+ # 公式类
|
|
|
|
|
+ 'interline_equation': (0, 255, 0),
|
|
|
|
|
+ # 代码类
|
|
|
|
|
+ 'code': (102, 0, 204),
|
|
|
|
|
+ # 丢弃类
|
|
|
|
|
+ 'abandon': (100, 100, 100),
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ # 过滤低置信度结果
|
|
|
|
|
+ filtered_results = [
|
|
|
|
|
+ res for res in results
|
|
|
|
|
+ if res['confidence'] >= min_confidence
|
|
|
|
|
+ ]
|
|
|
|
|
+
|
|
|
|
|
+ if not filtered_results:
|
|
|
|
|
+ print(f"⚠️ No results to visualize (min_confidence={min_confidence})")
|
|
|
|
|
+ return vis_img
|
|
|
|
|
+
|
|
|
|
|
+ # 为每个出现的类别分配颜色
|
|
|
|
|
+ category_colors = {}
|
|
|
|
|
+ for res in filtered_results:
|
|
|
|
|
+ cat = res['category']
|
|
|
|
|
+ if cat not in category_colors:
|
|
|
|
|
+ if cat in predefined_colors:
|
|
|
|
|
+ category_colors[cat] = predefined_colors[cat]
|
|
|
|
|
+ else:
|
|
|
|
|
+ category_colors[cat] = (
|
|
|
|
|
+ random.randint(50, 255),
|
|
|
|
|
+ random.randint(50, 255),
|
|
|
|
|
+ random.randint(50, 255)
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # 绘制检测框
|
|
|
|
|
+ for res in filtered_results:
|
|
|
|
|
+ bbox = res['bbox']
|
|
|
|
|
+ x1, y1, x2, y2 = bbox
|
|
|
|
|
+ cat = res['category']
|
|
|
|
|
+ confidence = res['confidence']
|
|
|
|
|
+ color = category_colors[cat]
|
|
|
|
|
+
|
|
|
|
|
+ # 获取原始标签
|
|
|
|
|
+ original_label = res.get('raw', {}).get('original_label', cat)
|
|
|
|
|
+
|
|
|
|
|
+ # 绘制矩形边框
|
|
|
|
|
+ cv2.rectangle(vis_img, (x1, y1), (x2, y2), color, 2)
|
|
|
|
|
+
|
|
|
|
|
+ # 构造标签文本
|
|
|
|
|
+ if show_confidence:
|
|
|
|
|
+ label = f"{original_label}->{cat} {confidence:.2f}"
|
|
|
|
|
+ else:
|
|
|
|
|
+ label = f"{original_label}->{cat}"
|
|
|
|
|
+
|
|
|
|
|
+ # 计算标签尺寸
|
|
|
|
|
+ label_size, baseline = cv2.getTextSize(
|
|
|
|
|
+ label,
|
|
|
|
|
+ cv2.FONT_HERSHEY_SIMPLEX,
|
|
|
|
|
+ 0.4,
|
|
|
|
|
+ 1
|
|
|
|
|
+ )
|
|
|
|
|
+ label_w, label_h = label_size
|
|
|
|
|
+
|
|
|
|
|
+ # 绘制标签背景
|
|
|
|
|
+ cv2.rectangle(
|
|
|
|
|
+ vis_img,
|
|
|
|
|
+ (x1, y1 - label_h - 4),
|
|
|
|
|
+ (x1 + label_w, y1),
|
|
|
|
|
+ color,
|
|
|
|
|
+ -1
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # 绘制标签文字
|
|
|
|
|
+ cv2.putText(
|
|
|
|
|
+ vis_img,
|
|
|
|
|
+ label,
|
|
|
|
|
+ (x1, y1 - 2),
|
|
|
|
|
+ cv2.FONT_HERSHEY_SIMPLEX,
|
|
|
|
|
+ 0.4,
|
|
|
|
|
+ (255, 255, 255),
|
|
|
|
|
+ 1,
|
|
|
|
|
+ cv2.LINE_AA
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # 添加图例
|
|
|
|
|
+ if category_colors:
|
|
|
|
|
+ self._draw_legend(vis_img, category_colors, len(filtered_results))
|
|
|
|
|
+
|
|
|
|
|
+ # 保存可视化结果
|
|
|
|
|
+ if output_path:
|
|
|
|
|
+ output_path_obj = Path(output_path)
|
|
|
|
|
+ output_path_obj.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
+ cv2.imwrite(str(output_path_obj), vis_img)
|
|
|
|
|
+ print(f"💾 Visualization saved to: {output_path_obj}")
|
|
|
|
|
+
|
|
|
|
|
+ return vis_img
|
|
|
|
|
+
|
|
|
|
|
+ def _draw_legend(
|
|
|
|
|
+ self,
|
|
|
|
|
+ img: np.ndarray,
|
|
|
|
|
+ category_colors: Dict[str, tuple],
|
|
|
|
|
+ total_count: int
|
|
|
|
|
+ ):
|
|
|
|
|
+ """在图像上绘制图例"""
|
|
|
|
|
+ legend_x = img.shape[1] - 200
|
|
|
|
|
+ legend_y = 20
|
|
|
|
|
+ line_height = 25
|
|
|
|
|
+
|
|
|
|
|
+ # 绘制半透明背景
|
|
|
|
|
+ overlay = img.copy()
|
|
|
|
|
+ cv2.rectangle(
|
|
|
|
|
+ overlay,
|
|
|
|
|
+ (legend_x - 10, legend_y - 10),
|
|
|
|
|
+ (img.shape[1] - 10, legend_y + len(category_colors) * line_height + 30),
|
|
|
|
|
+ (255, 255, 255),
|
|
|
|
|
+ -1
|
|
|
|
|
+ )
|
|
|
|
|
+ cv2.addWeighted(overlay, 0.7, img, 0.3, 0, img)
|
|
|
|
|
+
|
|
|
|
|
+ # 绘制标题
|
|
|
|
|
+ cv2.putText(
|
|
|
|
|
+ img,
|
|
|
|
|
+ f"Legend ({total_count} total)",
|
|
|
|
|
+ (legend_x, legend_y),
|
|
|
|
|
+ cv2.FONT_HERSHEY_SIMPLEX,
|
|
|
|
|
+ 0.5,
|
|
|
|
|
+ (0, 0, 0),
|
|
|
|
|
+ 1,
|
|
|
|
|
+ cv2.LINE_AA
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # 绘制每个类别
|
|
|
|
|
+ y_offset = legend_y + line_height
|
|
|
|
|
+ for cat, color in sorted(category_colors.items()):
|
|
|
|
|
+ cv2.rectangle(
|
|
|
|
|
+ img,
|
|
|
|
|
+ (legend_x, y_offset - 10),
|
|
|
|
|
+ (legend_x + 15, y_offset),
|
|
|
|
|
+ color,
|
|
|
|
|
+ -1
|
|
|
|
|
+ )
|
|
|
|
|
+ cv2.rectangle(
|
|
|
|
|
+ img,
|
|
|
|
|
+ (legend_x, y_offset - 10),
|
|
|
|
|
+ (legend_x + 15, y_offset),
|
|
|
|
|
+ (0, 0, 0),
|
|
|
|
|
+ 1
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ cv2.putText(
|
|
|
|
|
+ img,
|
|
|
|
|
+ cat,
|
|
|
|
|
+ (legend_x + 20, y_offset - 2),
|
|
|
|
|
+ cv2.FONT_HERSHEY_SIMPLEX,
|
|
|
|
|
+ 0.4,
|
|
|
|
|
+ (0, 0, 0),
|
|
|
|
|
+ 1,
|
|
|
|
|
+ cv2.LINE_AA
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ y_offset += line_height
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# 测试代码
|
|
|
|
|
+if __name__ == "__main__":
|
|
|
|
|
+ import sys
|
|
|
|
|
+ import os
|
|
|
|
|
+
|
|
|
|
|
+ # 测试配置
|
|
|
|
|
+ project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
|
|
|
|
+ config = {
|
|
|
|
|
+ 'config_file': os.path.join(project_root, 'dit', 'object_detection',
|
|
|
|
|
+ 'publaynet_configs', 'cascade', 'cascade_dit_large.yaml'),
|
|
|
|
|
+ 'model_weights': 'https://huggingface.co/HYPJUDY/dit/resolve/main/dit-fts/publaynet_dit-l_cascade.pth',
|
|
|
|
|
+ 'device': 'cpu',
|
|
|
|
|
+ 'conf': 0.3,
|
|
|
|
|
+ 'remove_overlap': True,
|
|
|
|
|
+ 'iou_threshold': 0.8,
|
|
|
|
|
+ 'overlap_ratio_threshold': 0.8
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ # 初始化检测器
|
|
|
|
|
+ print("🔧 Initializing DiT Layout Detector...")
|
|
|
|
|
+ detector = DitLayoutDetector(config)
|
|
|
|
|
+ detector.initialize()
|
|
|
|
|
+
|
|
|
|
|
+ # 读取测试图像
|
|
|
|
|
+ img_path = "/Users/zhch158/workspace/data/流水分析/2023年度报告母公司/paddleocr_vl_results/2023年度报告母公司/2023年度报告母公司_page_021.png"
|
|
|
|
|
+
|
|
|
|
|
+ print(f"\n📖 Loading image: {img_path}")
|
|
|
|
|
+ img = cv2.imread(img_path)
|
|
|
|
|
+
|
|
|
|
|
+ if img is None:
|
|
|
|
|
+ print(f"❌ Failed to load image: {img_path}")
|
|
|
|
|
+ sys.exit(1)
|
|
|
|
|
+
|
|
|
|
|
+ print(f" Image shape: {img.shape}")
|
|
|
|
|
+
|
|
|
|
|
+ # 执行检测
|
|
|
|
|
+ print("\n🔍 Detecting layout...")
|
|
|
|
|
+ results = detector.detect(img)
|
|
|
|
|
+
|
|
|
|
|
+ print(f"\n✅ 检测到 {len(results)} 个区域:")
|
|
|
|
|
+ for i, res in enumerate(results, 1):
|
|
|
|
|
+ print(f" [{i}] {res['category']}: "
|
|
|
|
|
+ f"score={res['confidence']:.3f}, "
|
|
|
|
|
+ f"bbox={res['bbox']}, "
|
|
|
|
|
+ f"original={res['raw']['original_label']}")
|
|
|
|
|
+
|
|
|
|
|
+ # 统计各类别
|
|
|
|
|
+ category_counts = {}
|
|
|
|
|
+ for res in results:
|
|
|
|
|
+ cat = res['category']
|
|
|
|
|
+ category_counts[cat] = category_counts.get(cat, 0) + 1
|
|
|
|
|
+
|
|
|
|
|
+ print(f"\n📊 类别统计 (MinerU格式):")
|
|
|
|
|
+ for cat, count in sorted(category_counts.items()):
|
|
|
|
|
+ print(f" - {cat}: {count}")
|
|
|
|
|
+
|
|
|
|
|
+ # 可视化
|
|
|
|
|
+ if len(results) > 0:
|
|
|
|
|
+ print("\n🎨 Generating visualization...")
|
|
|
|
|
+
|
|
|
|
|
+ output_dir = Path(__file__).parent / "output"
|
|
|
|
|
+ output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
+ output_path = output_dir / f"{Path(img_path).stem}_dit_layout_vis.jpg"
|
|
|
|
|
+
|
|
|
|
|
+ vis_img = detector.visualize(
|
|
|
|
|
+ img,
|
|
|
|
|
+ results,
|
|
|
|
|
+ output_path=str(output_path),
|
|
|
|
|
+ show_confidence=True,
|
|
|
|
|
+ min_confidence=0.0
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ print(f"💾 Visualization saved to: {output_path}")
|
|
|
|
|
+
|
|
|
|
|
+ # 清理
|
|
|
|
|
+ detector.cleanup()
|
|
|
|
|
+ print("\n✅ 测试完成!")
|
|
|
|
|
+
|