import torch import subprocess import pkg_resources def check_flash_attention(): print("🔍 Flash Attention 状态检查") print("=" * 50) # 检查已安装的包 try: flash_attn_version = pkg_resources.get_distribution("flash-attn").version print(f"✅ flash-attn: {flash_attn_version}") except: print("❌ flash-attn: 未安装") try: flashinfer_version = pkg_resources.get_distribution("flashinfer").version print(f"✅ flashinfer: {flashinfer_version}") except: print("❌ flashinfer: 未安装") # 检查 CUDA 可用性 print(f"\n🔧 CUDA 状态:") print(f"CUDA 可用: {torch.cuda.is_available()}") if torch.cuda.is_available(): print(f"CUDA 版本: {torch.version.cuda}") print(f"GPU 数量: {torch.cuda.device_count()}") for i in range(torch.cuda.device_count()): print(f"GPU {i}: {torch.cuda.get_device_name(i)}") # 检查 Flash Attention 功能 try: import flash_attn print(f"\n✅ Flash Attention 可导入") print(f"Flash Attention 版本: {flash_attn.__version__}") except ImportError as e: print(f"\n❌ Flash Attention 导入失败: {e}") if __name__ == "__main__": check_flash_attention()