| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271 |
- """
- Apple Silicon (macOS) compatible version of demo_hf.py
- Optimized for float16 + MPS acceleration, with CPU fallback
- """
- import os
- import platform
- import torch
- import argparse
- from transformers.models.auto.modeling_auto import AutoModelForCausalLM
- from transformers.models.auto.processing_auto import AutoProcessor
- from qwen_vl_utils import process_vision_info
- from dots_ocr.utils import dict_promptmode_to_prompt
- def get_optimal_device_and_dtype():
- """Get the best available device and dtype for macOS"""
- if torch.backends.mps.is_available():
- print("🚀 MPS (Metal Performance Shaders) available")
- return "mps", torch.float16
- else:
- print("⚠️ MPS not available, falling back to CPU")
- return "cpu", torch.float32
- def inference_macos(image_path, prompt, model, processor, device="mps", dtype=torch.float16):
- """
- Inference function optimized for macOS/Apple Silicon with float16 + MPS
- """
- messages = [
- {
- "role": "user",
- "content": [
- {
- "type": "image",
- "image": image_path
- },
- {"type": "text", "text": prompt}
- ]
- }
- ]
- # Preparation for inference
- text = processor.apply_chat_template(
- messages,
- tokenize=False,
- add_generation_prompt=True
- )
-
- # Handle process_vision_info return values properly
- try:
- vision_info = process_vision_info(messages)
- # Safely unpack the return value
- image_inputs = vision_info[0] if len(vision_info) > 0 else None
- video_inputs = vision_info[1] if len(vision_info) > 1 else None
- except Exception as e:
- print(f"Warning: Error processing vision info: {e}")
- image_inputs, video_inputs = None, None
-
- inputs = processor(
- text=[text],
- images=image_inputs,
- videos=video_inputs,
- padding=True,
- return_tensors="pt",
- )
- # Move inputs to device and convert to appropriate dtype
- inputs = inputs.to(device)
-
- # Convert floating point tensors to the target dtype for consistency
- for key, value in inputs.items():
- if isinstance(value, torch.Tensor) and value.dtype in [torch.float32, torch.float16, torch.bfloat16]:
- if dtype == torch.float16 and value.dtype != torch.float16:
- inputs[key] = value.to(dtype)
- print(f"🔄 Converted {key} to {dtype}")
- # Inference: Generation of the output with optimized settings
- print(f"🚀 Starting inference on {device} with {dtype}")
- with torch.no_grad(): # Save memory on Apple Silicon
- generated_ids = model.generate(
- **inputs,
- max_new_tokens=8000, # Increased for float16 efficiency
- do_sample=False, # Use greedy for consistency
- pad_token_id=processor.tokenizer.eos_token_id,
- eos_token_id=processor.tokenizer.eos_token_id,
- use_cache=True, # Enable KV cache for speed
- output_attentions=False,
- output_hidden_states=False,
- )
-
- generated_ids_trimmed = [
- out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
- ]
- output_text = processor.batch_decode(
- generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
- )
- print(output_text[0])
- return output_text[0]
- def load_model_macos(model_path, use_float16=True):
- """Load model with macOS-optimized settings for float16 + MPS"""
- device, dtype = get_optimal_device_and_dtype()
-
- # Override dtype based on parameter
- if not use_float16:
- dtype = torch.float32
- device = "cpu" # Force CPU for float32 safety
-
- print(f"Loading model on {device} with {dtype}...")
- print(f"Platform: {platform.platform()}")
- print(f"PyTorch version: {torch.__version__}")
-
- # Configuration for Apple Silicon with float16
- model_kwargs = {
- "torch_dtype": dtype,
- "trust_remote_code": True,
- "low_cpu_mem_usage": True,
- }
-
- # Handle device mapping
- if device == "mps":
- model_kwargs["device_map"] = None # Load on CPU first, then move to MPS
- print("🔄 Loading model on CPU first, then moving to MPS...")
- else:
- model_kwargs["device_map"] = "cpu"
-
- try:
- model = AutoModelForCausalLM.from_pretrained(model_path, **model_kwargs)
-
- # Move to MPS if available
- if device == "mps":
- print("🚀 Moving model to MPS for GPU acceleration...")
- model = model.to("mps")
-
- # Verify model is on MPS
- sample_param = next(model.parameters())
- print(f"✅ Model device: {sample_param.device}, dtype: {sample_param.dtype}")
- else:
- print("✅ Model loaded on CPU")
-
- model.eval() # Set to evaluation mode
- return model, device, dtype
-
- except Exception as e:
- print(f"❌ Error loading model: {e}")
- if "mps" in str(e).lower() or device == "mps":
- print("🔄 MPS loading failed, falling back to CPU with float32...")
- return load_model_macos(model_path, use_float16=False)
- else:
- raise e
- if __name__ == "__main__":
- # Parse command line arguments
- parser = argparse.ArgumentParser(description="DotsOCR Apple Silicon Demo")
- parser.add_argument("--model_path", default="./weights/DotsOCR",
- help="Path to the model")
- parser.add_argument("--use_float16", action="store_true", default=True,
- help="Use float16 + MPS for faster inference (default: True)")
- parser.add_argument("--force_cpu", action="store_true",
- help="Force CPU inference with float32")
- args = parser.parse_args()
-
- # Check system information
- print(f"System: {platform.system()} {platform.release()}")
- print(f"Machine: {platform.machine()}")
- print(f"MPS available: {torch.backends.mps.is_available()}")
-
- model_path = args.model_path
- use_float16 = args.use_float16 and not args.force_cpu
-
- # Auto-detect float16 model if available
- float16_path = model_path + "_float16"
- if use_float16 and os.path.exists(float16_path):
- print(f"🎯 Found float16 model at {float16_path}")
- model_path = float16_path
- elif use_float16 and not os.path.exists(float16_path):
- print(f"⚠️ Float16 model not found at {float16_path}")
- print("💡 Consider running: python tools/convert_model_float16.py")
- print("🔄 Falling back to original model with auto-conversion...")
-
- if not os.path.exists(model_path):
- print(f"❌ Model not found at {model_path}")
- exit(1)
-
- # Load model and processor
- try:
- model, device, dtype = load_model_macos(model_path, use_float16)
- processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
- print(f"✅ Model and processor loaded successfully")
- except Exception as e:
- print(f"❌ Failed to load model: {e}")
- exit(1)
- image_path = "demo/demo_image1.jpg"
-
- if not os.path.exists(image_path):
- print(f"❌ Demo image not found: {image_path}")
- exit(1)
-
- # Test with different prompt modes
- print("\n" + "="*60)
- print(f"🚀 Starting inference tests on {device} with {dtype}")
- print("="*60)
-
- # Test with a simple prompt first
- test_prompt = "Extract all text content from this image."
- print(f"\n🧪 Quick test with simple prompt...")
- print(f"Prompt: {test_prompt}")
- print("-" * 50)
-
- try:
- result = inference_macos(image_path, test_prompt, model, processor, device, dtype)
- print(f"✅ Quick test successful!")
- print(f"Result preview: {result[:200]}..." if len(result) > 200 else f"Result: {result}")
- except Exception as e:
- print(f"❌ Quick test failed: {str(e)}")
- print("🔄 This might be a compatibility issue. Try running:")
- print(" python tools/convert_model_float16.py # for float16 + MPS")
- print(" python --force_cpu # for CPU fallback")
- exit(1)
-
- # If quick test passed, run full tests
- print(f"\n" + "="*60)
- print("🎯 Running full prompt mode tests...")
- print("="*60)
-
- success_count = 0
- total_count = len(dict_promptmode_to_prompt)
-
- for prompt_mode, prompt in dict_promptmode_to_prompt.items():
- print(f"\n--- Testing prompt mode: {prompt_mode} ---")
- print(f"Prompt: {prompt}")
- print("---")
-
- try:
- result = inference_macos(image_path, prompt, model, processor, device, dtype)
- print(f"✅ Success for {prompt_mode}")
- success_count += 1
-
- # Show a preview of longer results
- if len(result) > 300:
- print(f"Result preview: {result}")
- else:
- print(f"Result: {result}")
-
- except Exception as e:
- print(f"❌ Error for {prompt_mode}: {str(e)}")
-
- print("-" * 60)
-
- print(f"\n🎊 Test Summary: {success_count}/{total_count} prompt modes successful")
-
- if success_count == total_count:
- print("🎉 All tests passed! Your setup is working perfectly.")
- elif success_count > 0:
- print("⚠️ Some tests passed. The model is working but may have compatibility issues.")
- else:
- print("❌ All tests failed. Please check your setup.")
-
- print(f"\n💡 Performance info:")
- print(f" Device: {device}")
- print(f" Data type: {dtype}")
- if device == "mps" and dtype == torch.float16:
- print(" 🚀 You're using the fastest configuration (float16 + MPS)!")
- elif device == "cpu":
- print(" 🐌 Using CPU inference. Consider float16 + MPS for better performance.")
-
- print(f"\n📊 To optimize further:")
- if device != "mps":
- print(" • Run: python tools/convert_model_float16.py")
- print(" • Then use the converted model for ~2x speedup")
- print(" • Reduce max_new_tokens for faster but shorter outputs")
- print(" • Use do_sample=False for deterministic results")
|