|
@@ -43,6 +43,16 @@ class ModelSingleton:
|
|
|
batch_size = 0
|
|
batch_size = 0
|
|
|
if backend in ['transformers', 'vllm-engine', "vllm-async-engine"] and not model_path:
|
|
if backend in ['transformers', 'vllm-engine', "vllm-async-engine"] and not model_path:
|
|
|
model_path = auto_download_and_get_model_root_path("/","vlm")
|
|
model_path = auto_download_and_get_model_root_path("/","vlm")
|
|
|
|
|
+ import torch
|
|
|
|
|
+ compute_capability = 0.0
|
|
|
|
|
+ custom_logits_processors = False
|
|
|
|
|
+ if torch.cuda.is_available():
|
|
|
|
|
+ major, minor = torch.cuda.get_device_capability()
|
|
|
|
|
+ compute_capability = float(major) + (float(minor) / 10.0)
|
|
|
|
|
+ logger.info(f"compute_capability: {compute_capability}")
|
|
|
|
|
+ if compute_capability >= 8.0:
|
|
|
|
|
+ custom_logits_processors = True
|
|
|
|
|
+
|
|
|
if backend == "transformers":
|
|
if backend == "transformers":
|
|
|
try:
|
|
try:
|
|
|
from transformers import (
|
|
from transformers import (
|
|
@@ -96,7 +106,7 @@ class ModelSingleton:
|
|
|
kwargs["gpu_memory_utilization"] = 0.5
|
|
kwargs["gpu_memory_utilization"] = 0.5
|
|
|
if "model" not in kwargs:
|
|
if "model" not in kwargs:
|
|
|
kwargs["model"] = model_path
|
|
kwargs["model"] = model_path
|
|
|
- if version.parse(vllm_version) >= version.parse("0.10.1") and "logits_processors" not in kwargs:
|
|
|
|
|
|
|
+ if custom_logits_processors and version.parse(vllm_version) >= version.parse("0.10.1") and "logits_processors" not in kwargs:
|
|
|
kwargs["logits_processors"] = [MinerULogitsProcessor]
|
|
kwargs["logits_processors"] = [MinerULogitsProcessor]
|
|
|
# 使用kwargs为 vllm初始化参数
|
|
# 使用kwargs为 vllm初始化参数
|
|
|
vllm_llm = vllm.LLM(**kwargs)
|
|
vllm_llm = vllm.LLM(**kwargs)
|
|
@@ -112,7 +122,7 @@ class ModelSingleton:
|
|
|
kwargs["gpu_memory_utilization"] = 0.5
|
|
kwargs["gpu_memory_utilization"] = 0.5
|
|
|
if "model" not in kwargs:
|
|
if "model" not in kwargs:
|
|
|
kwargs["model"] = model_path
|
|
kwargs["model"] = model_path
|
|
|
- if version.parse(vllm_version) >= version.parse("0.10.1") and "logits_processors" not in kwargs:
|
|
|
|
|
|
|
+ if custom_logits_processors and version.parse(vllm_version) >= version.parse("0.10.1") and "logits_processors" not in kwargs:
|
|
|
kwargs["logits_processors"] = [MinerULogitsProcessor]
|
|
kwargs["logits_processors"] = [MinerULogitsProcessor]
|
|
|
# 使用kwargs为 vllm初始化参数
|
|
# 使用kwargs为 vllm初始化参数
|
|
|
vllm_async_llm = AsyncLLM.from_engine_args(AsyncEngineArgs(**kwargs))
|
|
vllm_async_llm = AsyncLLM.from_engine_args(AsyncEngineArgs(**kwargs))
|