|
|
@@ -6,10 +6,26 @@ from sglang.srt.entrypoints.http_server import app, generate_request, launch_ser
|
|
|
from sglang.srt.managers.io_struct import GenerateReqInput
|
|
|
from sglang.srt.server_args import prepare_server_args
|
|
|
from sglang.srt.utils import kill_process_tree
|
|
|
+from sglang.srt.conversation import Conversation
|
|
|
|
|
|
from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
|
|
|
from .logit_processor import Mineru2LogitProcessor
|
|
|
|
|
|
+# mineru2.0的chat_template与chatml在换行上有微小区别
|
|
|
+def custom_get_prompt(self) -> str:
|
|
|
+ system_prompt = self.system_template.format(system_message=self.system_message)
|
|
|
+ if self.system_message == "":
|
|
|
+ ret = ""
|
|
|
+ else:
|
|
|
+ ret = system_prompt + self.sep
|
|
|
+
|
|
|
+ for role, message in self.messages:
|
|
|
+ if message:
|
|
|
+ ret += role + "\n" + message + self.sep
|
|
|
+ else:
|
|
|
+ ret += role + "\n"
|
|
|
+ return ret
|
|
|
+
|
|
|
_custom_logit_processor_str = Mineru2LogitProcessor().to_str()
|
|
|
|
|
|
# remote the existing /generate route
|
|
|
@@ -45,6 +61,7 @@ def main():
|
|
|
|
|
|
if server_args.chat_template is None:
|
|
|
server_args.chat_template = "chatml"
|
|
|
+ Conversation.get_prompt = custom_get_prompt
|
|
|
|
|
|
server_args.enable_custom_logit_processor = True
|
|
|
|