| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667 |
- import torch
- try:
- import torch_npu
- except ImportError:
- pass
- from typing import List
- import gc
- import os
- import time
- def detect_available_gpus() -> List[int]:
- """检测可用的GPU"""
- try:
- gpu_count = torch.cuda.device_count()
- available_gpus = list(range(gpu_count))
- print(f"检测到 {gpu_count} 个可用GPU: {available_gpus}")
- return available_gpus
- except Exception as e:
- print(f"GPU检测失败: {e}")
- return []
- def monitor_gpu_memory(gpu_ids: List[int] = [0, 1]):
- """监控GPU内存使用情况"""
- try:
- for gpu_id in gpu_ids:
- torch.cuda.set_device(gpu_id)
- total = torch.cuda.get_device_properties(gpu_id).total_memory / 1024**3
- allocated = torch.cuda.memory_allocated(gpu_id) / 1024**3
- reserved = torch.cuda.memory_reserved(gpu_id) / 1024**3
- print(f"GPU {gpu_id} - 显存: {total:.2f}GB, 已分配: {allocated:.2f}GB, 已预留: {reserved:.2f}GB")
- except Exception as e:
- print(f"GPU内存监控失败: {e}")
- def clean_memory(device='cuda'):
- if device == 'cuda':
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
- torch.cuda.ipc_collect()
- elif str(device).startswith("npu"):
- if torch_npu.npu.is_available():
- torch_npu.npu.empty_cache()
- elif str(device).startswith("mps"):
- torch.mps.empty_cache()
- gc.collect()
- def clean_vram(device, vram_threshold=8):
- total_memory = get_vram(device)
- if total_memory is not None:
- total_memory = int(os.getenv('VIRTUAL_VRAM_SIZE', round(total_memory)))
- if total_memory and total_memory <= vram_threshold:
- gc_start = time.time()
- clean_memory(device)
- gc_time = round(time.time() - gc_start, 2)
- # logger.info(f"gc time: {gc_time}")
- def get_vram(device):
- if torch.cuda.is_available() and str(device).startswith("cuda"):
- total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3) # 将字节转换为 GB
- return total_memory
- elif str(device).startswith("npu"):
- if torch_npu.npu.is_available():
- total_memory = torch_npu.npu.get_device_properties(device).total_memory / (1024 ** 3) # 转为 GB
- return total_memory
- else:
- return None
|