server.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. import os
  2. import sys
  3. from mineru.backend.vlm.custom_logits_processors import enable_custom_logits_processors
  4. from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
  5. from vllm.entrypoints.cli.main import main as vllm_main
  6. def main():
  7. args = sys.argv[1:]
  8. has_port_arg = False
  9. has_gpu_memory_utilization_arg = False
  10. has_logits_processors_arg = False
  11. model_path = None
  12. model_arg_indices = []
  13. # 检查现有参数
  14. for i, arg in enumerate(args):
  15. if arg == "--port" or arg.startswith("--port="):
  16. has_port_arg = True
  17. if arg == "--gpu-memory-utilization" or arg.startswith("--gpu-memory-utilization="):
  18. has_gpu_memory_utilization_arg = True
  19. if arg == "--logits-processors" or arg.startswith("--logits-processors="):
  20. has_logits_processors_arg = True
  21. if arg == "--model":
  22. if i + 1 < len(args):
  23. model_path = args[i + 1]
  24. model_arg_indices.extend([i, i + 1])
  25. elif arg.startswith("--model="):
  26. model_path = arg.split("=", 1)[1]
  27. model_arg_indices.append(i)
  28. # 从参数列表中移除 --model 参数
  29. if model_arg_indices:
  30. for index in sorted(model_arg_indices, reverse=True):
  31. args.pop(index)
  32. custom_logits_processors = enable_custom_logits_processors()
  33. # 添加默认参数
  34. if not has_port_arg:
  35. args.extend(["--port", "30000"])
  36. if not has_gpu_memory_utilization_arg:
  37. args.extend(["--gpu-memory-utilization", "0.5"])
  38. if not model_path:
  39. model_path = auto_download_and_get_model_root_path("/", "vlm")
  40. if (not has_logits_processors_arg) and custom_logits_processors:
  41. args.extend(["--logits-processors", "mineru_vl_utils:MinerULogitsProcessor"])
  42. # 重构参数,将模型路径作为位置参数
  43. sys.argv = [sys.argv[0]] + ["serve", model_path] + args
  44. os.environ["OMP_NUM_THREADS"] = "1"
  45. # 启动vllm服务器
  46. print(f"start vllm server: {sys.argv}")
  47. vllm_main()
  48. if __name__ == "__main__":
  49. main()