clean_memory.py 479 B

1234567891011121314151617
  1. # Copyright (c) Opendatalab. All rights reserved.
  2. import torch
  3. import gc
  4. def clean_memory(device='cuda'):
  5. if device == 'cuda':
  6. if torch.cuda.is_available():
  7. torch.cuda.empty_cache()
  8. torch.cuda.ipc_collect()
  9. elif str(device).startswith("npu"):
  10. import torch_npu
  11. if torch_npu.npu.is_available():
  12. torch_npu.npu.empty_cache()
  13. elif str(device).startswith("mps"):
  14. torch.mps.empty_cache()
  15. gc.collect()