| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475 |
- import os
- import sys
- from fastapi import Request
- from sglang.srt.entrypoints.http_server import app, generate_request, launch_server
- from sglang.srt.managers.io_struct import GenerateReqInput
- from sglang.srt.server_args import prepare_server_args
- from sglang.srt.utils import kill_process_tree
- from sglang.srt.conversation import Conversation
- from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
- from .logit_processor import Mineru2LogitProcessor
- # mineru2.0的chat_template与chatml在换行上有微小区别
- def custom_get_prompt(self) -> str:
- system_prompt = self.system_template.format(system_message=self.system_message)
- if self.system_message == "":
- ret = ""
- else:
- ret = system_prompt + self.sep
- for role, message in self.messages:
- if message:
- ret += role + "\n" + message + self.sep
- else:
- ret += role + "\n"
- return ret
- _custom_logit_processor_str = Mineru2LogitProcessor().to_str()
- # remote the existing /generate route
- for route in app.routes[:]:
- if hasattr(route, "path") and getattr(route, "path") == "/generate":
- app.routes.remove(route)
- # add the custom /generate route
- @app.api_route("/generate", methods=["POST", "PUT"])
- async def custom_generate_request(obj: GenerateReqInput, request: Request):
- if obj.custom_logit_processor is None:
- obj.custom_logit_processor = _custom_logit_processor_str
- return await generate_request(obj, request)
- def main():
- # 检查命令行参数中是否包含--model-path
- args = sys.argv[1:]
- has_model_path_arg = False
- for i, arg in enumerate(args):
- if arg == "--model-path" or arg.startswith("--model-path="):
- has_model_path_arg = True
- break
- # 如果没有--model-path参数,在参数列表中添加它
- if not has_model_path_arg:
- default_path = auto_download_and_get_model_root_path("/", "vlm")
- args.extend(["--model-path", default_path])
- server_args = prepare_server_args(args)
- if server_args.chat_template is None:
- server_args.chat_template = "chatml"
- Conversation.get_prompt = custom_get_prompt
- server_args.enable_custom_logit_processor = True
- try:
- launch_server(server_args)
- finally:
- kill_process_tree(os.getpid(), include_parent=False)
- if __name__ == "__main__":
- main()
|