import argparse import os from openai import OpenAI from transformers.utils.versions import require_version from PIL import Image import io import base64 from dots_ocr.utils.image_utils import PILimage_to_base64 from dots_ocr.utils import dict_promptmode_to_prompt # from dots_ocr.model.inference import inference_with_vllm def inference_with_vllm( image, prompt, ip="localhost", port=8000, temperature=0.1, top_p=0.9, # max_completion_tokens=32768, model_name='model', ): addr = f"http://{ip}:{port}/v1" client = OpenAI(api_key="{}".format(os.environ.get("API_KEY", "0")), base_url=addr) messages = [] messages.append( { "role": "user", "content": [ { "type": "image_url", "image_url": {"url": PILimage_to_base64(image)}, }, {"type": "text", "text": f"<|img|><|imgpad|><|endofimg|>{prompt}"} # if no "<|img|><|imgpad|><|endofimg|>" here,vllm v1 will add "\n" here ], } ) try: response = client.chat.completions.create( messages=messages, model=model_name, # max_completion_tokens=max_completion_tokens, temperature=temperature, top_p=top_p) response = response.choices[0].message.content return response except requests.exceptions.RequestException as e: print(f"request error: {e}") return None parser = argparse.ArgumentParser() parser.add_argument("--ip", type=str, default="localhost") parser.add_argument("--port", type=str, default="8101") parser.add_argument("--model_name", type=str, default="DotsOCR") parser.add_argument("--prompt_mode", type=str, default="prompt_layout_all_en") args = parser.parse_args() require_version("openai>=1.5.0", "To fix: pip install openai>=1.5.0") def main(): addr = f"http://{args.ip}:{args.port}/v1" image_path = "../demo/demo_image1.jpg" prompt = dict_promptmode_to_prompt[args.prompt_mode] image = Image.open(image_path) response = inference_with_vllm( image, prompt, ip=args.ip, port=args.port, temperature=0.1, top_p=0.9, # max_completion_tokens=32768*2, model_name=args.model_name, ) print(f"response: {response}") if __name__ == "__main__": main()