Quellcode durchsuchen

Refactor custom logits processors to include vllm version checks and improve logging

myhloli vor 1 Monat
Ursprung
Commit
7c481796f8

+ 12 - 3
mineru/backend/vlm/custom_logits_processors.py

@@ -1,10 +1,12 @@
 import os
 
 from loguru import logger
+from packaging import version
 
 
 def enable_custom_logits_processors():
     import torch
+    from vllm import __version__ as vllm_version
 
     if not torch.cuda.is_available():
         logger.info("CUDA not available, disabling custom_logits_processors")
@@ -25,9 +27,16 @@ def enable_custom_logits_processors():
     if vllm_use_v1 == 0:
         logger.info("VLLM_USE_V1 is set to 0, disabling custom_logits_processors")
         return False
-    elif compute_capability < 8.0:
-        logger.info(f"compute_capability: {compute_capability} < 8.0, disable custom_logits_processors")
+    elif version.parse(vllm_version) < version.parse("0.10.1"):
+        logger.info(f"vllm version: {vllm_version} < 0.10.1, disable custom_logits_processors")
         return False
+    elif compute_capability < 8.0:
+        if version.parse(vllm_version) >= version.parse("0.10.2"):
+            logger.info(f"compute_capability: {compute_capability} < 8.0, but vllm version: {vllm_version} >= 0.10.2, enable custom_logits_processors")
+            return True
+        else:
+            logger.info(f"compute_capability: {compute_capability} < 8.0 and vllm version: {vllm_version} < 0.10.2, disable custom_logits_processors")
+            return False
     else:
-        logger.info(f"compute_capability: {compute_capability} >= 8.0, enable custom_logits_processors")
+        logger.info(f"compute_capability: {compute_capability} >= 8.0 and vllm version: {vllm_version} >= 0.10.1, enable custom_logits_processors")
         return True

+ 2 - 4
mineru/backend/vlm/vlm_analyze.py

@@ -92,7 +92,6 @@ class ModelSingleton:
                 elif backend == "vllm-engine":
                     try:
                         import vllm
-                        vllm_version = vllm.__version__
                         from mineru_vl_utils import MinerULogitsProcessor
                     except ImportError:
                         raise ImportError("Please install vllm to use the vllm-engine backend.")
@@ -100,7 +99,7 @@ class ModelSingleton:
                         kwargs["gpu_memory_utilization"] = 0.5
                     if "model" not in kwargs:
                         kwargs["model"] = model_path
-                    if custom_logits_processors and version.parse(vllm_version) >= version.parse("0.10.1") and "logits_processors" not in kwargs:
+                    if custom_logits_processors and "logits_processors" not in kwargs:
                         kwargs["logits_processors"] = [MinerULogitsProcessor]
                     # 使用kwargs为 vllm初始化参数
                     vllm_llm = vllm.LLM(**kwargs)
@@ -108,7 +107,6 @@ class ModelSingleton:
                     try:
                         from vllm.engine.arg_utils import AsyncEngineArgs
                         from vllm.v1.engine.async_llm import AsyncLLM
-                        from vllm import __version__ as vllm_version
                         from mineru_vl_utils import MinerULogitsProcessor
                     except ImportError:
                         raise ImportError("Please install vllm to use the vllm-async-engine backend.")
@@ -116,7 +114,7 @@ class ModelSingleton:
                         kwargs["gpu_memory_utilization"] = 0.5
                     if "model" not in kwargs:
                         kwargs["model"] = model_path
-                    if custom_logits_processors and version.parse(vllm_version) >= version.parse("0.10.1") and "logits_processors" not in kwargs:
+                    if custom_logits_processors and "logits_processors" not in kwargs:
                         kwargs["logits_processors"] = [MinerULogitsProcessor]
                     # 使用kwargs为 vllm初始化参数
                     vllm_async_llm = AsyncLLM.from_engine_args(AsyncEngineArgs(**kwargs))

+ 1 - 3
mineru/model/vlm_vllm_model/server.py

@@ -4,8 +4,6 @@ from mineru.backend.vlm.custom_logits_processors import enable_custom_logits_pro
 from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
 
 from vllm.entrypoints.cli.main import main as vllm_main
-from vllm import __version__ as vllm_version
-from packaging import version
 
 
 def main():
@@ -47,7 +45,7 @@ def main():
         args.extend(["--gpu-memory-utilization", "0.5"])
     if not model_path:
         model_path = auto_download_and_get_model_root_path("/", "vlm")
-    if not has_logits_processors_arg and custom_logits_processors and version.parse(vllm_version) >= version.parse("0.10.1"):
+    if not has_logits_processors_arg and custom_logits_processors:
         args.extend(["--logits-processors", "mineru_vl_utils:MinerULogitsProcessor"])
 
     # 重构参数,将模型路径作为位置参数