|
@@ -0,0 +1,82 @@
|
|
|
|
|
+import argparse
|
|
|
|
|
+import os
|
|
|
|
|
+
|
|
|
|
|
+from openai import OpenAI
|
|
|
|
|
+from transformers.utils.versions import require_version
|
|
|
|
|
+from PIL import Image
|
|
|
|
|
+import io
|
|
|
|
|
+import base64
|
|
|
|
|
+from dots_ocr.utils.image_utils import PILimage_to_base64
|
|
|
|
|
+from dots_ocr.utils import dict_promptmode_to_prompt
|
|
|
|
|
+# from dots_ocr.model.inference import inference_with_vllm
|
|
|
|
|
+
|
|
|
|
|
+def inference_with_vllm(
|
|
|
|
|
+ image,
|
|
|
|
|
+ prompt,
|
|
|
|
|
+ ip="localhost",
|
|
|
|
|
+ port=8000,
|
|
|
|
|
+ temperature=0.1,
|
|
|
|
|
+ top_p=0.9,
|
|
|
|
|
+ # max_completion_tokens=32768,
|
|
|
|
|
+ model_name='model',
|
|
|
|
|
+ ):
|
|
|
|
|
+
|
|
|
|
|
+ addr = f"http://{ip}:{port}/v1"
|
|
|
|
|
+ client = OpenAI(api_key="{}".format(os.environ.get("API_KEY", "0")), base_url=addr)
|
|
|
|
|
+ messages = []
|
|
|
|
|
+ messages.append(
|
|
|
|
|
+ {
|
|
|
|
|
+ "role": "user",
|
|
|
|
|
+ "content": [
|
|
|
|
|
+ {
|
|
|
|
|
+ "type": "image_url",
|
|
|
|
|
+ "image_url": {"url": PILimage_to_base64(image)},
|
|
|
|
|
+ },
|
|
|
|
|
+ {"type": "text", "text": f"<|img|><|imgpad|><|endofimg|>{prompt}"} # if no "<|img|><|imgpad|><|endofimg|>" here,vllm v1 will add "\n" here
|
|
|
|
|
+ ],
|
|
|
|
|
+ }
|
|
|
|
|
+ )
|
|
|
|
|
+ try:
|
|
|
|
|
+ response = client.chat.completions.create(
|
|
|
|
|
+ messages=messages,
|
|
|
|
|
+ model=model_name,
|
|
|
|
|
+ # max_completion_tokens=max_completion_tokens,
|
|
|
|
|
+ temperature=temperature,
|
|
|
|
|
+ top_p=top_p)
|
|
|
|
|
+ response = response.choices[0].message.content
|
|
|
|
|
+ return response
|
|
|
|
|
+ except requests.exceptions.RequestException as e:
|
|
|
|
|
+ print(f"request error: {e}")
|
|
|
|
|
+ return None
|
|
|
|
|
+
|
|
|
|
|
+parser = argparse.ArgumentParser()
|
|
|
|
|
+parser.add_argument("--ip", type=str, default="localhost")
|
|
|
|
|
+parser.add_argument("--port", type=str, default="8101")
|
|
|
|
|
+parser.add_argument("--model_name", type=str, default="DotsOCR")
|
|
|
|
|
+parser.add_argument("--prompt_mode", type=str, default="prompt_layout_all_en")
|
|
|
|
|
+
|
|
|
|
|
+args = parser.parse_args()
|
|
|
|
|
+
|
|
|
|
|
+require_version("openai>=1.5.0", "To fix: pip install openai>=1.5.0")
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def main():
|
|
|
|
|
+ addr = f"http://{args.ip}:{args.port}/v1"
|
|
|
|
|
+ image_path = "../demo/demo_image1.jpg"
|
|
|
|
|
+ prompt = dict_promptmode_to_prompt[args.prompt_mode]
|
|
|
|
|
+ image = Image.open(image_path)
|
|
|
|
|
+ response = inference_with_vllm(
|
|
|
|
|
+ image,
|
|
|
|
|
+ prompt,
|
|
|
|
|
+ ip=args.ip,
|
|
|
|
|
+ port=args.port,
|
|
|
|
|
+ temperature=0.1,
|
|
|
|
|
+ top_p=0.9,
|
|
|
|
|
+ # max_completion_tokens=32768*2,
|
|
|
|
|
+ model_name=args.model_name,
|
|
|
|
|
+ )
|
|
|
|
|
+ print(f"response: {response}")
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+if __name__ == "__main__":
|
|
|
|
|
+ main()
|