|
|
@@ -1,5 +1,13 @@
|
|
|
import torch
|
|
|
+try:
|
|
|
+ import torch_npu
|
|
|
+except ImportError:
|
|
|
+ pass
|
|
|
from typing import List
|
|
|
+import gc
|
|
|
+import os
|
|
|
+import time
|
|
|
+
|
|
|
def detect_available_gpus() -> List[int]:
|
|
|
"""检测可用的GPU"""
|
|
|
try:
|
|
|
@@ -21,4 +29,39 @@ def monitor_gpu_memory(gpu_ids: List[int] = [0, 1]):
|
|
|
reserved = torch.cuda.memory_reserved(gpu_id) / 1024**3
|
|
|
print(f"GPU {gpu_id} - 显存: {total:.2f}GB, 已分配: {allocated:.2f}GB, 已预留: {reserved:.2f}GB")
|
|
|
except Exception as e:
|
|
|
- print(f"GPU内存监控失败: {e}")
|
|
|
+ print(f"GPU内存监控失败: {e}")
|
|
|
+
|
|
|
+def clean_memory(device='cuda'):
|
|
|
+ if device == 'cuda':
|
|
|
+ if torch.cuda.is_available():
|
|
|
+ torch.cuda.empty_cache()
|
|
|
+ torch.cuda.ipc_collect()
|
|
|
+ elif str(device).startswith("npu"):
|
|
|
+ if torch_npu.npu.is_available():
|
|
|
+ torch_npu.npu.empty_cache()
|
|
|
+ elif str(device).startswith("mps"):
|
|
|
+ torch.mps.empty_cache()
|
|
|
+ gc.collect()
|
|
|
+
|
|
|
+
|
|
|
+def clean_vram(device, vram_threshold=8):
|
|
|
+ total_memory = get_vram(device)
|
|
|
+ if total_memory is not None:
|
|
|
+ total_memory = int(os.getenv('VIRTUAL_VRAM_SIZE', round(total_memory)))
|
|
|
+ if total_memory and total_memory <= vram_threshold:
|
|
|
+ gc_start = time.time()
|
|
|
+ clean_memory(device)
|
|
|
+ gc_time = round(time.time() - gc_start, 2)
|
|
|
+ # logger.info(f"gc time: {gc_time}")
|
|
|
+
|
|
|
+
|
|
|
+def get_vram(device):
|
|
|
+ if torch.cuda.is_available() and str(device).startswith("cuda"):
|
|
|
+ total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3) # 将字节转换为 GB
|
|
|
+ return total_memory
|
|
|
+ elif str(device).startswith("npu"):
|
|
|
+ if torch_npu.npu.is_available():
|
|
|
+ total_memory = torch_npu.npu.get_device_properties(device).total_memory / (1024 ** 3) # 转为 GB
|
|
|
+ return total_memory
|
|
|
+ else:
|
|
|
+ return None
|