demo_vllm.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. import argparse
  2. import os
  3. from openai import OpenAI
  4. from transformers.utils.versions import require_version
  5. from PIL import Image
  6. import io
  7. import base64
  8. from dots_ocr.utils.image_utils import PILimage_to_base64
  9. from dots_ocr.utils import dict_promptmode_to_prompt
  10. # from dots_ocr.model.inference import inference_with_vllm
  11. def inference_with_vllm(
  12. image,
  13. prompt,
  14. ip="localhost",
  15. port=8000,
  16. temperature=0.1,
  17. top_p=0.9,
  18. # max_completion_tokens=32768,
  19. model_name='model',
  20. ):
  21. addr = f"http://{ip}:{port}/v1"
  22. client = OpenAI(api_key="{}".format(os.environ.get("API_KEY", "0")), base_url=addr)
  23. messages = []
  24. messages.append(
  25. {
  26. "role": "user",
  27. "content": [
  28. {
  29. "type": "image_url",
  30. "image_url": {"url": PILimage_to_base64(image)},
  31. },
  32. {"type": "text", "text": f"<|img|><|imgpad|><|endofimg|>{prompt}"} # if no "<|img|><|imgpad|><|endofimg|>" here,vllm v1 will add "\n" here
  33. ],
  34. }
  35. )
  36. try:
  37. response = client.chat.completions.create(
  38. messages=messages,
  39. model=model_name,
  40. # max_completion_tokens=max_completion_tokens,
  41. temperature=temperature,
  42. top_p=top_p)
  43. response = response.choices[0].message.content
  44. return response
  45. except requests.exceptions.RequestException as e:
  46. print(f"request error: {e}")
  47. return None
  48. parser = argparse.ArgumentParser()
  49. parser.add_argument("--ip", type=str, default="localhost")
  50. parser.add_argument("--port", type=str, default="8101")
  51. parser.add_argument("--model_name", type=str, default="DotsOCR")
  52. parser.add_argument("--prompt_mode", type=str, default="prompt_layout_all_en")
  53. args = parser.parse_args()
  54. require_version("openai>=1.5.0", "To fix: pip install openai>=1.5.0")
  55. def main():
  56. addr = f"http://{args.ip}:{args.port}/v1"
  57. image_path = "../demo/demo_image1.jpg"
  58. prompt = dict_promptmode_to_prompt[args.prompt_mode]
  59. image = Image.open(image_path)
  60. response = inference_with_vllm(
  61. image,
  62. prompt,
  63. ip=args.ip,
  64. port=args.port,
  65. temperature=0.1,
  66. top_p=0.9,
  67. # max_completion_tokens=32768*2,
  68. model_name=args.model_name,
  69. )
  70. print(f"response: {response}")
  71. if __name__ == "__main__":
  72. main()