Просмотр исходного кода

新增 vLLM 推理功能,支持图像输入和自定义提示

zhch158_admin 3 месяцев назад
Родитель
Сommit
6a3993fdf3
1 измененных файлов с 82 добавлено и 0 удалено
  1. 82 0
      zhch/demo_vllm.py

+ 82 - 0
zhch/demo_vllm.py

@@ -0,0 +1,82 @@
+import argparse
+import os
+
+from openai import OpenAI
+from transformers.utils.versions import require_version
+from PIL import Image
+import io
+import base64
+from dots_ocr.utils.image_utils import PILimage_to_base64
+from dots_ocr.utils import dict_promptmode_to_prompt
+# from dots_ocr.model.inference import inference_with_vllm
+
+def inference_with_vllm(
+        image,
+        prompt, 
+        ip="localhost",
+        port=8000,
+        temperature=0.1,
+        top_p=0.9,
+        # max_completion_tokens=32768,
+        model_name='model',
+        ):
+    
+    addr = f"http://{ip}:{port}/v1"
+    client = OpenAI(api_key="{}".format(os.environ.get("API_KEY", "0")), base_url=addr)
+    messages = []
+    messages.append(
+        {
+            "role": "user",
+            "content": [
+                {
+                    "type": "image_url",
+                    "image_url": {"url":  PILimage_to_base64(image)},
+                },
+                {"type": "text", "text": f"<|img|><|imgpad|><|endofimg|>{prompt}"}  # if no "<|img|><|imgpad|><|endofimg|>" here,vllm v1 will add "\n" here
+            ],
+        }
+    )
+    try:
+        response = client.chat.completions.create(
+            messages=messages, 
+            model=model_name, 
+            # max_completion_tokens=max_completion_tokens,
+            temperature=temperature,
+            top_p=top_p)
+        response = response.choices[0].message.content
+        return response
+    except requests.exceptions.RequestException as e:
+        print(f"request error: {e}")
+        return None
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--ip", type=str, default="localhost")
+parser.add_argument("--port", type=str, default="8101")
+parser.add_argument("--model_name", type=str, default="DotsOCR")
+parser.add_argument("--prompt_mode", type=str, default="prompt_layout_all_en")
+
+args = parser.parse_args()
+
+require_version("openai>=1.5.0", "To fix: pip install openai>=1.5.0")
+
+
+def main():
+    addr = f"http://{args.ip}:{args.port}/v1"
+    image_path = "../demo/demo_image1.jpg"
+    prompt = dict_promptmode_to_prompt[args.prompt_mode]
+    image = Image.open(image_path)
+    response = inference_with_vllm(
+        image,
+        prompt, 
+        ip=args.ip,
+        port=args.port,
+        temperature=0.1,
+        top_p=0.9,
+        # max_completion_tokens=32768*2,
+        model_name=args.model_name,
+    )
+    print(f"response: {response}")
+
+
+if __name__ == "__main__":
+    main()