server.py 1.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. import os
  2. import sys
  3. from fastapi import Request
  4. from sglang.srt.entrypoints.http_server import app, generate_request, launch_server
  5. from sglang.srt.managers.io_struct import GenerateReqInput
  6. from sglang.srt.server_args import prepare_server_args
  7. from sglang.srt.utils import kill_process_tree
  8. from .logit_processor import Mineru2LogitProcessor
  9. _custom_logit_processor_str = Mineru2LogitProcessor().to_str()
  10. # remote the existing /generate route
  11. for route in app.routes[:]:
  12. if hasattr(route, "path") and getattr(route, "path") == "/generate":
  13. app.routes.remove(route)
  14. # add the custom /generate route
  15. @app.api_route("/generate", methods=["POST", "PUT"])
  16. async def custom_generate_request(obj: GenerateReqInput, request: Request):
  17. if obj.custom_logit_processor is None:
  18. obj.custom_logit_processor = _custom_logit_processor_str
  19. return await generate_request(obj, request)
  20. def main():
  21. server_args = prepare_server_args(sys.argv[1:])
  22. if server_args.chat_template is None:
  23. server_args.chat_template = "chatml"
  24. server_args.enable_custom_logit_processor = True
  25. try:
  26. launch_server(server_args)
  27. finally:
  28. kill_process_tree(os.getpid(), include_parent=False)
  29. if __name__ == "__main__":
  30. main()