server.py 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. import os
  2. import sys
  3. from fastapi import Request
  4. from sglang.srt.entrypoints.http_server import app, generate_request, launch_server
  5. from sglang.srt.managers.io_struct import GenerateReqInput
  6. from sglang.srt.server_args import prepare_server_args
  7. from sglang.srt.utils import kill_process_tree
  8. from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
  9. from .logit_processor import Mineru2LogitProcessor
  10. _custom_logit_processor_str = Mineru2LogitProcessor().to_str()
  11. # remote the existing /generate route
  12. for route in app.routes[:]:
  13. if hasattr(route, "path") and getattr(route, "path") == "/generate":
  14. app.routes.remove(route)
  15. # add the custom /generate route
  16. @app.api_route("/generate", methods=["POST", "PUT"])
  17. async def custom_generate_request(obj: GenerateReqInput, request: Request):
  18. if obj.custom_logit_processor is None:
  19. obj.custom_logit_processor = _custom_logit_processor_str
  20. return await generate_request(obj, request)
  21. def main():
  22. server_args = prepare_server_args(sys.argv[1:])
  23. if server_args.chat_template is None:
  24. server_args.chat_template = "chatml"
  25. server_args.enable_custom_logit_processor = True
  26. if server_args.model_path is None:
  27. server_args.model_path = auto_download_and_get_model_root_path("/","vlm")
  28. try:
  29. launch_server(server_args)
  30. finally:
  31. kill_process_tree(os.getpid(), include_parent=False)
  32. if __name__ == "__main__":
  33. main()