|
@@ -0,0 +1,271 @@
|
|
|
|
|
+"""
|
|
|
|
|
+Apple Silicon (macOS) compatible version of demo_hf.py
|
|
|
|
|
+Optimized for float16 + MPS acceleration, with CPU fallback
|
|
|
|
|
+"""
|
|
|
|
|
+import os
|
|
|
|
|
+import platform
|
|
|
|
|
+import torch
|
|
|
|
|
+import argparse
|
|
|
|
|
+from transformers.models.auto.modeling_auto import AutoModelForCausalLM
|
|
|
|
|
+from transformers.models.auto.processing_auto import AutoProcessor
|
|
|
|
|
+from qwen_vl_utils import process_vision_info
|
|
|
|
|
+from dots_ocr.utils import dict_promptmode_to_prompt
|
|
|
|
|
+
|
|
|
|
|
+def get_optimal_device_and_dtype():
|
|
|
|
|
+ """Get the best available device and dtype for macOS"""
|
|
|
|
|
+ if torch.backends.mps.is_available():
|
|
|
|
|
+ print("🚀 MPS (Metal Performance Shaders) available")
|
|
|
|
|
+ return "mps", torch.float16
|
|
|
|
|
+ else:
|
|
|
|
|
+ print("⚠️ MPS not available, falling back to CPU")
|
|
|
|
|
+ return "cpu", torch.float32
|
|
|
|
|
+
|
|
|
|
|
+def inference_macos(image_path, prompt, model, processor, device="mps", dtype=torch.float16):
|
|
|
|
|
+ """
|
|
|
|
|
+ Inference function optimized for macOS/Apple Silicon with float16 + MPS
|
|
|
|
|
+ """
|
|
|
|
|
+ messages = [
|
|
|
|
|
+ {
|
|
|
|
|
+ "role": "user",
|
|
|
|
|
+ "content": [
|
|
|
|
|
+ {
|
|
|
|
|
+ "type": "image",
|
|
|
|
|
+ "image": image_path
|
|
|
|
|
+ },
|
|
|
|
|
+ {"type": "text", "text": prompt}
|
|
|
|
|
+ ]
|
|
|
|
|
+ }
|
|
|
|
|
+ ]
|
|
|
|
|
+
|
|
|
|
|
+ # Preparation for inference
|
|
|
|
|
+ text = processor.apply_chat_template(
|
|
|
|
|
+ messages,
|
|
|
|
|
+ tokenize=False,
|
|
|
|
|
+ add_generation_prompt=True
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # Handle process_vision_info return values properly
|
|
|
|
|
+ try:
|
|
|
|
|
+ vision_info = process_vision_info(messages)
|
|
|
|
|
+ # Safely unpack the return value
|
|
|
|
|
+ image_inputs = vision_info[0] if len(vision_info) > 0 else None
|
|
|
|
|
+ video_inputs = vision_info[1] if len(vision_info) > 1 else None
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ print(f"Warning: Error processing vision info: {e}")
|
|
|
|
|
+ image_inputs, video_inputs = None, None
|
|
|
|
|
+
|
|
|
|
|
+ inputs = processor(
|
|
|
|
|
+ text=[text],
|
|
|
|
|
+ images=image_inputs,
|
|
|
|
|
+ videos=video_inputs,
|
|
|
|
|
+ padding=True,
|
|
|
|
|
+ return_tensors="pt",
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # Move inputs to device and convert to appropriate dtype
|
|
|
|
|
+ inputs = inputs.to(device)
|
|
|
|
|
+
|
|
|
|
|
+ # Convert floating point tensors to the target dtype for consistency
|
|
|
|
|
+ for key, value in inputs.items():
|
|
|
|
|
+ if isinstance(value, torch.Tensor) and value.dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
|
|
|
|
+ if dtype == torch.float16 and value.dtype != torch.float16:
|
|
|
|
|
+ inputs[key] = value.to(dtype)
|
|
|
|
|
+ print(f"🔄 Converted {key} to {dtype}")
|
|
|
|
|
+
|
|
|
|
|
+ # Inference: Generation of the output with optimized settings
|
|
|
|
|
+ print(f"🚀 Starting inference on {device} with {dtype}")
|
|
|
|
|
+ with torch.no_grad(): # Save memory on Apple Silicon
|
|
|
|
|
+ generated_ids = model.generate(
|
|
|
|
|
+ **inputs,
|
|
|
|
|
+ max_new_tokens=8000, # Increased for float16 efficiency
|
|
|
|
|
+ do_sample=False, # Use greedy for consistency
|
|
|
|
|
+ pad_token_id=processor.tokenizer.eos_token_id,
|
|
|
|
|
+ eos_token_id=processor.tokenizer.eos_token_id,
|
|
|
|
|
+ use_cache=True, # Enable KV cache for speed
|
|
|
|
|
+ output_attentions=False,
|
|
|
|
|
+ output_hidden_states=False,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ generated_ids_trimmed = [
|
|
|
|
|
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
|
|
|
|
+ ]
|
|
|
|
|
+ output_text = processor.batch_decode(
|
|
|
|
|
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
|
|
|
|
+ )
|
|
|
|
|
+ print(output_text[0])
|
|
|
|
|
+ return output_text[0]
|
|
|
|
|
+
|
|
|
|
|
+def load_model_macos(model_path, use_float16=True):
|
|
|
|
|
+ """Load model with macOS-optimized settings for float16 + MPS"""
|
|
|
|
|
+ device, dtype = get_optimal_device_and_dtype()
|
|
|
|
|
+
|
|
|
|
|
+ # Override dtype based on parameter
|
|
|
|
|
+ if not use_float16:
|
|
|
|
|
+ dtype = torch.float32
|
|
|
|
|
+ device = "cpu" # Force CPU for float32 safety
|
|
|
|
|
+
|
|
|
|
|
+ print(f"Loading model on {device} with {dtype}...")
|
|
|
|
|
+ print(f"Platform: {platform.platform()}")
|
|
|
|
|
+ print(f"PyTorch version: {torch.__version__}")
|
|
|
|
|
+
|
|
|
|
|
+ # Configuration for Apple Silicon with float16
|
|
|
|
|
+ model_kwargs = {
|
|
|
|
|
+ "torch_dtype": dtype,
|
|
|
|
|
+ "trust_remote_code": True,
|
|
|
|
|
+ "low_cpu_mem_usage": True,
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ # Handle device mapping
|
|
|
|
|
+ if device == "mps":
|
|
|
|
|
+ model_kwargs["device_map"] = None # Load on CPU first, then move to MPS
|
|
|
|
|
+ print("🔄 Loading model on CPU first, then moving to MPS...")
|
|
|
|
|
+ else:
|
|
|
|
|
+ model_kwargs["device_map"] = "cpu"
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ model = AutoModelForCausalLM.from_pretrained(model_path, **model_kwargs)
|
|
|
|
|
+
|
|
|
|
|
+ # Move to MPS if available
|
|
|
|
|
+ if device == "mps":
|
|
|
|
|
+ print("🚀 Moving model to MPS for GPU acceleration...")
|
|
|
|
|
+ model = model.to("mps")
|
|
|
|
|
+
|
|
|
|
|
+ # Verify model is on MPS
|
|
|
|
|
+ sample_param = next(model.parameters())
|
|
|
|
|
+ print(f"✅ Model device: {sample_param.device}, dtype: {sample_param.dtype}")
|
|
|
|
|
+ else:
|
|
|
|
|
+ print("✅ Model loaded on CPU")
|
|
|
|
|
+
|
|
|
|
|
+ model.eval() # Set to evaluation mode
|
|
|
|
|
+ return model, device, dtype
|
|
|
|
|
+
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ print(f"❌ Error loading model: {e}")
|
|
|
|
|
+ if "mps" in str(e).lower() or device == "mps":
|
|
|
|
|
+ print("🔄 MPS loading failed, falling back to CPU with float32...")
|
|
|
|
|
+ return load_model_macos(model_path, use_float16=False)
|
|
|
|
|
+ else:
|
|
|
|
|
+ raise e
|
|
|
|
|
+
|
|
|
|
|
+if __name__ == "__main__":
|
|
|
|
|
+ # Parse command line arguments
|
|
|
|
|
+ parser = argparse.ArgumentParser(description="DotsOCR Apple Silicon Demo")
|
|
|
|
|
+ parser.add_argument("--model_path", default="./weights/DotsOCR",
|
|
|
|
|
+ help="Path to the model")
|
|
|
|
|
+ parser.add_argument("--use_float16", action="store_true", default=True,
|
|
|
|
|
+ help="Use float16 + MPS for faster inference (default: True)")
|
|
|
|
|
+ parser.add_argument("--force_cpu", action="store_true",
|
|
|
|
|
+ help="Force CPU inference with float32")
|
|
|
|
|
+ args = parser.parse_args()
|
|
|
|
|
+
|
|
|
|
|
+ # Check system information
|
|
|
|
|
+ print(f"System: {platform.system()} {platform.release()}")
|
|
|
|
|
+ print(f"Machine: {platform.machine()}")
|
|
|
|
|
+ print(f"MPS available: {torch.backends.mps.is_available()}")
|
|
|
|
|
+
|
|
|
|
|
+ model_path = args.model_path
|
|
|
|
|
+ use_float16 = args.use_float16 and not args.force_cpu
|
|
|
|
|
+
|
|
|
|
|
+ # Auto-detect float16 model if available
|
|
|
|
|
+ float16_path = model_path + "_float16"
|
|
|
|
|
+ if use_float16 and os.path.exists(float16_path):
|
|
|
|
|
+ print(f"🎯 Found float16 model at {float16_path}")
|
|
|
|
|
+ model_path = float16_path
|
|
|
|
|
+ elif use_float16 and not os.path.exists(float16_path):
|
|
|
|
|
+ print(f"⚠️ Float16 model not found at {float16_path}")
|
|
|
|
|
+ print("💡 Consider running: python tools/convert_model_float16.py")
|
|
|
|
|
+ print("🔄 Falling back to original model with auto-conversion...")
|
|
|
|
|
+
|
|
|
|
|
+ if not os.path.exists(model_path):
|
|
|
|
|
+ print(f"❌ Model not found at {model_path}")
|
|
|
|
|
+ exit(1)
|
|
|
|
|
+
|
|
|
|
|
+ # Load model and processor
|
|
|
|
|
+ try:
|
|
|
|
|
+ model, device, dtype = load_model_macos(model_path, use_float16)
|
|
|
|
|
+ processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
|
|
|
|
|
+ print(f"✅ Model and processor loaded successfully")
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ print(f"❌ Failed to load model: {e}")
|
|
|
|
|
+ exit(1)
|
|
|
|
|
+
|
|
|
|
|
+ image_path = "demo/demo_image1.jpg"
|
|
|
|
|
+
|
|
|
|
|
+ if not os.path.exists(image_path):
|
|
|
|
|
+ print(f"❌ Demo image not found: {image_path}")
|
|
|
|
|
+ exit(1)
|
|
|
|
|
+
|
|
|
|
|
+ # Test with different prompt modes
|
|
|
|
|
+ print("\n" + "="*60)
|
|
|
|
|
+ print(f"🚀 Starting inference tests on {device} with {dtype}")
|
|
|
|
|
+ print("="*60)
|
|
|
|
|
+
|
|
|
|
|
+ # Test with a simple prompt first
|
|
|
|
|
+ test_prompt = "Extract all text content from this image."
|
|
|
|
|
+ print(f"\n🧪 Quick test with simple prompt...")
|
|
|
|
|
+ print(f"Prompt: {test_prompt}")
|
|
|
|
|
+ print("-" * 50)
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ result = inference_macos(image_path, test_prompt, model, processor, device, dtype)
|
|
|
|
|
+ print(f"✅ Quick test successful!")
|
|
|
|
|
+ print(f"Result preview: {result[:200]}..." if len(result) > 200 else f"Result: {result}")
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ print(f"❌ Quick test failed: {str(e)}")
|
|
|
|
|
+ print("🔄 This might be a compatibility issue. Try running:")
|
|
|
|
|
+ print(" python tools/convert_model_float16.py # for float16 + MPS")
|
|
|
|
|
+ print(" python --force_cpu # for CPU fallback")
|
|
|
|
|
+ exit(1)
|
|
|
|
|
+
|
|
|
|
|
+ # If quick test passed, run full tests
|
|
|
|
|
+ print(f"\n" + "="*60)
|
|
|
|
|
+ print("🎯 Running full prompt mode tests...")
|
|
|
|
|
+ print("="*60)
|
|
|
|
|
+
|
|
|
|
|
+ success_count = 0
|
|
|
|
|
+ total_count = len(dict_promptmode_to_prompt)
|
|
|
|
|
+
|
|
|
|
|
+ for prompt_mode, prompt in dict_promptmode_to_prompt.items():
|
|
|
|
|
+ print(f"\n--- Testing prompt mode: {prompt_mode} ---")
|
|
|
|
|
+ print(f"Prompt: {prompt}")
|
|
|
|
|
+ print("---")
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ result = inference_macos(image_path, prompt, model, processor, device, dtype)
|
|
|
|
|
+ print(f"✅ Success for {prompt_mode}")
|
|
|
|
|
+ success_count += 1
|
|
|
|
|
+
|
|
|
|
|
+ # Show a preview of longer results
|
|
|
|
|
+ if len(result) > 300:
|
|
|
|
|
+ print(f"Result preview: {result}")
|
|
|
|
|
+ else:
|
|
|
|
|
+ print(f"Result: {result}")
|
|
|
|
|
+
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ print(f"❌ Error for {prompt_mode}: {str(e)}")
|
|
|
|
|
+
|
|
|
|
|
+ print("-" * 60)
|
|
|
|
|
+
|
|
|
|
|
+ print(f"\n🎊 Test Summary: {success_count}/{total_count} prompt modes successful")
|
|
|
|
|
+
|
|
|
|
|
+ if success_count == total_count:
|
|
|
|
|
+ print("🎉 All tests passed! Your setup is working perfectly.")
|
|
|
|
|
+ elif success_count > 0:
|
|
|
|
|
+ print("⚠️ Some tests passed. The model is working but may have compatibility issues.")
|
|
|
|
|
+ else:
|
|
|
|
|
+ print("❌ All tests failed. Please check your setup.")
|
|
|
|
|
+
|
|
|
|
|
+ print(f"\n💡 Performance info:")
|
|
|
|
|
+ print(f" Device: {device}")
|
|
|
|
|
+ print(f" Data type: {dtype}")
|
|
|
|
|
+ if device == "mps" and dtype == torch.float16:
|
|
|
|
|
+ print(" 🚀 You're using the fastest configuration (float16 + MPS)!")
|
|
|
|
|
+ elif device == "cpu":
|
|
|
|
|
+ print(" 🐌 Using CPU inference. Consider float16 + MPS for better performance.")
|
|
|
|
|
+
|
|
|
|
|
+ print(f"\n📊 To optimize further:")
|
|
|
|
|
+ if device != "mps":
|
|
|
|
|
+ print(" • Run: python tools/convert_model_float16.py")
|
|
|
|
|
+ print(" • Then use the converted model for ~2x speedup")
|
|
|
|
|
+ print(" • Reduce max_new_tokens for faster but shorter outputs")
|
|
|
|
|
+ print(" • Use do_sample=False for deterministic results")
|