| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142 |
- """
- unified_model_loader.py
- 统一的PyTorch模型加载器
- """
- import torch
- import torch.nn as nn
- from pathlib import Path
- from typing import Union, Dict, Any
- class UnifiedModelLoader:
- """统一的PyTorch模型加载器"""
-
- def __init__(self, models_root: str = "./unified_pytorch_models"):
- self.models_root = Path(models_root)
-
- # 模型注册表
- self.model_registry = {
- # OCR模型
- 'ocr_det_ch': 'OCR/Det/ch_PP-OCRv4_det_infer.pth',
- 'ocr_det_en': 'OCR/Det/en_PP-OCRv4_det_infer.pth',
- 'ocr_rec_ch': 'OCR/Rec/ch_PP-OCRv4_rec_infer.pth',
- 'ocr_rec_en': 'OCR/Rec/en_PP-OCRv4_rec_infer.pth',
- 'ocr_cls': 'OCR/Cls/orientation_cls.pth',
-
- # 表格模型
- 'table_cls': 'Table/Cls/table_cls.pth',
- 'table_rec_wired': 'Table/Rec/unet_table.pth',
- 'table_rec_wireless': 'Table/Rec/slanet_plus.pth',
-
- # Layout模型 (已是PyTorch)
- 'layout_yolo': 'Layout/YOLO/doclayout_yolo.pt',
-
- # 公式识别 (已是PyTorch)
- 'formula_rec': 'MFR/unimernet_small.safetensors',
-
- # VLM模型 (已是PyTorch)
- 'vlm_mineru': 'VLM/MinerU-VLM-1.2B.safetensors',
- 'vlm_paddleocr': 'VLM/PaddleOCR-VL-0.9B.safetensors',
- }
-
- def load_model(
- self,
- model_key: str,
- device: str = 'cpu',
- **kwargs
- ) -> nn.Module:
- """
- 加载模型
-
- Args:
- model_key: 模型键名 (如 'ocr_det_ch')
- device: 设备 ('cpu', 'cuda', 'cuda:0')
- **kwargs: 额外参数
-
- Returns:
- PyTorch模型
- """
- if model_key not in self.model_registry:
- raise ValueError(f"未知模型: {model_key}")
-
- model_path = self.models_root / self.model_registry[model_key]
-
- if not model_path.exists():
- raise FileNotFoundError(f"模型文件不存在: {model_path}")
-
- print(f"📦 加载模型: {model_key} from {model_path.name}")
-
- # 加载模型
- if model_path.suffix == '.safetensors':
- model = self._load_safetensors(model_path, device)
- elif model_path.suffix in ['.pt', '.pth']:
- model = self._load_pytorch(model_path, device)
- else:
- raise ValueError(f"不支持的模型格式: {model_path.suffix}")
-
- return model
-
- def _load_pytorch(self, model_path: Path, device: str) -> nn.Module:
- """加载标准PyTorch模型"""
- checkpoint = torch.load(model_path, map_location=device)
-
- if 'model' in checkpoint:
- # 完整模型
- model = checkpoint['model']
- elif 'model_state_dict' in checkpoint:
- # 仅权重 - 需要先创建模型架构
- raise NotImplementedError("需要提供模型架构")
- else:
- # 直接是state_dict
- raise NotImplementedError("需要提供模型架构")
-
- model.eval()
- return model.to(device)
-
- def _load_safetensors(self, model_path: Path, device: str) -> nn.Module:
- """加载Safetensors格式模型 (通常用于HuggingFace)"""
- from transformers import AutoModel
-
- model = AutoModel.from_pretrained(
- model_path.parent,
- torch_dtype=torch.float16 if 'cuda' in device else torch.float32,
- device_map=device
- )
-
- model.eval()
- return model
-
- def list_available_models(self) -> Dict[str, str]:
- """列出所有可用模型"""
- available = {}
- for key, rel_path in self.model_registry.items():
- full_path = self.models_root / rel_path
- available[key] = {
- 'path': str(rel_path),
- 'exists': full_path.exists(),
- 'size': full_path.stat().st_size if full_path.exists() else 0
- }
- return available
- # 使用示例
- def test_unified_loader():
- """测试统一加载器"""
- loader = UnifiedModelLoader("./unified_pytorch_models")
-
- # 列出所有模型
- print("📋 可用模型:")
- for key, info in loader.list_available_models().items():
- status = "✅" if info['exists'] else "❌"
- print(f" {status} {key}: {info['path']}")
-
- # 加载OCR检测模型
- try:
- ocr_det_model = loader.load_model('ocr_det_ch', device='cuda:0')
- print(f"\n✅ OCR检测模型加载成功: {type(ocr_det_model)}")
- except Exception as e:
- print(f"\n❌ 加载失败: {e}")
- if __name__ == "__main__":
- test_unified_loader()
|