Selaa lähdekoodia

fix: improve GPU memory utilization handling and ensure OMP_NUM_THREADS is set only if not defined

myhloli 3 viikkoa sitten
vanhempi
commit
0d0ebfd7bc
2 muutettua tiedostoa jossa 10 lisäystä ja 3 poistoa
  1. 4 1
      mineru/backend/vlm/vlm_analyze.py
  2. 6 2
      mineru/model/vlm_vllm_model/server.py

+ 4 - 1
mineru/backend/vlm/vlm_analyze.py

@@ -76,7 +76,10 @@ class ModelSingleton:
                     if batch_size == 0:
                         batch_size = set_defult_batch_size()
                 else:
-                    os.environ["OMP_NUM_THREADS"] = "1"
+
+                    if os.getenv('OMP_NUM_THREADS') is None:
+                        os.environ["OMP_NUM_THREADS"] = "1"
+
                     if backend == "vllm-engine":
                         try:
                             import vllm

+ 6 - 2
mineru/model/vlm_vllm_model/server.py

@@ -2,6 +2,8 @@ import os
 import sys
 
 from mineru.backend.vlm.custom_logits_processors import enable_custom_logits_processors
+
+from mineru.backend.vlm.utils import set_defult_gpu_memory_utilization
 from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
 
 from vllm.entrypoints.cli.main import main as vllm_main
@@ -43,7 +45,8 @@ def main():
     if not has_port_arg:
         args.extend(["--port", "30000"])
     if not has_gpu_memory_utilization_arg:
-        args.extend(["--gpu-memory-utilization", "0.7"])
+        gpu_memory_utilization = str(set_defult_gpu_memory_utilization())
+        args.extend(["--gpu-memory-utilization", gpu_memory_utilization])
     if not model_path:
         model_path = auto_download_and_get_model_root_path("/", "vlm")
     if (not has_logits_processors_arg) and custom_logits_processors:
@@ -52,7 +55,8 @@ def main():
     # 重构参数,将模型路径作为位置参数
     sys.argv = [sys.argv[0]] + ["serve", model_path] + args
 
-    os.environ["OMP_NUM_THREADS"] = "1"
+    if os.getenv('OMP_NUM_THREADS') is None:
+        os.environ["OMP_NUM_THREADS"] = "1"
 
     # 启动vllm服务器
     print(f"start vllm server: {sys.argv}")