Explorar el Código

Add compute capability check for custom logits processors in server.py and vlm_analyze.py

myhloli hace 1 mes
padre
commit
a4cac624df
Se han modificado 2 ficheros con 25 adiciones y 3 borrados
  1. 12 2
      mineru/backend/vlm/vlm_analyze.py
  2. 13 1
      mineru/model/vlm_vllm_model/server.py

+ 12 - 2
mineru/backend/vlm/vlm_analyze.py

@@ -43,6 +43,16 @@ class ModelSingleton:
             batch_size = 0
             if backend in ['transformers', 'vllm-engine', "vllm-async-engine"] and not model_path:
                 model_path = auto_download_and_get_model_root_path("/","vlm")
+                import torch
+                compute_capability = 0.0
+                custom_logits_processors = False
+                if torch.cuda.is_available():
+                    major, minor = torch.cuda.get_device_capability()
+                    compute_capability = float(major) + (float(minor) / 10.0)
+                    logger.info(f"compute_capability: {compute_capability}")
+                if compute_capability >= 8.0:
+                    custom_logits_processors = True
+
                 if backend == "transformers":
                     try:
                         from transformers import (
@@ -96,7 +106,7 @@ class ModelSingleton:
                         kwargs["gpu_memory_utilization"] = 0.5
                     if "model" not in kwargs:
                         kwargs["model"] = model_path
-                    if version.parse(vllm_version) >= version.parse("0.10.1") and "logits_processors" not in kwargs:
+                    if custom_logits_processors and version.parse(vllm_version) >= version.parse("0.10.1") and "logits_processors" not in kwargs:
                         kwargs["logits_processors"] = [MinerULogitsProcessor]
                     # 使用kwargs为 vllm初始化参数
                     vllm_llm = vllm.LLM(**kwargs)
@@ -112,7 +122,7 @@ class ModelSingleton:
                         kwargs["gpu_memory_utilization"] = 0.5
                     if "model" not in kwargs:
                         kwargs["model"] = model_path
-                    if version.parse(vllm_version) >= version.parse("0.10.1") and "logits_processors" not in kwargs:
+                    if custom_logits_processors and version.parse(vllm_version) >= version.parse("0.10.1") and "logits_processors" not in kwargs:
                         kwargs["logits_processors"] = [MinerULogitsProcessor]
                     # 使用kwargs为 vllm初始化参数
                     vllm_async_llm = AsyncLLM.from_engine_args(AsyncEngineArgs(**kwargs))

+ 13 - 1
mineru/model/vlm_vllm_model/server.py

@@ -1,5 +1,7 @@
 import sys
 
+from loguru import logger
+
 from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
 
 from vllm.entrypoints.cli.main import main as vllm_main
@@ -37,6 +39,16 @@ def main():
         for index in sorted(model_arg_indices, reverse=True):
             args.pop(index)
 
+    import torch
+    compute_capability = 0.0
+    custom_logits_processors = False
+    if torch.cuda.is_available():
+        major, minor = torch.cuda.get_device_capability()
+        compute_capability = float(major) + (float(minor) / 10.0)
+        logger.info(f"compute_capability: {compute_capability}")
+    if compute_capability >= 8.0:
+        custom_logits_processors = True
+
     # 添加默认参数
     if not has_port_arg:
         args.extend(["--port", "30000"])
@@ -44,7 +56,7 @@ def main():
         args.extend(["--gpu-memory-utilization", "0.5"])
     if not model_path:
         model_path = auto_download_and_get_model_root_path("/", "vlm")
-    if not has_logits_processors_arg and version.parse(vllm_version) >= version.parse("0.10.1"):
+    if not has_logits_processors_arg and custom_logits_processors and version.parse(vllm_version) >= version.parse("0.10.1"):
         args.extend(["--logits-processors", "mineru_vl_utils:MinerULogitsProcessor"])
 
     # 重构参数,将模型路径作为位置参数