Kaynağa Gözat

feat: add support for vllm-async-engine backend in vlm_analyze.py

myhloli 2 ay önce
ebeveyn
işleme
7a71cfe288
1 değiştirilmiş dosya ile 19 ekleme ve 1 silme
  1. 19 1
      mineru/backend/vlm/vlm_analyze.py

+ 19 - 1
mineru/backend/vlm/vlm_analyze.py

@@ -35,7 +35,8 @@ class ModelSingleton:
             model = None
             processor = None
             vllm_llm = None
-            if backend in ['transformers', 'vllm-engine'] and not model_path:
+            vllm_async_llm = None
+            if backend in ['transformers', 'vllm-engine', "vllm-async-engine"] and not model_path:
                 model_path = auto_download_and_get_model_root_path("/","vlm")
                 if backend == "transformers":
                     if not model_path:
@@ -78,11 +79,28 @@ class ModelSingleton:
                         kwargs["model"] = model_path
                     # 使用kwargs为 vllm初始化参数
                     vllm_llm = vllm.LLM(**kwargs)
+                elif backend == "vllm-async-engine":
+                    if not model_path:
+                        raise ValueError("model_path must be provided when vllm_llm is None.")
+                    try:
+                        from vllm.engine.arg_utils import AsyncEngineArgs
+                        from vllm.v1.engine.async_llm import AsyncLLM
+                    except ImportError:
+                        raise ImportError("Please install vllm to use the vllm-async-engine backend.")
+
+                    # logger.debug(kwargs)
+                    if "gpu_memory_utilization" not in kwargs:
+                        kwargs["gpu_memory_utilization"] = 0.5
+                    if "model" not in kwargs:
+                        kwargs["model"] = model_path
+                    # 使用kwargs为 vllm初始化参数
+                    vllm_async_llm = AsyncLLM.from_engine_args(AsyncEngineArgs(**kwargs))
             self._models[key] = MinerUClient(
                 backend=backend,
                 model=model,
                 processor=processor,
                 vllm_llm=vllm_llm,
+                vllm_async_llm=vllm_async_llm,
                 server_url=server_url,
             )
             elapsed = round(time.time() - start_time, 2)