Просмотр исходного кода

fix: enhance model initialization for transformers and vllm-engine backends in vlm_analyze.py

myhloli 2 месяцев назад
Родитель
Сommit
20e1dfe984
1 измененных файлов с 43 добавлено и 2 удалено
  1. 43 2
      mineru/backend/vlm/vlm_analyze.py

+ 43 - 2
mineru/backend/vlm/vlm_analyze.py

@@ -31,13 +31,54 @@ class ModelSingleton:
     ) -> MinerUClient:
         key = (backend, model_path, server_url)
         if key not in self._models:
+            model = None
+            processor = None
+            vllm_llm = None
             if backend in ['transformers', 'vllm-engine'] and not model_path:
                 model_path = auto_download_and_get_model_root_path("/","vlm")
+                if backend == "transformers":
+                    if not model_path:
+                        raise ValueError("model_path must be provided when model or processor is None.")
+
+                    try:
+                        from transformers import (
+                            AutoProcessor,
+                            Qwen2VLForConditionalGeneration,
+                        )
+                        from transformers import __version__ as transformers_version
+                    except ImportError:
+                        raise ImportError("Please install transformers to use the transformers backend.")
+
+                    from packaging import version
+                    if version.parse(transformers_version) >= version.parse("4.56.0"):
+                        dtype_key = "dtype"
+                    else:
+                        dtype_key = "torch_dtype"
+                    model = Qwen2VLForConditionalGeneration.from_pretrained(
+                        model_path,
+                        device_map="auto",
+                        **{dtype_key: "auto"},  # type: ignore
+                    )
+                    processor = AutoProcessor.from_pretrained(
+                        model_path,
+                        use_fast=True,
+                    )
+                elif backend == "vllm-engine":
+                    if not model_path:
+                        raise ValueError("model_path must be provided when vllm_llm is None.")
+                    try:
+                        import vllm
+                    except ImportError:
+                        raise ImportError("Please install vllm to use the vllm-engine backend.")
+                    logger.debug(kwargs)
+                    # 使用kwargs为 vllm初始化参数
+                    vllm_llm = vllm.LLM(model_path, **kwargs)
             self._models[key] = MinerUClient(
                 backend=backend,
-                model_path=model_path,
+                model=model,
+                processor=processor,
+                vllm_llm=vllm_llm,
                 server_url=server_url,
-                **kwargs,
             )
         return self._models[key]