|
|
@@ -255,8 +255,9 @@ def may_batch_image_analyze(
|
|
|
torch.npu.set_compile_mode(jit_compile=False)
|
|
|
|
|
|
if str(device).startswith('npu') or str(device).startswith('cuda'):
|
|
|
- gpu_memory = int(os.getenv('VIRTUAL_VRAM_SIZE', round(get_vram(device))))
|
|
|
- if gpu_memory is not None:
|
|
|
+ vram = get_vram(device)
|
|
|
+ if vram is not None:
|
|
|
+ gpu_memory = int(os.getenv('VIRTUAL_VRAM_SIZE', round(vram)))
|
|
|
if gpu_memory >= 16:
|
|
|
batch_ratio = 16
|
|
|
elif gpu_memory >= 12:
|
|
|
@@ -268,6 +269,10 @@ def may_batch_image_analyze(
|
|
|
else:
|
|
|
batch_ratio = 1
|
|
|
logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
|
|
|
+ else:
|
|
|
+ # Default batch_ratio when VRAM can't be determined
|
|
|
+ batch_ratio = 1
|
|
|
+ logger.info(f'Could not determine GPU memory, using default batch_ratio: {batch_ratio}')
|
|
|
|
|
|
|
|
|
# doc_analyze_start = time.time()
|