cuda_utils.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. import torch
  2. try:
  3. import torch_npu
  4. except ImportError:
  5. pass
  6. from typing import List
  7. import gc
  8. import os
  9. import time
  10. def detect_available_gpus() -> List[int]:
  11. """检测可用的GPU"""
  12. try:
  13. gpu_count = torch.cuda.device_count()
  14. available_gpus = list(range(gpu_count))
  15. print(f"检测到 {gpu_count} 个可用GPU: {available_gpus}")
  16. return available_gpus
  17. except Exception as e:
  18. print(f"GPU检测失败: {e}")
  19. return []
  20. def monitor_gpu_memory(gpu_ids: List[int] = [0, 1]):
  21. """监控GPU内存使用情况"""
  22. try:
  23. for gpu_id in gpu_ids:
  24. torch.cuda.set_device(gpu_id)
  25. total = torch.cuda.get_device_properties(gpu_id).total_memory / 1024**3
  26. allocated = torch.cuda.memory_allocated(gpu_id) / 1024**3
  27. reserved = torch.cuda.memory_reserved(gpu_id) / 1024**3
  28. print(f"GPU {gpu_id} - 显存: {total:.2f}GB, 已分配: {allocated:.2f}GB, 已预留: {reserved:.2f}GB")
  29. except Exception as e:
  30. print(f"GPU内存监控失败: {e}")
  31. def clean_memory(device='cuda'):
  32. if device == 'cuda':
  33. if torch.cuda.is_available():
  34. torch.cuda.empty_cache()
  35. torch.cuda.ipc_collect()
  36. elif str(device).startswith("npu"):
  37. if torch_npu.npu.is_available():
  38. torch_npu.npu.empty_cache()
  39. elif str(device).startswith("mps"):
  40. torch.mps.empty_cache()
  41. gc.collect()
  42. def clean_vram(device, vram_threshold=8):
  43. total_memory = get_vram(device)
  44. if total_memory is not None:
  45. total_memory = int(os.getenv('VIRTUAL_VRAM_SIZE', round(total_memory)))
  46. if total_memory and total_memory <= vram_threshold:
  47. gc_start = time.time()
  48. clean_memory(device)
  49. gc_time = round(time.time() - gc_start, 2)
  50. # logger.info(f"gc time: {gc_time}")
  51. def get_vram(device):
  52. if torch.cuda.is_available() and str(device).startswith("cuda"):
  53. total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3) # 将字节转换为 GB
  54. return total_memory
  55. elif str(device).startswith("npu"):
  56. if torch_npu.npu.is_available():
  57. total_memory = torch_npu.npu.get_device_properties(device).total_memory / (1024 ** 3) # 转为 GB
  58. return total_memory
  59. else:
  60. return None