demo_hf_macos_float16.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. """
  2. Apple Silicon (macOS) compatible version of demo_hf.py
  3. Optimized for float16 + MPS acceleration, with CPU fallback
  4. """
  5. import os
  6. import platform
  7. import torch
  8. import argparse
  9. from transformers.models.auto.modeling_auto import AutoModelForCausalLM
  10. from transformers.models.auto.processing_auto import AutoProcessor
  11. from qwen_vl_utils import process_vision_info
  12. from dots_ocr.utils import dict_promptmode_to_prompt
  13. def get_optimal_device_and_dtype():
  14. """Get the best available device and dtype for macOS"""
  15. if torch.backends.mps.is_available():
  16. print("🚀 MPS (Metal Performance Shaders) available")
  17. return "mps", torch.float16
  18. else:
  19. print("⚠️ MPS not available, falling back to CPU")
  20. return "cpu", torch.float32
  21. def inference_macos(image_path, prompt, model, processor, device="mps", dtype=torch.float16):
  22. """
  23. Inference function optimized for macOS/Apple Silicon with float16 + MPS
  24. """
  25. messages = [
  26. {
  27. "role": "user",
  28. "content": [
  29. {
  30. "type": "image",
  31. "image": image_path
  32. },
  33. {"type": "text", "text": prompt}
  34. ]
  35. }
  36. ]
  37. # Preparation for inference
  38. text = processor.apply_chat_template(
  39. messages,
  40. tokenize=False,
  41. add_generation_prompt=True
  42. )
  43. # Handle process_vision_info return values properly
  44. try:
  45. vision_info = process_vision_info(messages)
  46. # Safely unpack the return value
  47. image_inputs = vision_info[0] if len(vision_info) > 0 else None
  48. video_inputs = vision_info[1] if len(vision_info) > 1 else None
  49. except Exception as e:
  50. print(f"Warning: Error processing vision info: {e}")
  51. image_inputs, video_inputs = None, None
  52. inputs = processor(
  53. text=[text],
  54. images=image_inputs,
  55. videos=video_inputs,
  56. padding=True,
  57. return_tensors="pt",
  58. )
  59. # Move inputs to device and convert to appropriate dtype
  60. inputs = inputs.to(device)
  61. # Convert floating point tensors to the target dtype for consistency
  62. for key, value in inputs.items():
  63. if isinstance(value, torch.Tensor) and value.dtype in [torch.float32, torch.float16, torch.bfloat16]:
  64. if dtype == torch.float16 and value.dtype != torch.float16:
  65. inputs[key] = value.to(dtype)
  66. print(f"🔄 Converted {key} to {dtype}")
  67. # Inference: Generation of the output with optimized settings
  68. print(f"🚀 Starting inference on {device} with {dtype}")
  69. with torch.no_grad(): # Save memory on Apple Silicon
  70. generated_ids = model.generate(
  71. **inputs,
  72. max_new_tokens=8000, # Increased for float16 efficiency
  73. do_sample=False, # Use greedy for consistency
  74. pad_token_id=processor.tokenizer.eos_token_id,
  75. eos_token_id=processor.tokenizer.eos_token_id,
  76. use_cache=True, # Enable KV cache for speed
  77. output_attentions=False,
  78. output_hidden_states=False,
  79. )
  80. generated_ids_trimmed = [
  81. out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
  82. ]
  83. output_text = processor.batch_decode(
  84. generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
  85. )
  86. print(output_text[0])
  87. return output_text[0]
  88. def load_model_macos(model_path, use_float16=True):
  89. """Load model with macOS-optimized settings for float16 + MPS"""
  90. device, dtype = get_optimal_device_and_dtype()
  91. # Override dtype based on parameter
  92. if not use_float16:
  93. dtype = torch.float32
  94. device = "cpu" # Force CPU for float32 safety
  95. print(f"Loading model on {device} with {dtype}...")
  96. print(f"Platform: {platform.platform()}")
  97. print(f"PyTorch version: {torch.__version__}")
  98. # Configuration for Apple Silicon with float16
  99. model_kwargs = {
  100. "torch_dtype": dtype,
  101. "trust_remote_code": True,
  102. "low_cpu_mem_usage": True,
  103. }
  104. # Handle device mapping
  105. if device == "mps":
  106. model_kwargs["device_map"] = None # Load on CPU first, then move to MPS
  107. print("🔄 Loading model on CPU first, then moving to MPS...")
  108. else:
  109. model_kwargs["device_map"] = "cpu"
  110. try:
  111. model = AutoModelForCausalLM.from_pretrained(model_path, **model_kwargs)
  112. # Move to MPS if available
  113. if device == "mps":
  114. print("🚀 Moving model to MPS for GPU acceleration...")
  115. model = model.to("mps")
  116. # Verify model is on MPS
  117. sample_param = next(model.parameters())
  118. print(f"✅ Model device: {sample_param.device}, dtype: {sample_param.dtype}")
  119. else:
  120. print("✅ Model loaded on CPU")
  121. model.eval() # Set to evaluation mode
  122. return model, device, dtype
  123. except Exception as e:
  124. print(f"❌ Error loading model: {e}")
  125. if "mps" in str(e).lower() or device == "mps":
  126. print("🔄 MPS loading failed, falling back to CPU with float32...")
  127. return load_model_macos(model_path, use_float16=False)
  128. else:
  129. raise e
  130. if __name__ == "__main__":
  131. # Parse command line arguments
  132. parser = argparse.ArgumentParser(description="DotsOCR Apple Silicon Demo")
  133. parser.add_argument("--model_path", default="./weights/DotsOCR",
  134. help="Path to the model")
  135. parser.add_argument("--use_float16", action="store_true", default=True,
  136. help="Use float16 + MPS for faster inference (default: True)")
  137. parser.add_argument("--force_cpu", action="store_true",
  138. help="Force CPU inference with float32")
  139. args = parser.parse_args()
  140. # Check system information
  141. print(f"System: {platform.system()} {platform.release()}")
  142. print(f"Machine: {platform.machine()}")
  143. print(f"MPS available: {torch.backends.mps.is_available()}")
  144. model_path = args.model_path
  145. use_float16 = args.use_float16 and not args.force_cpu
  146. # Auto-detect float16 model if available
  147. float16_path = model_path + "_float16"
  148. if use_float16 and os.path.exists(float16_path):
  149. print(f"🎯 Found float16 model at {float16_path}")
  150. model_path = float16_path
  151. elif use_float16 and not os.path.exists(float16_path):
  152. print(f"⚠️ Float16 model not found at {float16_path}")
  153. print("💡 Consider running: python tools/convert_model_float16.py")
  154. print("🔄 Falling back to original model with auto-conversion...")
  155. if not os.path.exists(model_path):
  156. print(f"❌ Model not found at {model_path}")
  157. exit(1)
  158. # Load model and processor
  159. try:
  160. model, device, dtype = load_model_macos(model_path, use_float16)
  161. processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
  162. print(f"✅ Model and processor loaded successfully")
  163. except Exception as e:
  164. print(f"❌ Failed to load model: {e}")
  165. exit(1)
  166. image_path = "demo/demo_image1.jpg"
  167. if not os.path.exists(image_path):
  168. print(f"❌ Demo image not found: {image_path}")
  169. exit(1)
  170. # Test with different prompt modes
  171. print("\n" + "="*60)
  172. print(f"🚀 Starting inference tests on {device} with {dtype}")
  173. print("="*60)
  174. # Test with a simple prompt first
  175. test_prompt = "Extract all text content from this image."
  176. print(f"\n🧪 Quick test with simple prompt...")
  177. print(f"Prompt: {test_prompt}")
  178. print("-" * 50)
  179. try:
  180. result = inference_macos(image_path, test_prompt, model, processor, device, dtype)
  181. print(f"✅ Quick test successful!")
  182. print(f"Result preview: {result[:200]}..." if len(result) > 200 else f"Result: {result}")
  183. except Exception as e:
  184. print(f"❌ Quick test failed: {str(e)}")
  185. print("🔄 This might be a compatibility issue. Try running:")
  186. print(" python tools/convert_model_float16.py # for float16 + MPS")
  187. print(" python --force_cpu # for CPU fallback")
  188. exit(1)
  189. # If quick test passed, run full tests
  190. print(f"\n" + "="*60)
  191. print("🎯 Running full prompt mode tests...")
  192. print("="*60)
  193. success_count = 0
  194. total_count = len(dict_promptmode_to_prompt)
  195. for prompt_mode, prompt in dict_promptmode_to_prompt.items():
  196. print(f"\n--- Testing prompt mode: {prompt_mode} ---")
  197. print(f"Prompt: {prompt}")
  198. print("---")
  199. try:
  200. result = inference_macos(image_path, prompt, model, processor, device, dtype)
  201. print(f"✅ Success for {prompt_mode}")
  202. success_count += 1
  203. # Show a preview of longer results
  204. if len(result) > 300:
  205. print(f"Result preview: {result}")
  206. else:
  207. print(f"Result: {result}")
  208. except Exception as e:
  209. print(f"❌ Error for {prompt_mode}: {str(e)}")
  210. print("-" * 60)
  211. print(f"\n🎊 Test Summary: {success_count}/{total_count} prompt modes successful")
  212. if success_count == total_count:
  213. print("🎉 All tests passed! Your setup is working perfectly.")
  214. elif success_count > 0:
  215. print("⚠️ Some tests passed. The model is working but may have compatibility issues.")
  216. else:
  217. print("❌ All tests failed. Please check your setup.")
  218. print(f"\n💡 Performance info:")
  219. print(f" Device: {device}")
  220. print(f" Data type: {dtype}")
  221. if device == "mps" and dtype == torch.float16:
  222. print(" 🚀 You're using the fastest configuration (float16 + MPS)!")
  223. elif device == "cpu":
  224. print(" 🐌 Using CPU inference. Consider float16 + MPS for better performance.")
  225. print(f"\n📊 To optimize further:")
  226. if device != "mps":
  227. print(" • Run: python tools/convert_model_float16.py")
  228. print(" • Then use the converted model for ~2x speedup")
  229. print(" • Reduce max_new_tokens for faster but shorter outputs")
  230. print(" • Use do_sample=False for deterministic results")