Эх сурвалжийг харах

feat: 添加统一的PyTorch模型加载器,支持多种模型格式的加载和管理

zhch158_admin 3 долоо хоног өмнө
parent
commit
541075b260

+ 142 - 0
zhch/unified_pytorch_models/unified_model_loader.py

@@ -0,0 +1,142 @@
+"""
+unified_model_loader.py
+统一的PyTorch模型加载器
+"""
+import torch
+import torch.nn as nn
+from pathlib import Path
+from typing import Union, Dict, Any
+
+
+class UnifiedModelLoader:
+    """统一的PyTorch模型加载器"""
+    
+    def __init__(self, models_root: str = "./unified_pytorch_models"):
+        self.models_root = Path(models_root)
+        
+        # 模型注册表
+        self.model_registry = {
+            # OCR模型
+            'ocr_det_ch': 'OCR/Det/ch_PP-OCRv4_det_infer.pth',
+            'ocr_det_en': 'OCR/Det/en_PP-OCRv4_det_infer.pth',
+            'ocr_rec_ch': 'OCR/Rec/ch_PP-OCRv4_rec_infer.pth',
+            'ocr_rec_en': 'OCR/Rec/en_PP-OCRv4_rec_infer.pth',
+            'ocr_cls': 'OCR/Cls/orientation_cls.pth',
+            
+            # 表格模型
+            'table_cls': 'Table/Cls/table_cls.pth',
+            'table_rec_wired': 'Table/Rec/unet_table.pth',
+            'table_rec_wireless': 'Table/Rec/slanet_plus.pth',
+            
+            # Layout模型 (已是PyTorch)
+            'layout_yolo': 'Layout/YOLO/doclayout_yolo.pt',
+            
+            # 公式识别 (已是PyTorch)
+            'formula_rec': 'MFR/unimernet_small.safetensors',
+            
+            # VLM模型 (已是PyTorch)
+            'vlm_mineru': 'VLM/MinerU-VLM-1.2B.safetensors',
+            'vlm_paddleocr': 'VLM/PaddleOCR-VL-0.9B.safetensors',
+        }
+    
+    def load_model(
+        self, 
+        model_key: str, 
+        device: str = 'cpu',
+        **kwargs
+    ) -> nn.Module:
+        """
+        加载模型
+        
+        Args:
+            model_key: 模型键名 (如 'ocr_det_ch')
+            device: 设备 ('cpu', 'cuda', 'cuda:0')
+            **kwargs: 额外参数
+        
+        Returns:
+            PyTorch模型
+        """
+        if model_key not in self.model_registry:
+            raise ValueError(f"未知模型: {model_key}")
+        
+        model_path = self.models_root / self.model_registry[model_key]
+        
+        if not model_path.exists():
+            raise FileNotFoundError(f"模型文件不存在: {model_path}")
+        
+        print(f"📦 加载模型: {model_key} from {model_path.name}")
+        
+        # 加载模型
+        if model_path.suffix == '.safetensors':
+            model = self._load_safetensors(model_path, device)
+        elif model_path.suffix in ['.pt', '.pth']:
+            model = self._load_pytorch(model_path, device)
+        else:
+            raise ValueError(f"不支持的模型格式: {model_path.suffix}")
+        
+        return model
+    
+    def _load_pytorch(self, model_path: Path, device: str) -> nn.Module:
+        """加载标准PyTorch模型"""
+        checkpoint = torch.load(model_path, map_location=device)
+        
+        if 'model' in checkpoint:
+            # 完整模型
+            model = checkpoint['model']
+        elif 'model_state_dict' in checkpoint:
+            # 仅权重 - 需要先创建模型架构
+            raise NotImplementedError("需要提供模型架构")
+        else:
+            # 直接是state_dict
+            raise NotImplementedError("需要提供模型架构")
+        
+        model.eval()
+        return model.to(device)
+    
+    def _load_safetensors(self, model_path: Path, device: str) -> nn.Module:
+        """加载Safetensors格式模型 (通常用于HuggingFace)"""
+        from transformers import AutoModel
+        
+        model = AutoModel.from_pretrained(
+            model_path.parent,
+            torch_dtype=torch.float16 if 'cuda' in device else torch.float32,
+            device_map=device
+        )
+        
+        model.eval()
+        return model
+    
+    def list_available_models(self) -> Dict[str, str]:
+        """列出所有可用模型"""
+        available = {}
+        for key, rel_path in self.model_registry.items():
+            full_path = self.models_root / rel_path
+            available[key] = {
+                'path': str(rel_path),
+                'exists': full_path.exists(),
+                'size': full_path.stat().st_size if full_path.exists() else 0
+            }
+        return available
+
+
+# 使用示例
+def test_unified_loader():
+    """测试统一加载器"""
+    loader = UnifiedModelLoader("./unified_pytorch_models")
+    
+    # 列出所有模型
+    print("📋 可用模型:")
+    for key, info in loader.list_available_models().items():
+        status = "✅" if info['exists'] else "❌"
+        print(f"  {status} {key}: {info['path']}")
+    
+    # 加载OCR检测模型
+    try:
+        ocr_det_model = loader.load_model('ocr_det_ch', device='cuda:0')
+        print(f"\n✅ OCR检测模型加载成功: {type(ocr_det_model)}")
+    except Exception as e:
+        print(f"\n❌ 加载失败: {e}")
+
+
+if __name__ == "__main__":
+    test_unified_loader()