server.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. import os
  2. import sys
  3. from fastapi import Request
  4. from sglang.srt.entrypoints.http_server import app, generate_request, launch_server
  5. from sglang.srt.managers.io_struct import GenerateReqInput
  6. from sglang.srt.server_args import prepare_server_args
  7. from sglang.srt.utils import kill_process_tree
  8. from sglang.srt.conversation import Conversation
  9. from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
  10. from .logit_processor import Mineru2LogitProcessor
  11. # mineru2.0的chat_template与chatml在换行上有微小区别
  12. def custom_get_prompt(self) -> str:
  13. system_prompt = self.system_template.format(system_message=self.system_message)
  14. if self.system_message == "":
  15. ret = ""
  16. else:
  17. ret = system_prompt + self.sep
  18. for role, message in self.messages:
  19. if message:
  20. ret += role + "\n" + message + self.sep
  21. else:
  22. ret += role + "\n"
  23. return ret
  24. _custom_logit_processor_str = Mineru2LogitProcessor().to_str()
  25. # remote the existing /generate route
  26. for route in app.routes[:]:
  27. if hasattr(route, "path") and getattr(route, "path") == "/generate":
  28. app.routes.remove(route)
  29. # add the custom /generate route
  30. @app.api_route("/generate", methods=["POST", "PUT"])
  31. async def custom_generate_request(obj: GenerateReqInput, request: Request):
  32. if obj.custom_logit_processor is None:
  33. obj.custom_logit_processor = _custom_logit_processor_str
  34. return await generate_request(obj, request)
  35. def main():
  36. # 检查命令行参数中是否包含--model-path
  37. args = sys.argv[1:]
  38. has_model_path_arg = False
  39. for i, arg in enumerate(args):
  40. if arg == "--model-path" or arg.startswith("--model-path="):
  41. has_model_path_arg = True
  42. break
  43. # 如果没有--model-path参数,在参数列表中添加它
  44. if not has_model_path_arg:
  45. default_path = auto_download_and_get_model_root_path("/", "vlm")
  46. args.extend(["--model-path", default_path])
  47. server_args = prepare_server_args(args)
  48. if server_args.chat_template is None:
  49. server_args.chat_template = "chatml"
  50. Conversation.get_prompt = custom_get_prompt
  51. server_args.enable_custom_logit_processor = True
  52. try:
  53. launch_server(server_args)
  54. finally:
  55. kill_process_tree(os.getpid(), include_parent=False)
  56. if __name__ == "__main__":
  57. main()