|
|
@@ -31,13 +31,54 @@ class ModelSingleton:
|
|
|
) -> MinerUClient:
|
|
|
key = (backend, model_path, server_url)
|
|
|
if key not in self._models:
|
|
|
+ model = None
|
|
|
+ processor = None
|
|
|
+ vllm_llm = None
|
|
|
if backend in ['transformers', 'vllm-engine'] and not model_path:
|
|
|
model_path = auto_download_and_get_model_root_path("/","vlm")
|
|
|
+ if backend == "transformers":
|
|
|
+ if not model_path:
|
|
|
+ raise ValueError("model_path must be provided when model or processor is None.")
|
|
|
+
|
|
|
+ try:
|
|
|
+ from transformers import (
|
|
|
+ AutoProcessor,
|
|
|
+ Qwen2VLForConditionalGeneration,
|
|
|
+ )
|
|
|
+ from transformers import __version__ as transformers_version
|
|
|
+ except ImportError:
|
|
|
+ raise ImportError("Please install transformers to use the transformers backend.")
|
|
|
+
|
|
|
+ from packaging import version
|
|
|
+ if version.parse(transformers_version) >= version.parse("4.56.0"):
|
|
|
+ dtype_key = "dtype"
|
|
|
+ else:
|
|
|
+ dtype_key = "torch_dtype"
|
|
|
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
|
|
|
+ model_path,
|
|
|
+ device_map="auto",
|
|
|
+ **{dtype_key: "auto"}, # type: ignore
|
|
|
+ )
|
|
|
+ processor = AutoProcessor.from_pretrained(
|
|
|
+ model_path,
|
|
|
+ use_fast=True,
|
|
|
+ )
|
|
|
+ elif backend == "vllm-engine":
|
|
|
+ if not model_path:
|
|
|
+ raise ValueError("model_path must be provided when vllm_llm is None.")
|
|
|
+ try:
|
|
|
+ import vllm
|
|
|
+ except ImportError:
|
|
|
+ raise ImportError("Please install vllm to use the vllm-engine backend.")
|
|
|
+ logger.debug(kwargs)
|
|
|
+ # 使用kwargs为 vllm初始化参数
|
|
|
+ vllm_llm = vllm.LLM(model_path, **kwargs)
|
|
|
self._models[key] = MinerUClient(
|
|
|
backend=backend,
|
|
|
- model_path=model_path,
|
|
|
+ model=model,
|
|
|
+ processor=processor,
|
|
|
+ vllm_llm=vllm_llm,
|
|
|
server_url=server_url,
|
|
|
- **kwargs,
|
|
|
)
|
|
|
return self._models[key]
|
|
|
|