server.py 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. import sys
  2. from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
  3. from vllm.entrypoints.cli.main import main as vllm_main
  4. def main():
  5. args = sys.argv[1:]
  6. has_port_arg = False
  7. has_gpu_memory_utilization_arg = False
  8. model_path = None
  9. model_arg_indices = []
  10. # 检查现有参数
  11. for i, arg in enumerate(args):
  12. if arg == "--port" or arg.startswith("--port="):
  13. has_port_arg = True
  14. if arg == "--gpu-memory-utilization" or arg.startswith("--gpu-memory-utilization="):
  15. has_gpu_memory_utilization_arg = True
  16. if arg == "--model":
  17. if i + 1 < len(args):
  18. model_path = args[i + 1]
  19. model_arg_indices.extend([i, i + 1])
  20. elif arg.startswith("--model="):
  21. model_path = arg.split("=", 1)[1]
  22. model_arg_indices.append(i)
  23. # 从参数列表中移除 --model 参数
  24. if model_arg_indices:
  25. for index in sorted(model_arg_indices, reverse=True):
  26. args.pop(index)
  27. # 添加默认参数
  28. if not has_port_arg:
  29. args.extend(["--port", "30000"])
  30. if not has_gpu_memory_utilization_arg:
  31. args.extend(["--gpu-memory-utilization", "0.5"])
  32. if not model_path:
  33. model_path = auto_download_and_get_model_root_path("/", "vlm")
  34. # 重构参数,将模型路径作为位置参数
  35. sys.argv = [sys.argv[0]] + ["serve", model_path] + args
  36. # 启动vllm服务器
  37. print(f"start vllm server: {sys.argv}")
  38. vllm_main()
  39. if __name__ == "__main__":
  40. main()