Эх сурвалжийг харах

Refactor logits processor handling in server.py and vlm_analyze.py for improved clarity and consistency

myhloli 1 сар өмнө
parent
commit
f1659eb7a7

+ 2 - 5
mineru/backend/vlm/vlm_analyze.py

@@ -44,9 +44,6 @@ class ModelSingleton:
             batch_size = 0
             if backend in ['transformers', 'vllm-engine', "vllm-async-engine"] and not model_path:
                 model_path = auto_download_and_get_model_root_path("/","vlm")
-
-                custom_logits_processors = enable_custom_logits_processors()
-
                 if backend == "transformers":
                     try:
                         from transformers import (
@@ -99,7 +96,7 @@ class ModelSingleton:
                         kwargs["gpu_memory_utilization"] = 0.5
                     if "model" not in kwargs:
                         kwargs["model"] = model_path
-                    if custom_logits_processors and "logits_processors" not in kwargs:
+                    if enable_custom_logits_processors() and ("logits_processors" not in kwargs):
                         kwargs["logits_processors"] = [MinerULogitsProcessor]
                     # 使用kwargs为 vllm初始化参数
                     vllm_llm = vllm.LLM(**kwargs)
@@ -114,7 +111,7 @@ class ModelSingleton:
                         kwargs["gpu_memory_utilization"] = 0.5
                     if "model" not in kwargs:
                         kwargs["model"] = model_path
-                    if custom_logits_processors and "logits_processors" not in kwargs:
+                    if enable_custom_logits_processors() and ("logits_processors" not in kwargs):
                         kwargs["logits_processors"] = [MinerULogitsProcessor]
                     # 使用kwargs为 vllm初始化参数
                     vllm_async_llm = AsyncLLM.from_engine_args(AsyncEngineArgs(**kwargs))

+ 1 - 1
mineru/model/vlm_vllm_model/server.py

@@ -45,7 +45,7 @@ def main():
         args.extend(["--gpu-memory-utilization", "0.5"])
     if not model_path:
         model_path = auto_download_and_get_model_root_path("/", "vlm")
-    if not has_logits_processors_arg and custom_logits_processors:
+    if (not has_logits_processors_arg) and custom_logits_processors:
         args.extend(["--logits-processors", "mineru_vl_utils:MinerULogitsProcessor"])
 
     # 重构参数,将模型路径作为位置参数