Ver código fonte

feat: 添加GPU内存管理功能,支持NPU和MPS设备的内存清理

zhch158_admin 3 semanas atrás
pai
commit
cf08d39079
2 arquivos alterados com 62 adições e 1 exclusões
  1. 44 1
      zhch/utils/cuda_utils.py
  2. 18 0
      zhch/utils/text_utils.py

+ 44 - 1
zhch/utils/cuda_utils.py

@@ -1,5 +1,13 @@
 import torch
+try:
+    import torch_npu
+except ImportError:
+    pass
 from typing import List
+import gc
+import os
+import time
+
 def detect_available_gpus() -> List[int]:
     """检测可用的GPU"""
     try:
@@ -21,4 +29,39 @@ def monitor_gpu_memory(gpu_ids: List[int] = [0, 1]):
             reserved = torch.cuda.memory_reserved(gpu_id) / 1024**3
             print(f"GPU {gpu_id} - 显存: {total:.2f}GB, 已分配: {allocated:.2f}GB, 已预留: {reserved:.2f}GB")
     except Exception as e:
-        print(f"GPU内存监控失败: {e}")
+        print(f"GPU内存监控失败: {e}")
+
+def clean_memory(device='cuda'):
+    if device == 'cuda':
+        if torch.cuda.is_available():
+            torch.cuda.empty_cache()
+            torch.cuda.ipc_collect()
+    elif str(device).startswith("npu"):
+        if torch_npu.npu.is_available():
+            torch_npu.npu.empty_cache()
+    elif str(device).startswith("mps"):
+        torch.mps.empty_cache()
+    gc.collect()
+
+
+def clean_vram(device, vram_threshold=8):
+    total_memory = get_vram(device)
+    if total_memory is not None:
+        total_memory = int(os.getenv('VIRTUAL_VRAM_SIZE', round(total_memory)))
+    if total_memory and total_memory <= vram_threshold:
+        gc_start = time.time()
+        clean_memory(device)
+        gc_time = round(time.time() - gc_start, 2)
+        # logger.info(f"gc time: {gc_time}")
+
+
+def get_vram(device):
+    if torch.cuda.is_available() and str(device).startswith("cuda"):
+        total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3)  # 将字节转换为 GB
+        return total_memory
+    elif str(device).startswith("npu"):
+        if torch_npu.npu.is_available():
+            total_memory = torch_npu.npu.get_device_properties(device).total_memory / (1024 ** 3)  # 转为 GB
+            return total_memory
+    else:
+        return None

+ 18 - 0
zhch/utils/text_utils.py

@@ -0,0 +1,18 @@
+def full_to_half(text: str) -> str:
+    """Convert full-width characters to half-width characters using code point manipulation.
+
+    Args:
+        text: String containing full-width characters
+
+    Returns:
+        String with full-width characters converted to half-width
+    """
+    result = []
+    for char in text:
+        code = ord(char)
+        # Full-width letters and numbers (FF21-FF3A for A-Z, FF41-FF5A for a-z, FF10-FF19 for 0-9)
+        if (0xFF21 <= code <= 0xFF3A) or (0xFF41 <= code <= 0xFF5A) or (0xFF10 <= code <= 0xFF19):
+            result.append(chr(code - 0xFEE0))  # Shift to ASCII range
+        else:
+            result.append(char)
+    return ''.join(result)