Переглянути джерело

feat: Add base adapter classes and MinerU integration for document processing

- Implemented base adapter classes: BaseAdapter, BasePreprocessor, BaseLayoutDetector, BaseVLRecognizer, and BaseOCRRecognizer.
- Created MinerUPreprocessor, MinerULayoutDetector, MinerUVLRecognizer, and MinerUOCRRecognizer classes for handling image preprocessing, layout detection, visual language recognition, and OCR using MinerU components.
- Added PaddleLayoutDetector class for layout detection using ONNX Runtime with RT-DETR model.
- Included visualization methods for layout detection results.
- Added error handling and logging for model initialization and processing.
zhch158_admin 2 тижнів тому
батько
коміт
81b2cd6ebb

+ 105 - 0
zhch/universal_doc_parser/models/adapters/__init__.py

@@ -0,0 +1,105 @@
+"""
+模型适配器模块
+提供统一的接口适配不同的模型后端
+"""
+
+from .base import (
+    BaseAdapter,
+    BasePreprocessor,
+    BaseLayoutDetector,
+    BaseVLRecognizer,
+    BaseOCRRecognizer
+)
+
+from .paddle_layout_detector import PaddleLayoutDetector
+
+# 可选导入 MinerU 适配器
+try:
+    from .mineru_adapter import (
+        MinerUPreprocessor,
+        MinerULayoutDetector,
+        MinerUVLRecognizer,
+        MinerUOCRRecognizer
+    )
+    MINERU_AVAILABLE = True
+except ImportError:
+    MINERU_AVAILABLE = False
+
+__all__ = [
+    # 基类
+    'BaseAdapter',
+    'BasePreprocessor',
+    'BaseLayoutDetector',
+    'BaseVLRecognizer',
+    'BaseOCRRecognizer',
+    
+    # PaddleX 适配器
+    'PaddleLayoutDetector',
+]
+
+# 如果 MinerU 可用,添加到导出列表
+if MINERU_AVAILABLE:
+    __all__.extend([
+        'MinerUPreprocessor',
+        'MinerULayoutDetector',
+        'MinerUVLRecognizer',
+        'MinerUOCRRecognizer',
+    ])
+
+
+def get_layout_detector(config: dict):
+    """
+    根据配置获取布局检测器
+    
+    Args:
+        config: 配置字典,包含 module 和其他参数
+        
+    Returns:
+        BaseLayoutDetector 实例
+    """
+    module = config.get('module', 'paddle')
+    
+    if module == 'paddle':
+        return PaddleLayoutDetector(config)
+    elif module == 'mineru':
+        if not MINERU_AVAILABLE:
+            raise ImportError("MinerU adapter not available")
+        return MinerULayoutDetector(config)
+    else:
+        raise ValueError(f"Unknown layout detection module: {module}")
+
+
+def get_preprocessor(config: dict):
+    """根据配置获取预处理器"""
+    module = config.get('module', 'mineru')
+    
+    if module == 'mineru':
+        if not MINERU_AVAILABLE:
+            raise ImportError("MinerU adapter not available")
+        return MinerUPreprocessor(config)
+    else:
+        raise ValueError(f"Unknown preprocessor module: {module}")
+
+
+def get_vl_recognizer(config: dict):
+    """根据配置获取VL识别器"""
+    module = config.get('module', 'mineru')
+    
+    if module == 'mineru':
+        if not MINERU_AVAILABLE:
+            raise ImportError("MinerU adapter not available")
+        return MinerUVLRecognizer(config)
+    else:
+        raise ValueError(f"Unknown VL recognizer module: {module}")
+
+
+def get_ocr_recognizer(config: dict):
+    """根据配置获取OCR识别器"""
+    module = config.get('module', 'mineru')
+    
+    if module == 'mineru':
+        if not MINERU_AVAILABLE:
+            raise ImportError("MinerU adapter not available")
+        return MinerUOCRRecognizer(config)
+    else:
+        raise ValueError(f"Unknown OCR recognizer module: {module}")

+ 92 - 0
zhch/universal_doc_parser/models/adapters/base.py

@@ -0,0 +1,92 @@
+from abc import ABC, abstractmethod
+from typing import Dict, Any, List, Union
+import numpy as np
+from PIL import Image
+
+class BaseAdapter(ABC):
+    """基础适配器接口"""
+    
+    def __init__(self, config: Dict[str, Any]):
+        self.config = config
+    
+    @abstractmethod
+    def initialize(self):
+        """初始化模型"""
+        pass
+    
+    @abstractmethod 
+    def cleanup(self):
+        """清理资源"""
+        pass
+
+class BasePreprocessor(BaseAdapter):
+    """预处理器基类"""
+    
+    @abstractmethod
+    def process(self, image: Union[np.ndarray, Image.Image]) -> tuple[np.ndarray, int]:
+        """
+        处理图像
+        返回处理后的图像和旋转角度
+        """
+        pass
+    
+    def _apply_rotation(self, image: np.ndarray, rotation_label: int) -> np.ndarray:
+        """应用旋转"""
+        import cv2
+        if rotation_label == 1:  # 90度
+            return cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
+        elif rotation_label == 2:  # 180度
+            return cv2.rotate(image, cv2.ROTATE_180)
+        elif rotation_label == 3:  # 270度
+            return cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
+        return image
+
+class BaseLayoutDetector(BaseAdapter):
+    """版式检测器基类"""
+    
+    @abstractmethod
+    def detect(self, image: Union[np.ndarray, Image.Image]) -> List[Dict[str, Any]]:
+        """检测版式"""
+        pass
+    
+    def _map_category_id(self, category_id: int) -> str:
+        """映射类别ID到字符串"""
+        category_map = {
+            0: 'title',
+            1: 'text', 
+            2: 'abandon',
+            3: 'image_body',
+            4: 'image_caption',
+            5: 'table_body',
+            6: 'table_caption',
+            7: 'table_footnote',
+            8: 'interline_equation',
+            9: 'interline_equation_number',
+            13: 'inline_equation',
+            14: 'interline_equation_yolo',
+            15: 'ocr_text',
+            16: 'low_score_text',
+            101: 'image_footnote'
+        }
+        return category_map.get(category_id, f'unknown_{category_id}')
+
+class BaseVLRecognizer(BaseAdapter):
+    """VL识别器基类"""
+    
+    @abstractmethod
+    def recognize_table(self, image: Union[np.ndarray, Image.Image], **kwargs) -> Dict[str, Any]:
+        """识别表格"""
+        pass
+    
+    @abstractmethod
+    def recognize_formula(self, image: Union[np.ndarray, Image.Image], **kwargs) -> Dict[str, Any]:
+        """识别公式"""
+        pass
+
+class BaseOCRRecognizer(BaseAdapter):
+    """OCR识别器基类"""
+    
+    @abstractmethod
+    def recognize_text(self, image: Union[np.ndarray, Image.Image]) -> List[Dict[str, Any]]:
+        """识别文本"""
+        pass

+ 512 - 0
zhch/universal_doc_parser/models/adapters/mineru_adapter.py

@@ -0,0 +1,512 @@
+import sys
+from pathlib import Path
+from typing import Dict, Any, List, Union, Optional
+import numpy as np
+import cv2
+from PIL import Image
+from loguru import logger
+
+# 添加MinerU路径
+mineru_path = Path(__file__).parents[4] / "mineru"
+if str(mineru_path) not in sys.path:
+    sys.path.insert(0, str(mineru_path))
+
+from .base import BasePreprocessor, BaseLayoutDetector, BaseVLRecognizer, BaseOCRRecognizer
+
+# 导入MinerU组件
+try:
+    from mineru.backend.pipeline.model_init import AtomModelSingleton
+    from mineru.backend.vlm.vlm_analyze import ModelSingleton as VLMModelSingleton
+    from mineru.backend.pipeline.model_list import AtomicModel
+    from mineru.utils.config_reader import get_device
+    MINERU_AVAILABLE = True
+except ImportError as e:
+    print(f"Warning: MinerU components not available: {e}")
+    MINERU_AVAILABLE = False
+
+class MinerUPreprocessor(BasePreprocessor):
+    """MinerU预处理器适配器"""
+    
+    def __init__(self, config: Dict[str, Any]):
+        super().__init__(config)
+        if not MINERU_AVAILABLE:
+            raise ImportError("MinerU components not available")
+            
+        self.atom_model_manager = AtomModelSingleton()
+        self.orientation_classifier = None
+        
+    def initialize(self):
+        """初始化预处理组件"""
+        # 初始化方向分类器
+        if self.config.get('orientation_classifier', {}).get('enabled', True):
+            try:
+                self.orientation_classifier = self.atom_model_manager.get_atom_model(
+                    atom_model_name=AtomicModel.ImgOrientationCls,
+                )
+                print("✅ Orientation classifier initialized")
+            except Exception as e:
+                print(f"⚠️ Failed to initialize orientation classifier: {e}")
+        
+    def cleanup(self):
+        """清理资源"""
+        pass
+
+    def process(self, image: Union[np.ndarray, Image.Image]) -> tuple[np.ndarray, int]:
+        """图像预处理"""
+        # 转换为numpy数组
+        if isinstance(image, Image.Image):
+            image = np.array(image)
+
+        rotate_map = {0: 0, 1: 90, 2: 180, 3: 270}
+        rotate_label = 0
+        processed_image = image
+        
+        # 方向校正
+        if self.orientation_classifier is not None:
+            try:
+                rotate_label = self.orientation_classifier.predict(image)
+                processed_image = self._apply_rotation(processed_image, rotate_label)
+                logger.info(f"📐 Applied rotation: {rotate_label}")
+            except Exception as e:
+                logger.error(f"⚠️ Orientation classification failed: {e}")
+
+        return processed_image, rotate_map.get(rotate_label, 0)
+
+class MinerULayoutDetector(BaseLayoutDetector):
+    """MinerU版式检测适配器"""
+    
+    def __init__(self, config: Dict[str, Any]):
+        super().__init__(config)
+        if not MINERU_AVAILABLE:
+            raise ImportError("MinerU components not available")
+            
+        self.atom_model_manager = AtomModelSingleton()
+        self.layout_model = None
+        
+    def initialize(self):
+        """初始化版式检测模型"""
+        try:
+            # 获取模型配置
+            model_name = self.config.get('model_name', 'RT-DETR-H_layout_17cls')
+            model_dir = self.config.get('model_dir')
+            device = self.config.get('device', 'cpu')
+            
+            # 初始化版式检测模型
+            if model_dir:
+                # 使用自定义模型路径
+                self.layout_model = self.atom_model_manager.get_atom_model(
+                    atom_model_name=AtomicModel.Layout,
+                    doclayout_yolo_weights=model_dir,
+                    device=device
+                )
+            else:
+                # 使用默认模型
+                self.layout_model = self.atom_model_manager.get_atom_model(
+                    atom_model_name=AtomicModel.Layout,
+                    device=device
+                )
+            print(f"✅ Layout detector initialized: {model_name}")
+            
+        except Exception as e:
+            print(f"❌ Failed to initialize layout detector: {e}")
+            raise
+        
+    def cleanup(self):
+        """清理资源"""
+        pass
+        
+    def detect(self, image: Union[np.ndarray, Image.Image]) -> List[Dict[str, Any]]:
+        """版式检测"""
+        if self.layout_model is None:
+            raise RuntimeError("Layout model not initialized")
+            
+        # 转换为PIL图像
+        if isinstance(image, np.ndarray):
+            image = Image.fromarray(image)
+            
+        # 进行检测
+        try:
+            layout_results = self.layout_model.predict([image])
+            
+            # 转换结果格式
+            formatted_results = []
+            for result in layout_results[0]:  # 第一页结果
+                # 提取坐标信息
+                poly = result.get('poly', [0, 0, 0, 0, 0, 0, 0, 0])
+                if len(poly) >= 8:
+                    # 转换8点坐标为4点坐标 [x1,y1,x2,y2]
+                    bbox = [poly[0], poly[1], poly[4], poly[5]]
+                else:
+                    bbox = poly[:4] if len(poly) >= 4 else [0, 0, 0, 0]
+                    
+                formatted_results.append({
+                    'category': self._map_category_id(result.get('category_id', 1)),
+                    'bbox': bbox,
+                    'confidence': result.get('score', 0.0),
+                    'raw': result
+                })
+                
+            return formatted_results
+            
+        except Exception as e:
+            print(f"❌ Layout detection failed: {e}")
+            return []
+
+class MinerUVLRecognizer(BaseVLRecognizer):
+    """MinerU VL识别适配器"""
+    
+    def __init__(self, config: Dict[str, Any]):
+        super().__init__(config)
+        if not MINERU_AVAILABLE:
+            raise ImportError("MinerU components not available")
+            
+        self.vlm_model = None
+        # 🔧 添加图片尺寸限制配置
+        self.max_image_size = config.get('max_image_size', 1568)  # VLM 模型的最大尺寸
+        self.resize_mode = config.get('resize_mode', 'max')  # 'max' or 'fixed'
+        
+    def initialize(self):
+        """初始化VL模型"""
+        try:
+            backend = self.config.get('backend', 'http-client')
+            server_url = self.config.get('server_url')
+            model_params = self.config.get('model_params', {})
+            
+            # 初始化VLM模型
+            self.vlm_model = VLMModelSingleton().get_model(
+                backend=backend,
+                model_path=None,
+                server_url=server_url,
+                **model_params
+            )
+            print(f"✅ VL recognizer initialized: {backend}")
+            
+        except Exception as e:
+            print(f"❌ Failed to initialize VL recognizer: {e}")
+            raise
+        
+    def cleanup(self):
+        """清理资源"""
+        pass
+    
+    def _preprocess_image(self, image: Union[np.ndarray, Image.Image]) -> Image.Image:
+        """
+        预处理图片,控制尺寸避免序列长度超限
+        
+        Args:
+            image: 输入图片
+            
+        Returns:
+            处理后的PIL图片
+        """
+        # 转换为PIL图像
+        if isinstance(image, np.ndarray):
+            image = Image.fromarray(image)
+        
+        # 获取原始尺寸
+        orig_w, orig_h = image.size
+        
+        # 计算缩放比例
+        if self.resize_mode == 'max':
+            # 保持宽高比,最长边不超过 max_image_size
+            max_dim = max(orig_w, orig_h)
+            if max_dim > self.max_image_size:
+                scale = self.max_image_size / max_dim
+                new_w = int(orig_w * scale)
+                new_h = int(orig_h * scale)
+                
+                logger.debug(f"🔄 Resizing image: {orig_w}x{orig_h} → {new_w}x{new_h} (scale={scale:.3f})")
+                image = image.resize((new_w, new_h), Image.Resampling.LANCZOS)
+        
+        elif self.resize_mode == 'fixed':
+            # 固定尺寸(可能改变宽高比)
+            if orig_w != self.max_image_size or orig_h != self.max_image_size:
+                logger.debug(f"🔄 Resizing image: {orig_w}x{orig_h} → {self.max_image_size}x{self.max_image_size}")
+                image = image.resize((self.max_image_size, self.max_image_size), Image.Resampling.LANCZOS)
+        
+        return image
+
+    def recognize_table(self, image: Union[np.ndarray, Image.Image], **kwargs) -> Dict[str, Any]:
+        """表格识别"""
+        if self.vlm_model is None:
+            raise RuntimeError("VL model not initialized")
+        
+        try:
+            # 🔧 预处理图片
+            image = self._preprocess_image(image)
+            
+            # 直接调用 content_extract,指定类型为 table
+            table_content = self.vlm_model.content_extract(
+                image=image,
+                type="table"
+            )
+            
+            if not table_content:
+                return {'html': '', 'markdown': '', 'cells': []}
+            
+            # 解析表格内容(假设返回的是HTML格式)
+            return {
+                'html': table_content,
+                'markdown': self._html_to_markdown(table_content),
+                'cells': self._extract_cells_from_html(table_content) if kwargs.get('return_cells_coordinate', False) else []
+            }
+            
+        except Exception as e:
+            logger.error(f"❌ Table recognition failed: {e}")
+            return {'html': '', 'markdown': '', 'cells': []}
+    
+    def recognize_formula(self, image: Union[np.ndarray, Image.Image], **kwargs) -> Dict[str, Any]:
+        """识别公式"""
+        if self.vlm_model is None:
+            raise RuntimeError("VL model not initialized")
+        
+        try:
+            # 🔧 预处理图片
+            image = self._preprocess_image(image)
+            
+            # 直接调用 content_extract,指定类型为 equation
+            formula_content = self.vlm_model.content_extract(
+                image=image,
+                type="equation"
+            )
+            
+            if not formula_content:
+                return {'latex': '', 'confidence': 0.0, 'raw': {}}
+            
+            # 清理LaTeX格式
+            latex = self._clean_latex(formula_content)
+            
+            return {
+                'latex': latex,
+                'confidence': 0.9 if latex else 0.0,
+                'raw': {'raw_output': formula_content}
+            }
+            
+        except Exception as e:
+            logger.error(f"❌ Formula recognition failed: {e}")
+            return {'latex': '', 'confidence': 0.0, 'raw': {}}
+    
+    def recognize_text(self, image: Union[np.ndarray, Image.Image], **kwargs) -> Dict[str, Any]:
+        """识别文本区域"""
+        if self.vlm_model is None:
+            raise RuntimeError("VL model not initialized")
+            
+        try:
+            # 🔧 预处理图片
+            image = self._preprocess_image(image)
+
+            # 直接调用 content_extract,指定类型为 text
+            text_content = self.vlm_model.content_extract(
+                image=image,
+                type="text"
+            )
+            
+            return {
+                'text': text_content or '',
+                'confidence': 0.9 if text_content else 0.0
+            }
+            
+        except Exception as e:
+            print(f"❌ Text recognition failed: {e}")
+            return {'text': '', 'confidence': 0.0}
+    
+    def batch_recognize_table(
+        self, 
+        images: List[Union[np.ndarray, Image.Image]], 
+        **kwargs
+    ) -> List[Dict[str, Any]]:
+        """批量表格识别"""
+        if self.vlm_model is None:
+            raise RuntimeError("VL model not initialized")
+        
+        try:
+            # 🔧 批量预处理图片
+            pil_images = [self._preprocess_image(img) for img in images]
+            
+            # 批量调用 batch_content_extract
+            table_contents = self.vlm_model.batch_content_extract(
+                images=pil_images,
+                types="table"
+            )
+            
+            # 格式化结果
+            results = []
+            for content in table_contents:
+                if content:
+                    results.append({
+                        'html': content,
+                        'markdown': self._html_to_markdown(content),
+                        'cells': self._extract_cells_from_html(content) if kwargs.get('return_cells_coordinate', False) else []
+                    })
+                else:
+                    results.append({'html': '', 'markdown': '', 'cells': []})
+            
+            return results
+            
+        except Exception as e:
+            logger.error(f"❌ Batch table recognition failed: {e}")
+            return [{'html': '', 'markdown': '', 'cells': []} for _ in images]
+    
+    def batch_recognize_formula(
+        self, 
+        images: List[Union[np.ndarray, Image.Image]], 
+        **kwargs
+    ) -> List[Dict[str, Any]]:
+        """批量公式识别"""
+        if self.vlm_model is None:
+            raise RuntimeError("VL model not initialized")
+            
+        # 转换为PIL图像列表
+        pil_images = []
+        for img in images:
+            if isinstance(img, np.ndarray):
+                pil_images.append(Image.fromarray(img))
+            else:
+                pil_images.append(img)
+        
+        try:
+            # 批量调用 batch_content_extract,指定类型为 equation
+            formula_contents = self.vlm_model.batch_content_extract(
+                images=pil_images,
+                types="equation"
+            )
+            
+            # 格式化结果
+            results = []
+            for content in formula_contents:
+                latex = self._clean_latex(content) if content else ''
+                results.append({
+                    'latex': latex,
+                    'confidence': 0.9 if latex else 0.0,
+                    'raw': {'raw_output': content}
+                })
+            
+            return results
+            
+        except Exception as e:
+            print(f"❌ Batch formula recognition failed: {e}")
+            return [{'latex': '', 'confidence': 0.0, 'raw': {}} for _ in images]
+    
+    def _clean_latex(self, raw_latex: str) -> str:
+        """清理LaTeX格式"""
+        if not raw_latex:
+            return ''
+        
+        # 移除外层的 $$ 或 $
+        latex = raw_latex.strip()
+        if latex.startswith('$$') and latex.endswith('$$'):
+            latex = latex[2:-2].strip()
+        elif latex.startswith('$') and latex.endswith('$'):
+            latex = latex[1:-1].strip()
+        
+        return latex
+    
+    def _html_to_markdown(self, html: str) -> str:
+        """将HTML表格转换为Markdown格式"""
+        if not html:
+            return ''
+        
+        return html
+        # try:
+        #     # 简单的HTML到Markdown转换
+        #     # 实际应用中可以使用 markdownify 库
+        #     import re
+            
+        #     # 移除HTML标签,保留内容
+        #     markdown = re.sub(r'<tr[^>]*>', '\n', html)
+        #     markdown = re.sub(r'</tr>', '', markdown)
+        #     markdown = re.sub(r'<t[dh][^>]*>', '| ', markdown)
+        #     markdown = re.sub(r'</t[dh]>', ' ', markdown)
+        #     markdown = re.sub(r'<[^>]+>', '', markdown)
+            
+        #     return markdown.strip()
+            
+        # except Exception as e:
+        #     print(f"⚠️ HTML to Markdown conversion failed: {e}")
+        #     return html
+    
+    def _extract_cells_from_html(self, html: str) -> List[Dict[str, Any]]:
+        """从HTML中提取单元格信息(简化版本)"""
+        if not html:
+            return []
+        
+        try:
+            # 这里只是示例,实际需要解析HTML DOM
+            # 可以使用 BeautifulSoup 等库
+            cells = []
+            # TODO: 实现HTML解析逻辑
+            return cells
+            
+        except Exception as e:
+            print(f"⚠️ Cell extraction failed: {e}")
+            return []
+
+class MinerUOCRRecognizer(BaseOCRRecognizer):
+    """MinerU OCR识别适配器"""
+    
+    def __init__(self, config: Dict[str, Any]):
+        super().__init__(config)
+        if not MINERU_AVAILABLE:
+            raise ImportError("MinerU components not available")
+            
+        self.atom_model_manager = AtomModelSingleton()
+        self.ocr_model = None
+        
+    def initialize(self):
+        """初始化OCR模型"""
+        try:
+            # 初始化OCR模型
+            self.ocr_model = self.atom_model_manager.get_atom_model(
+                atom_model_name=AtomicModel.OCR,
+                det_db_box_thresh=self.config.get('det_threshold', 0.3),
+                lang=self.config.get('language', 'ch'),
+                det_db_unclip_ratio=self.config.get('unclip_ratio', 1.8),
+            )
+            print(f"✅ OCR recognizer initialized: {self.config.get('language', 'ch')}")
+            
+        except Exception as e:
+            print(f"❌ Failed to initialize OCR recognizer: {e}")
+            raise
+        
+    def cleanup(self):
+        """清理资源"""
+        pass
+    
+    def recognize_text(self, image: Union[np.ndarray, Image.Image]) -> List[Dict[str, Any]]:
+        """文本识别"""
+        if self.ocr_model is None:
+            raise RuntimeError("OCR model not initialized")
+            
+        # 转换为BGR格式
+        if isinstance(image, Image.Image):
+            image = np.array(image)
+        bgr_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
+        
+        try:
+            # OCR识别
+            ocr_results = self.ocr_model.ocr(bgr_image, rec=True)
+            
+            # 格式化结果
+            formatted_results = []
+            if ocr_results and ocr_results[0]:
+                for item in ocr_results[0]:
+                    if len(item) >= 2 and len(item[1]) >= 2:
+                        formatted_results.append({
+                            'bbox': item[0],  # 坐标
+                            'text': item[1][0],  # 识别文本
+                            'confidence': item[1][1]  # 置信度
+                        })
+                        
+            return formatted_results
+            
+        except Exception as e:
+            print(f"❌ OCR recognition failed: {e}")
+            return []
+
+# 导出适配器类
+__all__ = [
+    'MinerUPreprocessor',
+    'MinerULayoutDetector', 
+    'MinerUVLRecognizer',
+    'MinerUOCRRecognizer'
+]

+ 679 - 0
zhch/universal_doc_parser/models/adapters/paddle_layout_detector.py

@@ -0,0 +1,679 @@
+"""使用 ONNX Runtime 进行布局检测的统一接口 (符合 BaseLayoutDetector 规范)"""
+
+import cv2
+import numpy as np
+import onnxruntime as ort
+from pathlib import Path
+from typing import Dict, List, Tuple, Union, Any
+from PIL import Image
+import sys
+
+try:
+    from .base import BaseLayoutDetector
+except ImportError:
+    # 如果相对导入失败,尝试绝对导入(适用于测试环境)
+    from base import BaseLayoutDetector
+
+class PaddleLayoutDetector(BaseLayoutDetector):
+    """PaddleX RT-DETR 布局检测器 (ONNX 版本)"""
+    
+    # ⚠️ 修正:使用官方的 RT-DETR-H_layout_17cls 类别定义
+    # 映射到 MinerU 的类别体系
+    CATEGORY_MAP = {
+        0: 'title',              # paragraph_title -> title
+        1: 'image_body',         # image -> image_body
+        2: 'text',               # text -> text
+        3: 'text',               # number -> text (合并到text)
+        4: 'text',               # abstract -> text
+        5: 'text',               # content -> text
+        6: 'image_caption',      # figure_title -> image_caption
+        7: 'interline_equation', # formula -> interline_equation
+        8: 'table_body',         # table -> table_body
+        9: 'table_caption',      # table_title -> table_caption
+        10: 'text',              # reference -> text
+        11: 'title',             # doc_title -> title
+        12: 'table_footnote',    # footnote -> table_footnote
+        13: 'abandon',           # header -> abandon (页眉通常不需要)
+        14: 'text',              # algorithm -> text
+        15: 'abandon',           # footer -> abandon (页脚通常不需要)
+        16: 'abandon'            # seal -> abandon (印章通常不需要)
+    }
+    
+    ORIGINAL_CATEGORY_NAMES = {
+        0: 'paragraph_title',
+        1: 'image',
+        2: 'text',
+        3: 'number',
+        4: 'abstract',
+        5: 'content',
+        6: 'figure_title',
+        7: 'formula',
+        8: 'table',
+        9: 'table_title',
+        10: 'reference',
+        11: 'doc_title',
+        12: 'footnote',
+        13: 'header',
+        14: 'algorithm',
+        15: 'footer',
+        16: 'seal'
+    }
+    
+    def __init__(self, config: Dict[str, Any]):
+        super().__init__(config)
+        self.session = None
+        self.inputs = {}
+        self.outputs = {}
+        self.target_size = 640
+    
+    def initialize(self):
+        """初始化 ONNX 模型"""
+        try:
+            onnx_path = self.config.get('model_dir')
+            if not onnx_path:
+                raise ValueError("model_dir not specified in config")
+            
+            if not Path(onnx_path).exists():
+                raise FileNotFoundError(f"ONNX model not found: {onnx_path}")
+            
+            # 根据配置选择执行提供器
+            device = self.config.get('device', 'cpu')
+            if device == 'gpu':
+                # Mac 支持 CoreML
+                providers = ['CoreMLExecutionProvider', 'CPUExecutionProvider']
+            else:
+                providers = ['CPUExecutionProvider']
+            
+            self.session = ort.InferenceSession(onnx_path, providers=providers)
+            
+            # 获取模型输入输出信息
+            self.inputs = {inp.name: inp for inp in self.session.get_inputs()}
+            self.outputs = {out.name: out for out in self.session.get_outputs()}
+            
+            # 自动检测输入尺寸
+            self.target_size = self._detect_input_size()
+            
+            print(f"✅ PaddleX Layout Detector initialized")
+            print(f"   - Model: {Path(onnx_path).name}")
+            print(f"   - Target size: {self.target_size}")
+            print(f"   - Device: {device}")
+            print(f"   - Providers: {self.session.get_providers()}")
+            
+        except Exception as e:
+            print(f"❌ Failed to initialize PaddleX Layout Detector: {e}")
+            raise
+    
+    def cleanup(self):
+        """清理资源"""
+        self.session = None
+        self.inputs = {}
+        self.outputs = {}
+    
+    def detect(self, image: Union[np.ndarray, Image.Image]) -> List[Dict[str, Any]]:
+        """
+        检测布局
+        
+        Args:
+            image: 输入图像 (numpy数组或PIL图像)
+            
+        Returns:
+            检测结果列表,每个元素包含:
+            - category: MinerU类别名称
+            - bbox: [x1, y1, x2, y2]
+            - confidence: 置信度
+            - raw: 原始检测结果
+        """
+        if self.session is None:
+            raise RuntimeError("Model not initialized. Call initialize() first.")
+        
+        # 转换为numpy数组
+        if isinstance(image, Image.Image):
+            image = np.array(image)
+            if image.ndim == 2:  # 灰度图
+                image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
+            elif image.shape[2] == 4:  # RGBA
+                image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGR)
+            elif image.shape[2] == 3:  # RGB
+                image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
+        
+        # 执行预测
+        conf_threshold = self.config.get('conf', 0.25)
+        results = self._predict(image, conf_threshold)
+        
+        # 转换为 MinerU 格式
+        formatted_results = []
+        for result in results:
+            # 映射类别
+            original_category_id = result['category_id']
+            mineru_category = self.CATEGORY_MAP.get(original_category_id, 'text')
+            
+            formatted_results.append({
+                'category': mineru_category,
+                'bbox': result['bbox'],
+                'confidence': result['score'],
+                'raw': {
+                    'original_category_id': original_category_id,
+                    'original_category_name': result['category_name'],
+                    'poly': result['poly'],
+                    'width': result['width'],
+                    'height': result['height']
+                }
+            })
+        
+        return formatted_results
+    
+    def _detect_input_size(self) -> int:
+        """自动检测模型的输入尺寸"""
+        if 'image' in self.inputs:
+            shape = self.inputs['image'].shape
+            # shape 通常是 [batch, channels, height, width]
+            if len(shape) >= 3:
+                # 尝试从 shape[2] 或 shape[3] 获取尺寸
+                for dim in shape[2:]:
+                    if isinstance(dim, int) and dim > 0:
+                        return dim
+        return 640  # 默认值
+    
+    def _preprocess(
+        self, 
+        img: np.ndarray
+    ) -> Tuple[Dict[str, np.ndarray], Tuple[float, float], Tuple[int, int]]:
+        """
+        预处理图像 (根据 RT-DETR 的配置)
+        
+        Returns:
+            input_dict: 包含所有输入的字典
+            scale: (scale_h, scale_w) 缩放因子
+            orig_shape: (h, w) 原始图像尺寸
+        """
+        orig_h, orig_w = img.shape[:2]
+        target_size = self.target_size  # 640
+        
+        # 1. Resize 到目标尺寸 (不保持长宽比)
+        img_resized = cv2.resize(
+            img, 
+            (target_size, target_size), 
+            interpolation=cv2.INTER_LINEAR
+        )
+        
+        # 2. 转换为 RGB
+        img_rgb = cv2.cvtColor(img_resized, cv2.COLOR_BGR2RGB)
+        
+        # ✅ 修正 3: 归一化 (mean=[0,0,0], std=[1,1,1], norm_type=none)
+        # 只做 /255,不做均值减法和标准差除法
+        img_normalized = img_rgb.astype(np.float32) / 255.0
+        
+        # 4. 转换为 CHW 格式
+        img_chw = img_normalized.transpose(2, 0, 1)
+        img_tensor = img_chw[None, ...].astype(np.float32)  # [1, 3, H, W]
+        
+        # 5. 准备所有输入
+        input_dict = {}
+        
+        # 主图像输入
+        if 'image' in self.inputs:
+            input_dict['image'] = img_tensor
+        elif 'images' in self.inputs:
+            input_dict['images'] = img_tensor
+        else:
+            # 使用第一个输入
+            first_input_name = list(self.inputs.keys())[0]
+            input_dict[first_input_name] = img_tensor
+        
+        # ✅ 修正 4: 计算缩放因子 (实际图像尺寸 / 目标尺寸)
+        scale_h = orig_h / target_size
+        scale_w = orig_w / target_size
+        
+        # im_shape 输入 (原始图像尺寸)
+        if 'im_shape' in self.inputs:
+            im_shape = np.array([[float(orig_h), float(orig_w)]], dtype=np.float32)
+            input_dict['im_shape'] = im_shape
+        
+        # scale_factor 输入
+        if 'scale_factor' in self.inputs:
+            # ⚠️ 注意:这里是原始尺寸/目标尺寸的比例
+            scale_factor = np.array([[scale_h, scale_w]], dtype=np.float32)
+            input_dict['scale_factor'] = scale_factor
+        
+        # ✅ 返回的 scale 用于后处理坐标还原
+        # 因为不保持长宽比,所以需要分别记录 x 和 y 的缩放
+        return input_dict, (scale_h, scale_w), (orig_h, orig_w)
+    
+    def _postprocess(
+        self, 
+        outputs: List[np.ndarray], 
+        scale: Tuple[float, float],  # (scale_h, scale_w)
+        orig_shape: Tuple[int, int],
+        conf_threshold: float = 0.5
+    ) -> List[Dict]:
+        """
+        后处理模型输出
+        
+        Args:
+            outputs: ONNX 模型输出
+            scale: (scale_h, scale_w) 缩放因子
+            orig_shape: (h, w) 原始图像尺寸
+            conf_threshold: 置信度阈值
+            
+        Returns:
+            检测结果列表
+        """
+        scale_h, scale_w = scale
+        orig_h, orig_w = orig_shape
+        
+        # 解析输出格式
+        if len(outputs) >= 2:
+            output0_shape = outputs[0].shape
+            output1_shape = outputs[1].shape
+            
+            # RT-DETR ONNX 格式: (num_boxes, 6)
+            # 格式: [label_id, score, x1, y1, x2, y2]
+            if len(output0_shape) == 2 and output0_shape[1] == 6:
+                pred = outputs[0]
+                labels = pred[:, 0].astype(int)
+                scores = pred[:, 1]
+                bboxes = pred[:, 2:6].copy()  # [x1, y1, x2, y2] - 在 640×640 尺度上
+                
+            # 情况2: output0 是 (batch, num_boxes, 6) - 带batch的合并格式
+            elif len(output0_shape) == 3 and output0_shape[2] == 6:
+                pred = outputs[0][0]
+                labels = pred[:, 0].astype(int)
+                scores = pred[:, 1]
+                bboxes = pred[:, 2:6].copy()
+                
+            # 情况3: output0 是 bboxes, output1 是 scores (分离格式)
+            elif len(output0_shape) == 2 and output0_shape[1] == 4:
+                bboxes = outputs[0].copy()
+                if len(output1_shape) == 1:
+                    scores = outputs[1]
+                    labels = np.zeros(len(scores), dtype=int)
+                elif len(output1_shape) == 2:
+                    scores_all = outputs[1]
+                    scores = scores_all.max(axis=1)
+                    labels = scores_all.argmax(axis=1)
+                else:
+                    raise ValueError(f"Unexpected output1 shape: {output1_shape}")
+        
+            # 情况4: RT-DETR 格式 (batch, num_boxes, 4) + (batch, num_boxes, num_classes)
+            elif len(output0_shape) == 3 and output0_shape[2] == 4:
+                bboxes = outputs[0][0].copy()
+                scores_all = outputs[1][0]
+                scores = scores_all.max(axis=1)
+                labels = scores_all.argmax(axis=1)
+            
+            else:
+                raise ValueError(f"Unexpected output format: {output0_shape}, {output1_shape}")
+        
+        elif len(outputs) == 1:
+            # 单一输出
+            output_shape = outputs[0].shape
+            
+            if len(output_shape) == 2 and output_shape[1] == 6:
+                pred = outputs[0]
+                labels = pred[:, 0].astype(int)
+                scores = pred[:, 1]
+                bboxes = pred[:, 2:6].copy()
+            
+            elif len(output_shape) == 3 and output_shape[2] == 6:
+                pred = outputs[0][0]
+                labels = pred[:, 0].astype(int)
+                scores = pred[:, 1]
+                bboxes = pred[:, 2:6].copy()
+            
+            else:
+                raise ValueError(f"Unexpected single output shape: {output_shape}")
+        
+        else:
+            raise ValueError(f"Unexpected number of outputs: {len(outputs)}")
+        
+        # 将坐标从 640×640 还原到原图尺度
+        bboxes[:, [0, 2]] *= scale_w
+        bboxes[:, [1, 3]] *= scale_h
+        
+        # 自适应阈值
+        max_score = scores.max() if len(scores) > 0 else 0
+        if max_score < conf_threshold:
+            adjusted_threshold = max(max_score * 0.5, 0.05)
+            conf_threshold = adjusted_threshold
+        
+        # 过滤低分框
+        mask = scores > conf_threshold
+        bboxes = bboxes[mask]
+        scores = scores[mask]
+        labels = labels[mask]
+        
+        # 过滤完全在图像外的框
+        valid_mask = (
+            (bboxes[:, 2] > 0) &  # x2 > 0
+            (bboxes[:, 3] > 0) &  # y2 > 0
+            (bboxes[:, 0] < orig_w) &  # x1 < width
+            (bboxes[:, 1] < orig_h)    # y1 < height
+        )
+        bboxes = bboxes[valid_mask]
+        scores = scores[valid_mask]
+        labels = labels[valid_mask]
+        
+        # 裁剪坐标到图像范围
+        bboxes[:, [0, 2]] = np.clip(bboxes[:, [0, 2]], 0, orig_w)
+        bboxes[:, [1, 3]] = np.clip(bboxes[:, [1, 3]], 0, orig_h)
+        
+        # 构造结果
+        results = []
+        for box, score, label in zip(bboxes, scores, labels):
+            x1, y1, x2, y2 = box
+            
+            # 过滤无效框
+            width = x2 - x1
+            height = y2 - y1
+            
+            # 过滤太小的框
+            if width < 10 or height < 10:
+                continue
+            
+            # 过滤面积异常大的框
+            area = width * height
+            img_area = orig_w * orig_h
+            if area > img_area * 0.95:
+                continue
+            
+            results.append({
+                'category_id': int(label),
+                'category_name': self.ORIGINAL_CATEGORY_NAMES.get(int(label), f'unknown_{label}'),
+                'bbox': [int(x1), int(y1), int(x2), int(y2)],
+                'poly': [int(x1), int(y1), int(x2), int(y1), int(x2), int(y2), int(x1), int(y2)],
+                'score': float(score),
+                'width': int(width),
+                'height': int(height)
+            })
+        
+        return results
+    
+    def _predict(
+        self, 
+        img: np.ndarray, 
+        conf_threshold: float = 0.25
+    ) -> List[Dict]:
+        """执行预测"""
+        # 预处理
+        input_dict, scale, orig_shape = self._preprocess(img)
+        
+        # ONNX 推理
+        output_names = [out.name for out in self.session.get_outputs()]
+        outputs = self.session.run(output_names, input_dict)
+        
+        # 后处理
+        results = self._postprocess(outputs, scale, orig_shape, conf_threshold)
+        
+        return results
+
+    def visualize(
+        self, 
+        img: np.ndarray, 
+        results: List[Dict],
+        output_path: str = None,
+        show_confidence: bool = True,
+        min_confidence: float = 0.0
+    ) -> np.ndarray:
+        """
+        可视化检测结果
+        
+        Args:
+            img: 输入图像
+            results: 检测结果 (MinerU格式)
+            output_path: 输出路径(可选)
+            show_confidence: 是否显示置信度
+            min_confidence: 最小置信度阈值
+            
+        Returns:
+            标注后的图像
+        """
+        import random
+        
+        vis_img = img.copy()
+        
+        # 为每个类别分配固定颜色
+        category_colors = {}
+        
+        # 预定义一些常用类别的颜色
+        predefined_colors = {
+            'text': (0, 255, 0),              # 绿色
+            'title': (255, 0, 0),             # 红色
+            'table_body': (0, 0, 255),        # 蓝色
+            'table_caption': (255, 255, 0),   # 青色
+            'table_footnote': (255, 128, 0),  # 橙色
+            'image_body': (255, 0, 255),      # 洋红
+            'image_caption': (128, 0, 255),   # 紫色
+            'interline_equation': (0, 255, 255),  # 黄色
+            'abandon': (128, 128, 128),       # 灰色
+        }
+        
+        # 过滤低置信度结果
+        filtered_results = [
+            res for res in results 
+            if res['confidence'] >= min_confidence
+        ]
+        
+        if not filtered_results:
+            print(f"⚠️ No results to visualize (min_confidence={min_confidence})")
+            return vis_img
+        
+        # 为每个出现的类别分配颜色
+        for res in filtered_results:
+            cat = res['category']
+            if cat not in category_colors:
+                if cat in predefined_colors:
+                    category_colors[cat] = predefined_colors[cat]
+                else:
+                    # 随机生成颜色
+                    category_colors[cat] = (
+                        random.randint(50, 255),
+                        random.randint(50, 255),
+                        random.randint(50, 255)
+                    )
+        
+        # 绘制检测框
+        for res in filtered_results:
+            bbox = res['bbox']
+            x1, y1, x2, y2 = bbox
+            cat = res['category']
+            confidence = res['confidence']
+            color = category_colors[cat]
+            
+            # 绘制矩形边框
+            cv2.rectangle(vis_img, (x1, y1), (x2, y2), color, 2)
+            
+            # 构造标签文本
+            if show_confidence:
+                label = f"{cat} {confidence:.2f}"
+            else:
+                label = cat
+            
+            # 计算标签尺寸
+            label_size, baseline = cv2.getTextSize(
+                label, 
+                cv2.FONT_HERSHEY_SIMPLEX, 
+                0.5, 
+                1
+            )
+            label_w, label_h = label_size
+            
+            # 绘制标签背景 (填充矩形)
+            cv2.rectangle(
+                vis_img,
+                (x1, y1 - label_h - 4),
+                (x1 + label_w, y1),
+                color,
+                -1  # 填充
+            )
+            
+            # 绘制标签文字 (白色)
+            cv2.putText(
+                vis_img,
+                label,
+                (x1, y1 - 2),
+                cv2.FONT_HERSHEY_SIMPLEX,
+                0.5,
+                (255, 255, 255),  # 白色文字
+                1,
+                cv2.LINE_AA
+            )
+        
+        # 添加图例 (在图像右上角)
+        if category_colors:
+            self._draw_legend(vis_img, category_colors, len(filtered_results))
+        
+        # 保存可视化结果
+        if output_path:
+            output_path = Path(output_path)
+            output_path.parent.mkdir(parents=True, exist_ok=True)
+            cv2.imwrite(str(output_path), vis_img)
+            print(f"💾 Visualization saved to: {output_path}")
+        
+        return vis_img
+    
+    def _draw_legend(
+        self, 
+        img: np.ndarray, 
+        category_colors: Dict[str, tuple],
+        total_count: int
+    ):
+        """
+        在图像上绘制图例
+        
+        Args:
+            img: 图像
+            category_colors: 类别颜色映射
+            total_count: 总检测数量
+        """
+        legend_x = img.shape[1] - 200  # 右侧留200像素
+        legend_y = 20
+        line_height = 25
+        
+        # 绘制半透明背景
+        overlay = img.copy()
+        cv2.rectangle(
+            overlay,
+            (legend_x - 10, legend_y - 10),
+            (img.shape[1] - 10, legend_y + len(category_colors) * line_height + 30),
+            (255, 255, 255),
+            -1
+        )
+        cv2.addWeighted(overlay, 0.7, img, 0.3, 0, img)
+        
+        # 绘制标题
+        cv2.putText(
+            img,
+            f"Legend ({total_count} total)",
+            (legend_x, legend_y),
+            cv2.FONT_HERSHEY_SIMPLEX,
+            0.5,
+            (0, 0, 0),
+            1,
+            cv2.LINE_AA
+        )
+        
+        # 绘制每个类别
+        y_offset = legend_y + line_height
+        for cat, color in sorted(category_colors.items()):
+            # 绘制颜色方块
+            cv2.rectangle(
+                img,
+                (legend_x, y_offset - 10),
+                (legend_x + 15, y_offset),
+                color,
+                -1
+            )
+            cv2.rectangle(
+                img,
+                (legend_x, y_offset - 10),
+                (legend_x + 15, y_offset),
+                (0, 0, 0),
+                1
+            )
+            
+            # 绘制类别名称
+            cv2.putText(
+                img,
+                cat,
+                (legend_x + 20, y_offset - 2),
+                cv2.FONT_HERSHEY_SIMPLEX,
+                0.4,
+                (0, 0, 0),
+                1,
+                cv2.LINE_AA
+            )
+            
+            y_offset += line_height
+
+
+# 测试代码
+if __name__ == "__main__":
+    import yaml
+    
+    # 测试配置
+    config = {
+        'model_dir': '/Users/zhch158/workspace/repository.git/PaddleX/zhch/unified_pytorch_models/Layout/RT-DETR-H_layout_17cls.onnx',
+        'device': 'cpu',
+        'conf': 0.25
+    }
+    
+    # 初始化检测器
+    print("🔧 Initializing detector...")
+    detector = PaddleLayoutDetector(config)
+    detector.initialize()
+    
+    # 读取测试图像
+    img_path = "/Users/zhch158/workspace/data/流水分析/B用户_扫描流水/PaddleOCR_VL_Results/B用户_扫描流水/B用户_扫描流水_page_001.png"
+    print(f"\n📖 Loading image: {img_path}")
+    img = cv2.imread(img_path)
+    
+    if img is None:
+        print(f"❌ Failed to load image: {img_path}")
+        exit(1)
+    
+    print(f"   Image shape: {img.shape}")
+    
+    # 执行检测
+    print("\n🔍 Detecting layout...")
+    results = detector.detect(img)
+    
+    print(f"\n✅ 检测到 {len(results)} 个区域:")
+    for i, res in enumerate(results, 1):
+        print(f"  [{i}] {res['category']}: "
+              f"score={res['confidence']:.3f}, "
+              f"bbox={res['bbox']}, "
+              f"original={res['raw']['original_category_name']}")
+    
+    # 统计各类别
+    category_counts = {}
+    for res in results:
+        cat = res['category']
+        category_counts[cat] = category_counts.get(cat, 0) + 1
+    
+    print(f"\n📊 类别统计 (MinerU格式):")
+    for cat, count in sorted(category_counts.items()):
+        print(f"  - {cat}: {count}")
+    
+    # 使用新的可视化方法
+    if len(results) > 0:
+        print("\n🎨 Generating visualization...")
+        
+        # 创建输出目录
+        output_dir = Path(__file__).parent.parent.parent / "tests" / "output"
+        output_dir.mkdir(parents=True, exist_ok=True)
+        output_path = output_dir / f"{Path(img_path).stem}_layout_vis.jpg"
+        
+        # 调用可视化方法
+        vis_img = detector.visualize(
+            img, 
+            results, 
+            output_path=str(output_path),
+            show_confidence=True,
+            min_confidence=0.0
+        )
+        
+        print(f"💾 Visualization saved to: {output_path}")
+    
+    # 清理
+    detector.cleanup()
+    print("\n✅ 测试完成!")