|
@@ -43,7 +43,7 @@ def get_res_list_from_layout_res(layout_res):
|
|
|
|
|
|
|
|
def clean_vram(device, vram_threshold=8):
|
|
def clean_vram(device, vram_threshold=8):
|
|
|
total_memory = get_vram(device)
|
|
total_memory = get_vram(device)
|
|
|
- if total_memory <= vram_threshold:
|
|
|
|
|
|
|
+ if total_memory and total_memory <= vram_threshold:
|
|
|
gc_start = time.time()
|
|
gc_start = time.time()
|
|
|
clean_memory()
|
|
clean_memory()
|
|
|
gc_time = round(time.time() - gc_start, 2)
|
|
gc_time = round(time.time() - gc_start, 2)
|
|
@@ -54,4 +54,4 @@ def get_vram(device):
|
|
|
if torch.cuda.is_available() and device != 'cpu':
|
|
if torch.cuda.is_available() and device != 'cpu':
|
|
|
total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3) # 将字节转换为 GB
|
|
total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3) # 将字节转换为 GB
|
|
|
return total_memory
|
|
return total_memory
|
|
|
- return 0
|
|
|
|
|
|
|
+ return None
|