|
|
@@ -165,12 +165,14 @@ def doc_analyze(
|
|
|
import torch_npu
|
|
|
if torch_npu.npu.is_available():
|
|
|
npu_support = True
|
|
|
+ torch.npu.set_compile_mode(jit_compile=False)
|
|
|
|
|
|
if torch.cuda.is_available() and device != 'cpu' or npu_support:
|
|
|
gpu_memory = int(os.getenv("VIRTUAL_VRAM_SIZE", round(get_vram(device))))
|
|
|
if gpu_memory is not None and gpu_memory >= 8:
|
|
|
-
|
|
|
- if gpu_memory >= 16:
|
|
|
+ if gpu_memory >= 20:
|
|
|
+ batch_ratio = 16
|
|
|
+ elif gpu_memory >= 15:
|
|
|
batch_ratio = 8
|
|
|
elif gpu_memory >= 10:
|
|
|
batch_ratio = 4
|