clean_memory.py 435 B

12345678910111213141516
  1. # Copyright (c) Opendatalab. All rights reserved.
  2. import torch
  3. import gc
  4. def clean_memory(device='cuda'):
  5. if device == 'cuda':
  6. if torch.cuda.is_available():
  7. torch.cuda.empty_cache()
  8. torch.cuda.ipc_collect()
  9. elif str(device).startswith("npu"):
  10. import torch_npu
  11. if torch.npu.is_available():
  12. torch_npu.empty_cache()
  13. torch_npu.ipc_collect()
  14. gc.collect()