|
|
@@ -0,0 +1,40 @@
|
|
|
+import torch
|
|
|
+import subprocess
|
|
|
+import pkg_resources
|
|
|
+
|
|
|
+def check_flash_attention():
|
|
|
+ print("🔍 Flash Attention 状态检查")
|
|
|
+ print("=" * 50)
|
|
|
+
|
|
|
+ # 检查已安装的包
|
|
|
+ try:
|
|
|
+ flash_attn_version = pkg_resources.get_distribution("flash-attn").version
|
|
|
+ print(f"✅ flash-attn: {flash_attn_version}")
|
|
|
+ except:
|
|
|
+ print("❌ flash-attn: 未安装")
|
|
|
+
|
|
|
+ try:
|
|
|
+ flashinfer_version = pkg_resources.get_distribution("flashinfer").version
|
|
|
+ print(f"✅ flashinfer: {flashinfer_version}")
|
|
|
+ except:
|
|
|
+ print("❌ flashinfer: 未安装")
|
|
|
+
|
|
|
+ # 检查 CUDA 可用性
|
|
|
+ print(f"\n🔧 CUDA 状态:")
|
|
|
+ print(f"CUDA 可用: {torch.cuda.is_available()}")
|
|
|
+ if torch.cuda.is_available():
|
|
|
+ print(f"CUDA 版本: {torch.version.cuda}")
|
|
|
+ print(f"GPU 数量: {torch.cuda.device_count()}")
|
|
|
+ for i in range(torch.cuda.device_count()):
|
|
|
+ print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
|
|
|
+
|
|
|
+ # 检查 Flash Attention 功能
|
|
|
+ try:
|
|
|
+ import flash_attn
|
|
|
+ print(f"\n✅ Flash Attention 可导入")
|
|
|
+ print(f"Flash Attention 版本: {flash_attn.__version__}")
|
|
|
+ except ImportError as e:
|
|
|
+ print(f"\n❌ Flash Attention 导入失败: {e}")
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ check_flash_attention()
|