| 12345678910111213141516171819202122232425262728293031323334353637383940 |
- import torch
- import subprocess
- import pkg_resources
- def check_flash_attention():
- print("🔍 Flash Attention 状态检查")
- print("=" * 50)
-
- # 检查已安装的包
- try:
- flash_attn_version = pkg_resources.get_distribution("flash-attn").version
- print(f"✅ flash-attn: {flash_attn_version}")
- except:
- print("❌ flash-attn: 未安装")
-
- try:
- flashinfer_version = pkg_resources.get_distribution("flashinfer").version
- print(f"✅ flashinfer: {flashinfer_version}")
- except:
- print("❌ flashinfer: 未安装")
-
- # 检查 CUDA 可用性
- print(f"\n🔧 CUDA 状态:")
- print(f"CUDA 可用: {torch.cuda.is_available()}")
- if torch.cuda.is_available():
- print(f"CUDA 版本: {torch.version.cuda}")
- print(f"GPU 数量: {torch.cuda.device_count()}")
- for i in range(torch.cuda.device_count()):
- print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
-
- # 检查 Flash Attention 功能
- try:
- import flash_attn
- print(f"\n✅ Flash Attention 可导入")
- print(f"Flash Attention 版本: {flash_attn.__version__}")
- except ImportError as e:
- print(f"\n❌ Flash Attention 导入失败: {e}")
- if __name__ == "__main__":
- check_flash_attention()
|