""" unified_model_loader.py 统一的PyTorch模型加载器 """ import torch import torch.nn as nn from pathlib import Path from typing import Union, Dict, Any class UnifiedModelLoader: """统一的PyTorch模型加载器""" def __init__(self, models_root: str = "./unified_pytorch_models"): self.models_root = Path(models_root) # 模型注册表 self.model_registry = { # OCR模型 'ocr_det_ch': 'OCR/Det/ch_PP-OCRv4_det_infer.pth', 'ocr_det_en': 'OCR/Det/en_PP-OCRv4_det_infer.pth', 'ocr_rec_ch': 'OCR/Rec/ch_PP-OCRv4_rec_infer.pth', 'ocr_rec_en': 'OCR/Rec/en_PP-OCRv4_rec_infer.pth', 'ocr_cls': 'OCR/Cls/orientation_cls.pth', # 表格模型 'table_cls': 'Table/Cls/table_cls.pth', 'table_rec_wired': 'Table/Rec/unet_table.pth', 'table_rec_wireless': 'Table/Rec/slanet_plus.pth', # Layout模型 (已是PyTorch) 'layout_yolo': 'Layout/YOLO/doclayout_yolo.pt', # 公式识别 (已是PyTorch) 'formula_rec': 'MFR/unimernet_small.safetensors', # VLM模型 (已是PyTorch) 'vlm_mineru': 'VLM/MinerU-VLM-1.2B.safetensors', 'vlm_paddleocr': 'VLM/PaddleOCR-VL-0.9B.safetensors', } def load_model( self, model_key: str, device: str = 'cpu', **kwargs ) -> nn.Module: """ 加载模型 Args: model_key: 模型键名 (如 'ocr_det_ch') device: 设备 ('cpu', 'cuda', 'cuda:0') **kwargs: 额外参数 Returns: PyTorch模型 """ if model_key not in self.model_registry: raise ValueError(f"未知模型: {model_key}") model_path = self.models_root / self.model_registry[model_key] if not model_path.exists(): raise FileNotFoundError(f"模型文件不存在: {model_path}") print(f"📦 加载模型: {model_key} from {model_path.name}") # 加载模型 if model_path.suffix == '.safetensors': model = self._load_safetensors(model_path, device) elif model_path.suffix in ['.pt', '.pth']: model = self._load_pytorch(model_path, device) else: raise ValueError(f"不支持的模型格式: {model_path.suffix}") return model def _load_pytorch(self, model_path: Path, device: str) -> nn.Module: """加载标准PyTorch模型""" checkpoint = torch.load(model_path, map_location=device) if 'model' in checkpoint: # 完整模型 model = checkpoint['model'] elif 'model_state_dict' in checkpoint: # 仅权重 - 需要先创建模型架构 raise NotImplementedError("需要提供模型架构") else: # 直接是state_dict raise NotImplementedError("需要提供模型架构") model.eval() return model.to(device) def _load_safetensors(self, model_path: Path, device: str) -> nn.Module: """加载Safetensors格式模型 (通常用于HuggingFace)""" from transformers import AutoModel model = AutoModel.from_pretrained( model_path.parent, torch_dtype=torch.float16 if 'cuda' in device else torch.float32, device_map=device ) model.eval() return model def list_available_models(self) -> Dict[str, str]: """列出所有可用模型""" available = {} for key, rel_path in self.model_registry.items(): full_path = self.models_root / rel_path available[key] = { 'path': str(rel_path), 'exists': full_path.exists(), 'size': full_path.stat().st_size if full_path.exists() else 0 } return available # 使用示例 def test_unified_loader(): """测试统一加载器""" loader = UnifiedModelLoader("./unified_pytorch_models") # 列出所有模型 print("📋 可用模型:") for key, info in loader.list_available_models().items(): status = "✅" if info['exists'] else "❌" print(f" {status} {key}: {info['path']}") # 加载OCR检测模型 try: ocr_det_model = loader.load_model('ocr_det_ch', device='cuda:0') print(f"\n✅ OCR检测模型加载成功: {type(ocr_det_model)}") except Exception as e: print(f"\n❌ 加载失败: {e}") if __name__ == "__main__": test_unified_loader()