Преглед на файлове

Refactor logits processor handling in server.py and vlm_analyze.py for improved clarity and consistency

myhloli преди 1 месец
родител
ревизия
f1659eb7a7
променени са 2 файла, в които са добавени 3 реда и са изтрити 6 реда
  1. 2 5
      mineru/backend/vlm/vlm_analyze.py
  2. 1 1
      mineru/model/vlm_vllm_model/server.py

+ 2 - 5
mineru/backend/vlm/vlm_analyze.py

@@ -44,9 +44,6 @@ class ModelSingleton:
             batch_size = 0
             if backend in ['transformers', 'vllm-engine', "vllm-async-engine"] and not model_path:
                 model_path = auto_download_and_get_model_root_path("/","vlm")
-
-                custom_logits_processors = enable_custom_logits_processors()
-
                 if backend == "transformers":
                     try:
                         from transformers import (
@@ -99,7 +96,7 @@ class ModelSingleton:
                         kwargs["gpu_memory_utilization"] = 0.5
                     if "model" not in kwargs:
                         kwargs["model"] = model_path
-                    if custom_logits_processors and "logits_processors" not in kwargs:
+                    if enable_custom_logits_processors() and ("logits_processors" not in kwargs):
                         kwargs["logits_processors"] = [MinerULogitsProcessor]
                     # 使用kwargs为 vllm初始化参数
                     vllm_llm = vllm.LLM(**kwargs)
@@ -114,7 +111,7 @@ class ModelSingleton:
                         kwargs["gpu_memory_utilization"] = 0.5
                     if "model" not in kwargs:
                         kwargs["model"] = model_path
-                    if custom_logits_processors and "logits_processors" not in kwargs:
+                    if enable_custom_logits_processors() and ("logits_processors" not in kwargs):
                         kwargs["logits_processors"] = [MinerULogitsProcessor]
                     # 使用kwargs为 vllm初始化参数
                     vllm_async_llm = AsyncLLM.from_engine_args(AsyncEngineArgs(**kwargs))

+ 1 - 1
mineru/model/vlm_vllm_model/server.py

@@ -45,7 +45,7 @@ def main():
         args.extend(["--gpu-memory-utilization", "0.5"])
     if not model_path:
         model_path = auto_download_and_get_model_root_path("/", "vlm")
-    if not has_logits_processors_arg and custom_logits_processors:
+    if (not has_logits_processors_arg) and custom_logits_processors:
         args.extend(["--logits-processors", "mineru_vl_utils:MinerULogitsProcessor"])
 
     # 重构参数,将模型路径作为位置参数