|
@@ -178,9 +178,9 @@ def doc_analyze(
|
|
|
gpu_memory = int(os.getenv("VIRTUAL_VRAM_SIZE", round(get_vram(device))))
|
|
gpu_memory = int(os.getenv("VIRTUAL_VRAM_SIZE", round(get_vram(device))))
|
|
|
if gpu_memory is not None and gpu_memory >= 8:
|
|
if gpu_memory is not None and gpu_memory >= 8:
|
|
|
|
|
|
|
|
- if 8 <= gpu_memory <= 10:
|
|
|
|
|
|
|
+ if 8 <= gpu_memory < 10:
|
|
|
batch_ratio = 2
|
|
batch_ratio = 2
|
|
|
- elif 10 < gpu_memory <= 12:
|
|
|
|
|
|
|
+ elif 10 <= gpu_memory <= 12:
|
|
|
batch_ratio = 4
|
|
batch_ratio = 4
|
|
|
elif 12 < gpu_memory <= 16:
|
|
elif 12 < gpu_memory <= 16:
|
|
|
batch_ratio = 8
|
|
batch_ratio = 8
|