Jelajahi Sumber

feat: 添加 Flash Attention 状态检查工具

zhch158_admin 2 bulan lalu
induk
melakukan
df20ec45df
1 mengubah file dengan 40 tambahan dan 0 penghapusan
  1. 40 0
      zhch/utils/verify_flash_attention.py

+ 40 - 0
zhch/utils/verify_flash_attention.py

@@ -0,0 +1,40 @@
+import torch
+import subprocess
+import pkg_resources
+
+def check_flash_attention():
+    print("🔍 Flash Attention 状态检查")
+    print("=" * 50)
+    
+    # 检查已安装的包
+    try:
+        flash_attn_version = pkg_resources.get_distribution("flash-attn").version
+        print(f"✅ flash-attn: {flash_attn_version}")
+    except:
+        print("❌ flash-attn: 未安装")
+    
+    try:
+        flashinfer_version = pkg_resources.get_distribution("flashinfer").version
+        print(f"✅ flashinfer: {flashinfer_version}")
+    except:
+        print("❌ flashinfer: 未安装")
+    
+    # 检查 CUDA 可用性
+    print(f"\n🔧 CUDA 状态:")
+    print(f"CUDA 可用: {torch.cuda.is_available()}")
+    if torch.cuda.is_available():
+        print(f"CUDA 版本: {torch.version.cuda}")
+        print(f"GPU 数量: {torch.cuda.device_count()}")
+        for i in range(torch.cuda.device_count()):
+            print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
+    
+    # 检查 Flash Attention 功能
+    try:
+        import flash_attn
+        print(f"\n✅ Flash Attention 可导入")
+        print(f"Flash Attention 版本: {flash_attn.__version__}")
+    except ImportError as e:
+        print(f"\n❌ Flash Attention 导入失败: {e}")
+
+if __name__ == "__main__":
+    check_flash_attention()