| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182 |
- import argparse
- import os
- from openai import OpenAI
- from transformers.utils.versions import require_version
- from PIL import Image
- import io
- import base64
- from dots_ocr.utils.image_utils import PILimage_to_base64
- from dots_ocr.utils import dict_promptmode_to_prompt
- # from dots_ocr.model.inference import inference_with_vllm
- def inference_with_vllm(
- image,
- prompt,
- ip="localhost",
- port=8000,
- temperature=0.1,
- top_p=0.9,
- # max_completion_tokens=32768,
- model_name='model',
- ):
-
- addr = f"http://{ip}:{port}/v1"
- client = OpenAI(api_key="{}".format(os.environ.get("API_KEY", "0")), base_url=addr)
- messages = []
- messages.append(
- {
- "role": "user",
- "content": [
- {
- "type": "image_url",
- "image_url": {"url": PILimage_to_base64(image)},
- },
- {"type": "text", "text": f"<|img|><|imgpad|><|endofimg|>{prompt}"} # if no "<|img|><|imgpad|><|endofimg|>" here,vllm v1 will add "\n" here
- ],
- }
- )
- try:
- response = client.chat.completions.create(
- messages=messages,
- model=model_name,
- # max_completion_tokens=max_completion_tokens,
- temperature=temperature,
- top_p=top_p)
- response = response.choices[0].message.content
- return response
- except requests.exceptions.RequestException as e:
- print(f"request error: {e}")
- return None
- parser = argparse.ArgumentParser()
- parser.add_argument("--ip", type=str, default="localhost")
- parser.add_argument("--port", type=str, default="8101")
- parser.add_argument("--model_name", type=str, default="DotsOCR")
- parser.add_argument("--prompt_mode", type=str, default="prompt_layout_all_en")
- args = parser.parse_args()
- require_version("openai>=1.5.0", "To fix: pip install openai>=1.5.0")
- def main():
- addr = f"http://{args.ip}:{args.port}/v1"
- image_path = "../demo/demo_image1.jpg"
- prompt = dict_promptmode_to_prompt[args.prompt_mode]
- image = Image.open(image_path)
- response = inference_with_vllm(
- image,
- prompt,
- ip=args.ip,
- port=args.port,
- temperature=0.1,
- top_p=0.9,
- # max_completion_tokens=32768*2,
- model_name=args.model_name,
- )
- print(f"response: {response}")
- if __name__ == "__main__":
- main()
|