|
|
@@ -6,6 +6,12 @@ import numpy as np
|
|
|
|
|
|
from mineru.utils.boxbase import get_minbox_if_overlap_by_ratio
|
|
|
|
|
|
+try:
|
|
|
+ import torch
|
|
|
+ import torch_npu
|
|
|
+except ImportError:
|
|
|
+ pass
|
|
|
+
|
|
|
|
|
|
def crop_img(input_res, input_img, crop_paste_x=0, crop_paste_y=0):
|
|
|
|
|
|
@@ -297,14 +303,11 @@ def get_res_list_from_layout_res(layout_res, iou_threshold=0.7, overlap_threshol
|
|
|
|
|
|
|
|
|
def clean_memory(device='cuda'):
|
|
|
- import torch
|
|
|
-
|
|
|
if device == 'cuda':
|
|
|
if torch.cuda.is_available():
|
|
|
torch.cuda.empty_cache()
|
|
|
torch.cuda.ipc_collect()
|
|
|
elif str(device).startswith("npu"):
|
|
|
- import torch_npu
|
|
|
if torch_npu.npu.is_available():
|
|
|
torch_npu.npu.empty_cache()
|
|
|
elif str(device).startswith("mps"):
|
|
|
@@ -322,13 +325,10 @@ def clean_vram(device, vram_threshold=8):
|
|
|
|
|
|
|
|
|
def get_vram(device):
|
|
|
- import torch
|
|
|
-
|
|
|
if torch.cuda.is_available() and str(device).startswith("cuda"):
|
|
|
total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3) # 将字节转换为 GB
|
|
|
return total_memory
|
|
|
elif str(device).startswith("npu"):
|
|
|
- import torch_npu
|
|
|
if torch_npu.npu.is_available():
|
|
|
total_memory = torch_npu.npu.get_device_properties(device).total_memory / (1024 ** 3) # 转为 GB
|
|
|
return total_memory
|