server.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import os
  2. import sys
  3. from fastapi import Request
  4. from sglang.srt.entrypoints.http_server import app, generate_request, launch_server
  5. from sglang.srt.managers.io_struct import GenerateReqInput
  6. from sglang.srt.server_args import prepare_server_args
  7. from sglang.srt.utils import kill_process_tree
  8. from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
  9. from .logit_processor import Mineru2LogitProcessor
  10. _custom_logit_processor_str = Mineru2LogitProcessor().to_str()
  11. # remote the existing /generate route
  12. for route in app.routes[:]:
  13. if hasattr(route, "path") and getattr(route, "path") == "/generate":
  14. app.routes.remove(route)
  15. # add the custom /generate route
  16. @app.api_route("/generate", methods=["POST", "PUT"])
  17. async def custom_generate_request(obj: GenerateReqInput, request: Request):
  18. if obj.custom_logit_processor is None:
  19. obj.custom_logit_processor = _custom_logit_processor_str
  20. return await generate_request(obj, request)
  21. def main():
  22. # 检查命令行参数中是否包含--model-path
  23. args = sys.argv[1:]
  24. has_model_path_arg = False
  25. for i, arg in enumerate(args):
  26. if arg == "--model-path" or arg.startswith("--model-path="):
  27. has_model_path_arg = True
  28. break
  29. # 如果没有--model-path参数,在参数列表中添加它
  30. if not has_model_path_arg:
  31. default_path = auto_download_and_get_model_root_path("/", "vlm")
  32. args.extend(["--model-path", default_path])
  33. server_args = prepare_server_args(args)
  34. if server_args.chat_template is None:
  35. server_args.chat_template = "chatml"
  36. server_args.enable_custom_logit_processor = True
  37. try:
  38. launch_server(server_args)
  39. finally:
  40. kill_process_tree(os.getpid(), include_parent=False)
  41. if __name__ == "__main__":
  42. main()