unified_model_loader.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. """
  2. unified_model_loader.py
  3. 统一的PyTorch模型加载器
  4. """
  5. import torch
  6. import torch.nn as nn
  7. from pathlib import Path
  8. from typing import Union, Dict, Any
  9. class UnifiedModelLoader:
  10. """统一的PyTorch模型加载器"""
  11. def __init__(self, models_root: str = "./unified_pytorch_models"):
  12. self.models_root = Path(models_root)
  13. # 模型注册表
  14. self.model_registry = {
  15. # OCR模型
  16. 'ocr_det_ch': 'OCR/Det/ch_PP-OCRv4_det_infer.pth',
  17. 'ocr_det_en': 'OCR/Det/en_PP-OCRv4_det_infer.pth',
  18. 'ocr_rec_ch': 'OCR/Rec/ch_PP-OCRv4_rec_infer.pth',
  19. 'ocr_rec_en': 'OCR/Rec/en_PP-OCRv4_rec_infer.pth',
  20. 'ocr_cls': 'OCR/Cls/orientation_cls.pth',
  21. # 表格模型
  22. 'table_cls': 'Table/Cls/table_cls.pth',
  23. 'table_rec_wired': 'Table/Rec/unet_table.pth',
  24. 'table_rec_wireless': 'Table/Rec/slanet_plus.pth',
  25. # Layout模型 (已是PyTorch)
  26. 'layout_yolo': 'Layout/YOLO/doclayout_yolo.pt',
  27. # 公式识别 (已是PyTorch)
  28. 'formula_rec': 'MFR/unimernet_small.safetensors',
  29. # VLM模型 (已是PyTorch)
  30. 'vlm_mineru': 'VLM/MinerU-VLM-1.2B.safetensors',
  31. 'vlm_paddleocr': 'VLM/PaddleOCR-VL-0.9B.safetensors',
  32. }
  33. def load_model(
  34. self,
  35. model_key: str,
  36. device: str = 'cpu',
  37. **kwargs
  38. ) -> nn.Module:
  39. """
  40. 加载模型
  41. Args:
  42. model_key: 模型键名 (如 'ocr_det_ch')
  43. device: 设备 ('cpu', 'cuda', 'cuda:0')
  44. **kwargs: 额外参数
  45. Returns:
  46. PyTorch模型
  47. """
  48. if model_key not in self.model_registry:
  49. raise ValueError(f"未知模型: {model_key}")
  50. model_path = self.models_root / self.model_registry[model_key]
  51. if not model_path.exists():
  52. raise FileNotFoundError(f"模型文件不存在: {model_path}")
  53. print(f"📦 加载模型: {model_key} from {model_path.name}")
  54. # 加载模型
  55. if model_path.suffix == '.safetensors':
  56. model = self._load_safetensors(model_path, device)
  57. elif model_path.suffix in ['.pt', '.pth']:
  58. model = self._load_pytorch(model_path, device)
  59. else:
  60. raise ValueError(f"不支持的模型格式: {model_path.suffix}")
  61. return model
  62. def _load_pytorch(self, model_path: Path, device: str) -> nn.Module:
  63. """加载标准PyTorch模型"""
  64. checkpoint = torch.load(model_path, map_location=device)
  65. if 'model' in checkpoint:
  66. # 完整模型
  67. model = checkpoint['model']
  68. elif 'model_state_dict' in checkpoint:
  69. # 仅权重 - 需要先创建模型架构
  70. raise NotImplementedError("需要提供模型架构")
  71. else:
  72. # 直接是state_dict
  73. raise NotImplementedError("需要提供模型架构")
  74. model.eval()
  75. return model.to(device)
  76. def _load_safetensors(self, model_path: Path, device: str) -> nn.Module:
  77. """加载Safetensors格式模型 (通常用于HuggingFace)"""
  78. from transformers import AutoModel
  79. model = AutoModel.from_pretrained(
  80. model_path.parent,
  81. torch_dtype=torch.float16 if 'cuda' in device else torch.float32,
  82. device_map=device
  83. )
  84. model.eval()
  85. return model
  86. def list_available_models(self) -> Dict[str, str]:
  87. """列出所有可用模型"""
  88. available = {}
  89. for key, rel_path in self.model_registry.items():
  90. full_path = self.models_root / rel_path
  91. available[key] = {
  92. 'path': str(rel_path),
  93. 'exists': full_path.exists(),
  94. 'size': full_path.stat().st_size if full_path.exists() else 0
  95. }
  96. return available
  97. # 使用示例
  98. def test_unified_loader():
  99. """测试统一加载器"""
  100. loader = UnifiedModelLoader("./unified_pytorch_models")
  101. # 列出所有模型
  102. print("📋 可用模型:")
  103. for key, info in loader.list_available_models().items():
  104. status = "✅" if info['exists'] else "❌"
  105. print(f" {status} {key}: {info['path']}")
  106. # 加载OCR检测模型
  107. try:
  108. ocr_det_model = loader.load_model('ocr_det_ch', device='cuda:0')
  109. print(f"\n✅ OCR检测模型加载成功: {type(ocr_det_model)}")
  110. except Exception as e:
  111. print(f"\n❌ 加载失败: {e}")
  112. if __name__ == "__main__":
  113. test_unified_loader()