server.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. import sys
  2. from mineru.backend.vlm.custom_logits_processors import enable_custom_logits_processors
  3. from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
  4. from vllm.entrypoints.cli.main import main as vllm_main
  5. def main():
  6. args = sys.argv[1:]
  7. has_port_arg = False
  8. has_gpu_memory_utilization_arg = False
  9. has_logits_processors_arg = False
  10. model_path = None
  11. model_arg_indices = []
  12. # 检查现有参数
  13. for i, arg in enumerate(args):
  14. if arg == "--port" or arg.startswith("--port="):
  15. has_port_arg = True
  16. if arg == "--gpu-memory-utilization" or arg.startswith("--gpu-memory-utilization="):
  17. has_gpu_memory_utilization_arg = True
  18. if arg == "--logits-processors" or arg.startswith("--logits-processors="):
  19. has_logits_processors_arg = True
  20. if arg == "--model":
  21. if i + 1 < len(args):
  22. model_path = args[i + 1]
  23. model_arg_indices.extend([i, i + 1])
  24. elif arg.startswith("--model="):
  25. model_path = arg.split("=", 1)[1]
  26. model_arg_indices.append(i)
  27. # 从参数列表中移除 --model 参数
  28. if model_arg_indices:
  29. for index in sorted(model_arg_indices, reverse=True):
  30. args.pop(index)
  31. custom_logits_processors = enable_custom_logits_processors()
  32. # 添加默认参数
  33. if not has_port_arg:
  34. args.extend(["--port", "30000"])
  35. if not has_gpu_memory_utilization_arg:
  36. args.extend(["--gpu-memory-utilization", "0.5"])
  37. if not model_path:
  38. model_path = auto_download_and_get_model_root_path("/", "vlm")
  39. if not has_logits_processors_arg and custom_logits_processors:
  40. args.extend(["--logits-processors", "mineru_vl_utils:MinerULogitsProcessor"])
  41. # 重构参数,将模型路径作为位置参数
  42. sys.argv = [sys.argv[0]] + ["serve", model_path] + args
  43. # 启动vllm服务器
  44. print(f"start vllm server: {sys.argv}")
  45. vllm_main()
  46. if __name__ == "__main__":
  47. main()