demo_hf.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. import os
  2. if "LOCAL_RANK" not in os.environ:
  3. os.environ["LOCAL_RANK"] = "0"
  4. import torch
  5. from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer
  6. from qwen_vl_utils import process_vision_info
  7. from dots_ocr.utils import dict_promptmode_to_prompt
  8. # 强制使用单个GPU
  9. os.environ["CUDA_VISIBLE_DEVICES"] = "0"
  10. def inference(image_path, prompt, model, processor):
  11. # image_path = "demo/demo_image1.jpg"
  12. messages = [
  13. {
  14. "role": "user",
  15. "content": [
  16. {
  17. "type": "image",
  18. "image": image_path
  19. },
  20. {"type": "text", "text": prompt}
  21. ]
  22. }
  23. ]
  24. # Preparation for inference
  25. text = processor.apply_chat_template(
  26. messages,
  27. tokenize=False,
  28. add_generation_prompt=True
  29. )
  30. image_inputs, video_inputs = process_vision_info(messages)
  31. inputs = processor(
  32. text=[text],
  33. images=image_inputs,
  34. videos=video_inputs,
  35. padding=True,
  36. return_tensors="pt",
  37. )
  38. inputs = inputs.to("cuda")
  39. # Inference: Generation of the output
  40. generated_ids = model.generate(**inputs, max_new_tokens=24000)
  41. generated_ids_trimmed = [
  42. out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
  43. ]
  44. output_text = processor.batch_decode(
  45. generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
  46. )
  47. print(output_text)
  48. if __name__ == "__main__":
  49. # We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
  50. model_path = "./weights/DotsOCR"
  51. model = AutoModelForCausalLM.from_pretrained(
  52. model_path,
  53. attn_implementation="flash_attention_2",
  54. torch_dtype=torch.bfloat16,
  55. device_map="auto",
  56. trust_remote_code=True
  57. )
  58. processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
  59. image_path = "demo/demo_image1.jpg"
  60. for prompt_mode, prompt in dict_promptmode_to_prompt.items():
  61. print(f"prompt: {prompt}")
  62. inference(image_path, prompt, model, processor)