|
|
@@ -0,0 +1,142 @@
|
|
|
+"""
|
|
|
+unified_model_loader.py
|
|
|
+统一的PyTorch模型加载器
|
|
|
+"""
|
|
|
+import torch
|
|
|
+import torch.nn as nn
|
|
|
+from pathlib import Path
|
|
|
+from typing import Union, Dict, Any
|
|
|
+
|
|
|
+
|
|
|
+class UnifiedModelLoader:
|
|
|
+ """统一的PyTorch模型加载器"""
|
|
|
+
|
|
|
+ def __init__(self, models_root: str = "./unified_pytorch_models"):
|
|
|
+ self.models_root = Path(models_root)
|
|
|
+
|
|
|
+ # 模型注册表
|
|
|
+ self.model_registry = {
|
|
|
+ # OCR模型
|
|
|
+ 'ocr_det_ch': 'OCR/Det/ch_PP-OCRv4_det_infer.pth',
|
|
|
+ 'ocr_det_en': 'OCR/Det/en_PP-OCRv4_det_infer.pth',
|
|
|
+ 'ocr_rec_ch': 'OCR/Rec/ch_PP-OCRv4_rec_infer.pth',
|
|
|
+ 'ocr_rec_en': 'OCR/Rec/en_PP-OCRv4_rec_infer.pth',
|
|
|
+ 'ocr_cls': 'OCR/Cls/orientation_cls.pth',
|
|
|
+
|
|
|
+ # 表格模型
|
|
|
+ 'table_cls': 'Table/Cls/table_cls.pth',
|
|
|
+ 'table_rec_wired': 'Table/Rec/unet_table.pth',
|
|
|
+ 'table_rec_wireless': 'Table/Rec/slanet_plus.pth',
|
|
|
+
|
|
|
+ # Layout模型 (已是PyTorch)
|
|
|
+ 'layout_yolo': 'Layout/YOLO/doclayout_yolo.pt',
|
|
|
+
|
|
|
+ # 公式识别 (已是PyTorch)
|
|
|
+ 'formula_rec': 'MFR/unimernet_small.safetensors',
|
|
|
+
|
|
|
+ # VLM模型 (已是PyTorch)
|
|
|
+ 'vlm_mineru': 'VLM/MinerU-VLM-1.2B.safetensors',
|
|
|
+ 'vlm_paddleocr': 'VLM/PaddleOCR-VL-0.9B.safetensors',
|
|
|
+ }
|
|
|
+
|
|
|
+ def load_model(
|
|
|
+ self,
|
|
|
+ model_key: str,
|
|
|
+ device: str = 'cpu',
|
|
|
+ **kwargs
|
|
|
+ ) -> nn.Module:
|
|
|
+ """
|
|
|
+ 加载模型
|
|
|
+
|
|
|
+ Args:
|
|
|
+ model_key: 模型键名 (如 'ocr_det_ch')
|
|
|
+ device: 设备 ('cpu', 'cuda', 'cuda:0')
|
|
|
+ **kwargs: 额外参数
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ PyTorch模型
|
|
|
+ """
|
|
|
+ if model_key not in self.model_registry:
|
|
|
+ raise ValueError(f"未知模型: {model_key}")
|
|
|
+
|
|
|
+ model_path = self.models_root / self.model_registry[model_key]
|
|
|
+
|
|
|
+ if not model_path.exists():
|
|
|
+ raise FileNotFoundError(f"模型文件不存在: {model_path}")
|
|
|
+
|
|
|
+ print(f"📦 加载模型: {model_key} from {model_path.name}")
|
|
|
+
|
|
|
+ # 加载模型
|
|
|
+ if model_path.suffix == '.safetensors':
|
|
|
+ model = self._load_safetensors(model_path, device)
|
|
|
+ elif model_path.suffix in ['.pt', '.pth']:
|
|
|
+ model = self._load_pytorch(model_path, device)
|
|
|
+ else:
|
|
|
+ raise ValueError(f"不支持的模型格式: {model_path.suffix}")
|
|
|
+
|
|
|
+ return model
|
|
|
+
|
|
|
+ def _load_pytorch(self, model_path: Path, device: str) -> nn.Module:
|
|
|
+ """加载标准PyTorch模型"""
|
|
|
+ checkpoint = torch.load(model_path, map_location=device)
|
|
|
+
|
|
|
+ if 'model' in checkpoint:
|
|
|
+ # 完整模型
|
|
|
+ model = checkpoint['model']
|
|
|
+ elif 'model_state_dict' in checkpoint:
|
|
|
+ # 仅权重 - 需要先创建模型架构
|
|
|
+ raise NotImplementedError("需要提供模型架构")
|
|
|
+ else:
|
|
|
+ # 直接是state_dict
|
|
|
+ raise NotImplementedError("需要提供模型架构")
|
|
|
+
|
|
|
+ model.eval()
|
|
|
+ return model.to(device)
|
|
|
+
|
|
|
+ def _load_safetensors(self, model_path: Path, device: str) -> nn.Module:
|
|
|
+ """加载Safetensors格式模型 (通常用于HuggingFace)"""
|
|
|
+ from transformers import AutoModel
|
|
|
+
|
|
|
+ model = AutoModel.from_pretrained(
|
|
|
+ model_path.parent,
|
|
|
+ torch_dtype=torch.float16 if 'cuda' in device else torch.float32,
|
|
|
+ device_map=device
|
|
|
+ )
|
|
|
+
|
|
|
+ model.eval()
|
|
|
+ return model
|
|
|
+
|
|
|
+ def list_available_models(self) -> Dict[str, str]:
|
|
|
+ """列出所有可用模型"""
|
|
|
+ available = {}
|
|
|
+ for key, rel_path in self.model_registry.items():
|
|
|
+ full_path = self.models_root / rel_path
|
|
|
+ available[key] = {
|
|
|
+ 'path': str(rel_path),
|
|
|
+ 'exists': full_path.exists(),
|
|
|
+ 'size': full_path.stat().st_size if full_path.exists() else 0
|
|
|
+ }
|
|
|
+ return available
|
|
|
+
|
|
|
+
|
|
|
+# 使用示例
|
|
|
+def test_unified_loader():
|
|
|
+ """测试统一加载器"""
|
|
|
+ loader = UnifiedModelLoader("./unified_pytorch_models")
|
|
|
+
|
|
|
+ # 列出所有模型
|
|
|
+ print("📋 可用模型:")
|
|
|
+ for key, info in loader.list_available_models().items():
|
|
|
+ status = "✅" if info['exists'] else "❌"
|
|
|
+ print(f" {status} {key}: {info['path']}")
|
|
|
+
|
|
|
+ # 加载OCR检测模型
|
|
|
+ try:
|
|
|
+ ocr_det_model = loader.load_model('ocr_det_ch', device='cuda:0')
|
|
|
+ print(f"\n✅ OCR检测模型加载成功: {type(ocr_det_model)}")
|
|
|
+ except Exception as e:
|
|
|
+ print(f"\n❌ 加载失败: {e}")
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ test_unified_loader()
|