瀏覽代碼

feat: implement dynamic batch size calculation based on GPU memory in vlm_analyze.py

myhloli 2 月之前
父節點
當前提交
3ca520a3fe
共有 1 個文件被更改,包括 25 次插入1 次删除
  1. 25 1
      mineru/backend/vlm/vlm_analyze.py

+ 25 - 1
mineru/backend/vlm/vlm_analyze.py

@@ -1,4 +1,5 @@
 # Copyright (c) Opendatalab. All rights reserved.
+import os
 import time
 
 from loguru import logger
@@ -6,8 +7,10 @@ from loguru import logger
 from .model_output_to_middle_json import result_to_middle_json
 from ...data.data_reader_writer import DataWriter
 from mineru.utils.pdf_image_tools import load_images_from_pdf
+from ...utils.config_reader import get_device
 
 from ...utils.enum_class import ImageType
+from ...utils.model_utils import get_vram
 from ...utils.models_download_utils import auto_download_and_get_model_root_path
 
 from mineru_vl_utils import MinerUClient
@@ -36,6 +39,7 @@ class ModelSingleton:
             processor = None
             vllm_llm = None
             vllm_async_llm = None
+            batch_size = 0
             if backend in ['transformers', 'vllm-engine', "vllm-async-engine"] and not model_path:
                 model_path = auto_download_and_get_model_root_path("/","vlm")
                 if backend == "transformers":
@@ -53,15 +57,34 @@ class ModelSingleton:
                         dtype_key = "dtype"
                     else:
                         dtype_key = "torch_dtype"
+                    device = get_device()
                     model = Qwen2VLForConditionalGeneration.from_pretrained(
                         model_path,
-                        device_map="auto",
+                        device_map={"": device},
                         **{dtype_key: "auto"},  # type: ignore
                     )
                     processor = AutoProcessor.from_pretrained(
                         model_path,
                         use_fast=True,
                     )
+                    try:
+                        vram = get_vram(device)
+                        if vram is not None:
+                            gpu_memory = int(os.getenv('MINERU_VIRTUAL_VRAM_SIZE', round(vram)))
+                            if gpu_memory >= 16:
+                                batch_size = 8
+                            elif gpu_memory >= 8:
+                                batch_size = 4
+                            else:
+                                batch_size = 1
+                            logger.info(f'gpu_memory: {gpu_memory} GB, batch_size: {batch_size}')
+                        else:
+                            # Default batch_ratio when VRAM can't be determined
+                            batch_size = 1
+                            logger.info(f'Could not determine GPU memory, using default batch_ratio: {batch_size}')
+                    except Exception as e:
+                        logger.warning(f'Error determining VRAM: {e}, using default batch_ratio: 1')
+                        batch_size = 1
                 elif backend == "vllm-engine":
                     try:
                         import vllm
@@ -92,6 +115,7 @@ class ModelSingleton:
                 vllm_llm=vllm_llm,
                 vllm_async_llm=vllm_async_llm,
                 server_url=server_url,
+                batch_size=batch_size,
             )
             elapsed = round(time.time() - start_time, 2)
             logger.info(f"get {backend} predictor cost: {elapsed}s")