sglang.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import json
  15. import os
  16. import subprocess
  17. import sys
  18. import tempfile
  19. import textwrap
  20. from ....utils.deps import require_genai_engine_plugin
  21. def run_sglang_server(host, port, model_name, model_dir, config, chat_template_path):
  22. require_genai_engine_plugin("sglang-server")
  23. data = json.dumps(
  24. {
  25. "host": host,
  26. "port": port,
  27. "model_name": model_name,
  28. "model_dir": model_dir,
  29. "config": config,
  30. "chat_template_path": str(chat_template_path),
  31. }
  32. )
  33. # HACK
  34. code = textwrap.dedent(
  35. f"""
  36. import json
  37. import os
  38. from paddlex.inference.genai.configs.utils import (
  39. backend_config_to_args,
  40. set_config_defaults,
  41. update_backend_config,
  42. )
  43. from paddlex.inference.genai.models import get_model_components
  44. from sglang.srt.configs.model_config import multimodal_model_archs
  45. from sglang.srt.entrypoints.http_server import launch_server
  46. from sglang.srt.managers.multimodal_processor import PROCESSOR_MAPPING
  47. from sglang.srt.models.registry import ModelRegistry
  48. from sglang.srt.server_args import prepare_server_args
  49. from sglang.srt.utils import kill_process_tree
  50. data = json.loads({repr(data)})
  51. host = data["host"]
  52. port = data["port"]
  53. model_name = data["model_name"]
  54. model_dir = data["model_dir"]
  55. config = data["config"]
  56. chat_template_path = data["chat_template_path"]
  57. network_class, processor_class = get_model_components(model_name, "sglang")
  58. ModelRegistry.models[network_class.__name__] = network_class
  59. multimodal_model_archs.append(network_class.__name__)
  60. PROCESSOR_MAPPING[network_class] = processor_class
  61. set_config_defaults(config, {{"served-model-name": model_name}})
  62. if chat_template_path:
  63. set_config_defaults(config, {{"chat-template": chat_template_path}})
  64. set_config_defaults(config, {{"enable-metrics": True}})
  65. update_backend_config(
  66. config,
  67. {{
  68. "model-path": model_dir,
  69. "host": host,
  70. "port": port,
  71. }},
  72. )
  73. if __name__ == "__main__":
  74. args = backend_config_to_args(config)
  75. server_args = prepare_server_args(args)
  76. try:
  77. launch_server(server_args)
  78. finally:
  79. kill_process_tree(os.getpid(), include_parent=False)
  80. """
  81. )
  82. with tempfile.NamedTemporaryFile("w", suffix=".py", delete=False) as f:
  83. f.write(code)
  84. script_path = f.name
  85. try:
  86. subprocess.check_call([sys.executable, script_path])
  87. finally:
  88. os.unlink(script_path)