Pārlūkot izejas kodu

Merge pull request #3822 from opendatalab/dev

Dev
Xiaomeng Zhao 3 nedēļas atpakaļ
vecāks
revīzija
6131013ce9
2 mainītis faili ar 49 papildinājumiem un 26 dzēšanām
  1. 36 2
      mineru/backend/vlm/utils.py
  2. 13 24
      mineru/backend/vlm/vlm_analyze.py

+ 36 - 2
mineru/backend/vlm/custom_logits_processors.py → mineru/backend/vlm/utils.py

@@ -3,8 +3,11 @@ import os
 from loguru import logger
 from packaging import version
 
+from mineru.utils.config_reader import get_device
+from mineru.utils.model_utils import get_vram
 
-def enable_custom_logits_processors():
+
+def enable_custom_logits_processors() -> bool:
     import torch
     from vllm import __version__ as vllm_version
 
@@ -38,4 +41,35 @@ def enable_custom_logits_processors():
             return False
     else:
         logger.info(f"compute_capability: {compute_capability} >= 8.0 and vllm version: {vllm_version} >= 0.10.1, enable custom_logits_processors")
-        return True
+        return True
+
+
+def set_defult_gpu_memory_utilization() -> float:
+    from vllm import __version__ as vllm_version
+    if version.parse(vllm_version) >= version.parse("0.11.0"):
+        return 0.7
+    else:
+        return 0.5
+
+
+def set_defult_batch_size() -> int:
+    try:
+        device = get_device()
+        vram = get_vram(device)
+        if vram is not None:
+            gpu_memory = int(os.getenv('MINERU_VIRTUAL_VRAM_SIZE', round(vram)))
+            if gpu_memory >= 16:
+                batch_size = 8
+            elif gpu_memory >= 8:
+                batch_size = 4
+            else:
+                batch_size = 1
+            logger.info(f'gpu_memory: {gpu_memory} GB, batch_size: {batch_size}')
+        else:
+            # Default batch_ratio when VRAM can't be determined
+            batch_size = 1
+            logger.info(f'Could not determine GPU memory, using default batch_ratio: {batch_size}')
+    except Exception as e:
+        logger.warning(f'Error determining VRAM: {e}, using default batch_ratio: 1')
+        batch_size = 1
+    return batch_size

+ 13 - 24
mineru/backend/vlm/vlm_analyze.py

@@ -4,14 +4,13 @@ import time
 
 from loguru import logger
 
-from .custom_logits_processors import enable_custom_logits_processors
+from .utils import enable_custom_logits_processors, set_defult_gpu_memory_utilization, set_defult_batch_size
 from .model_output_to_middle_json import result_to_middle_json
 from ...data.data_reader_writer import DataWriter
 from mineru.utils.pdf_image_tools import load_images_from_pdf
 from ...utils.config_reader import get_device
 
 from ...utils.enum_class import ImageType
-from ...utils.model_utils import get_vram
 from ...utils.models_download_utils import auto_download_and_get_model_root_path
 
 from mineru_vl_utils import MinerUClient
@@ -41,8 +40,13 @@ class ModelSingleton:
             processor = None
             vllm_llm = None
             vllm_async_llm = None
-            batch_size = 0
-            max_concurrency = kwargs.get("max_concurrency", 100)
+            batch_size = kwargs.get("batch_size", 0)  # for transformers backend only
+            max_concurrency = kwargs.get("max_concurrency", 100)  # for http-client backend only
+            http_timeout = kwargs.get("http_timeout", 600)  # for http-client backend only
+            # 从kwargs中移除这些参数,避免传递给不相关的初始化函数
+            for param in ["batch_size", "max_concurrency", "http_timeout"]:
+                if param in kwargs:
+                    del kwargs[param]
             if backend in ['transformers', 'vllm-engine', "vllm-async-engine"] and not model_path:
                 model_path = auto_download_and_get_model_root_path("/","vlm")
                 if backend == "transformers":
@@ -69,24 +73,8 @@ class ModelSingleton:
                         model_path,
                         use_fast=True,
                     )
-                    try:
-                        vram = get_vram(device)
-                        if vram is not None:
-                            gpu_memory = int(os.getenv('MINERU_VIRTUAL_VRAM_SIZE', round(vram)))
-                            if gpu_memory >= 16:
-                                batch_size = 8
-                            elif gpu_memory >= 8:
-                                batch_size = 4
-                            else:
-                                batch_size = 1
-                            logger.info(f'gpu_memory: {gpu_memory} GB, batch_size: {batch_size}')
-                        else:
-                            # Default batch_ratio when VRAM can't be determined
-                            batch_size = 1
-                            logger.info(f'Could not determine GPU memory, using default batch_ratio: {batch_size}')
-                    except Exception as e:
-                        logger.warning(f'Error determining VRAM: {e}, using default batch_ratio: 1')
-                        batch_size = 1
+                    if batch_size == 0:
+                        batch_size = set_defult_batch_size()
                 else:
                     os.environ["OMP_NUM_THREADS"] = "1"
                     if backend == "vllm-engine":
@@ -96,7 +84,7 @@ class ModelSingleton:
                         except ImportError:
                             raise ImportError("Please install vllm to use the vllm-engine backend.")
                         if "gpu_memory_utilization" not in kwargs:
-                            kwargs["gpu_memory_utilization"] = 0.7
+                            kwargs["gpu_memory_utilization"] = set_defult_gpu_memory_utilization()
                         if "model" not in kwargs:
                             kwargs["model"] = model_path
                         if enable_custom_logits_processors() and ("logits_processors" not in kwargs):
@@ -111,7 +99,7 @@ class ModelSingleton:
                         except ImportError:
                             raise ImportError("Please install vllm to use the vllm-async-engine backend.")
                         if "gpu_memory_utilization" not in kwargs:
-                            kwargs["gpu_memory_utilization"] = 0.7
+                            kwargs["gpu_memory_utilization"] = set_defult_gpu_memory_utilization()
                         if "model" not in kwargs:
                             kwargs["model"] = model_path
                         if enable_custom_logits_processors() and ("logits_processors" not in kwargs):
@@ -127,6 +115,7 @@ class ModelSingleton:
                 server_url=server_url,
                 batch_size=batch_size,
                 max_concurrency=max_concurrency,
+                http_timeout=http_timeout,
             )
             elapsed = round(time.time() - start_time, 2)
             logger.info(f"get {backend} predictor cost: {elapsed}s")