demo_hf_macos_bfloat16.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. import os
  2. if "LOCAL_RANK" not in os.environ:
  3. os.environ["LOCAL_RANK"] = "0"
  4. import torch
  5. from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer
  6. from qwen_vl_utils import process_vision_info
  7. from dots_ocr.utils import dict_promptmode_to_prompt
  8. def inference(image_path, prompt, model, processor):
  9. # image_path = "demo/demo_image1.jpg"
  10. messages = [
  11. {
  12. "role": "user",
  13. "content": [
  14. {
  15. "type": "image",
  16. "image": image_path
  17. },
  18. {"type": "text", "text": prompt}
  19. ]
  20. }
  21. ]
  22. # Preparation for inference
  23. text = processor.apply_chat_template(
  24. messages,
  25. tokenize=False,
  26. add_generation_prompt=True
  27. )
  28. image_inputs, video_inputs = process_vision_info(messages)
  29. inputs = processor(
  30. text=[text],
  31. images=image_inputs,
  32. videos=video_inputs,
  33. padding=True,
  34. return_tensors="pt",
  35. )
  36. # inputs = inputs.to("cuda")
  37. # Inference: Generation of the output
  38. generated_ids = model.generate(**inputs, max_new_tokens=24000)
  39. generated_ids_trimmed = [
  40. out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
  41. ]
  42. output_text = processor.batch_decode(
  43. generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
  44. )
  45. print(output_text)
  46. if __name__ == "__main__":
  47. # We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
  48. # model_path = "./weights/DotsOCR"
  49. model_path = "./weights/DotsOCR_CPU_bfloat16"
  50. model = AutoModelForCausalLM.from_pretrained(
  51. model_path,
  52. # attn_implementation="flash_attention_2",
  53. torch_dtype=torch.bfloat16,
  54. # device_map="auto",
  55. device_map="cpu",
  56. trust_remote_code=True
  57. )
  58. processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
  59. # image_path = "demo/demo_image1.jpg"
  60. image_path = "./zhch/demo_1.jpg"
  61. for prompt_mode, prompt in dict_promptmode_to_prompt.items():
  62. print(f"prompt: {prompt}")
  63. inference(image_path, prompt, model, processor)