server.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. import sys
  2. from loguru import logger
  3. from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
  4. from vllm.entrypoints.cli.main import main as vllm_main
  5. from vllm import __version__ as vllm_version
  6. from packaging import version
  7. def main():
  8. args = sys.argv[1:]
  9. has_port_arg = False
  10. has_gpu_memory_utilization_arg = False
  11. has_logits_processors_arg = False
  12. model_path = None
  13. model_arg_indices = []
  14. # 检查现有参数
  15. for i, arg in enumerate(args):
  16. if arg == "--port" or arg.startswith("--port="):
  17. has_port_arg = True
  18. if arg == "--gpu-memory-utilization" or arg.startswith("--gpu-memory-utilization="):
  19. has_gpu_memory_utilization_arg = True
  20. if arg == "--logits-processors" or arg.startswith("--logits-processors="):
  21. has_logits_processors_arg = True
  22. if arg == "--model":
  23. if i + 1 < len(args):
  24. model_path = args[i + 1]
  25. model_arg_indices.extend([i, i + 1])
  26. elif arg.startswith("--model="):
  27. model_path = arg.split("=", 1)[1]
  28. model_arg_indices.append(i)
  29. # 从参数列表中移除 --model 参数
  30. if model_arg_indices:
  31. for index in sorted(model_arg_indices, reverse=True):
  32. args.pop(index)
  33. import torch
  34. compute_capability = 0.0
  35. custom_logits_processors = False
  36. if torch.cuda.is_available():
  37. major, minor = torch.cuda.get_device_capability()
  38. compute_capability = float(major) + (float(minor) / 10.0)
  39. logger.info(f"compute_capability: {compute_capability}")
  40. if compute_capability >= 8.0:
  41. custom_logits_processors = True
  42. # 添加默认参数
  43. if not has_port_arg:
  44. args.extend(["--port", "30000"])
  45. if not has_gpu_memory_utilization_arg:
  46. args.extend(["--gpu-memory-utilization", "0.5"])
  47. if not model_path:
  48. model_path = auto_download_and_get_model_root_path("/", "vlm")
  49. if not has_logits_processors_arg and custom_logits_processors and version.parse(vllm_version) >= version.parse("0.10.1"):
  50. args.extend(["--logits-processors", "mineru_vl_utils:MinerULogitsProcessor"])
  51. # 重构参数,将模型路径作为位置参数
  52. sys.argv = [sys.argv[0]] + ["serve", model_path] + args
  53. # 启动vllm服务器
  54. print(f"start vllm server: {sys.argv}")
  55. vllm_main()
  56. if __name__ == "__main__":
  57. main()