import torch try: import torch_npu except ImportError: pass from typing import List import gc import os import time def detect_available_gpus() -> List[int]: """检测可用的GPU""" try: gpu_count = torch.cuda.device_count() available_gpus = list(range(gpu_count)) print(f"检测到 {gpu_count} 个可用GPU: {available_gpus}") return available_gpus except Exception as e: print(f"GPU检测失败: {e}") return [] def monitor_gpu_memory(gpu_ids: List[int] = [0, 1]): """监控GPU内存使用情况""" try: for gpu_id in gpu_ids: torch.cuda.set_device(gpu_id) total = torch.cuda.get_device_properties(gpu_id).total_memory / 1024**3 allocated = torch.cuda.memory_allocated(gpu_id) / 1024**3 reserved = torch.cuda.memory_reserved(gpu_id) / 1024**3 print(f"GPU {gpu_id} - 显存: {total:.2f}GB, 已分配: {allocated:.2f}GB, 已预留: {reserved:.2f}GB") except Exception as e: print(f"GPU内存监控失败: {e}") def clean_memory(device='cuda'): if device == 'cuda': if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect() elif str(device).startswith("npu"): if torch_npu.npu.is_available(): torch_npu.npu.empty_cache() elif str(device).startswith("mps"): torch.mps.empty_cache() gc.collect() def clean_vram(device, vram_threshold=8): total_memory = get_vram(device) if total_memory is not None: total_memory = int(os.getenv('VIRTUAL_VRAM_SIZE', round(total_memory))) if total_memory and total_memory <= vram_threshold: gc_start = time.time() clean_memory(device) gc_time = round(time.time() - gc_start, 2) # logger.info(f"gc time: {gc_time}") def get_vram(device): if torch.cuda.is_available() and str(device).startswith("cuda"): total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3) # 将字节转换为 GB return total_memory elif str(device).startswith("npu"): if torch_npu.npu.is_available(): total_memory = torch_npu.npu.get_device_properties(device).total_memory / (1024 ** 3) # 转为 GB return total_memory else: return None