""" Apple Silicon (macOS) compatible version of demo_hf.py Optimized for float16 + MPS acceleration, with CPU fallback """ import os import platform import torch import argparse from transformers.models.auto.modeling_auto import AutoModelForCausalLM from transformers.models.auto.processing_auto import AutoProcessor from qwen_vl_utils import process_vision_info from dots_ocr.utils import dict_promptmode_to_prompt def get_optimal_device_and_dtype(): """Get the best available device and dtype for macOS""" if torch.backends.mps.is_available(): print("๐Ÿš€ MPS (Metal Performance Shaders) available") return "mps", torch.float16 else: print("โš ๏ธ MPS not available, falling back to CPU") return "cpu", torch.float32 def inference_macos(image_path, prompt, model, processor, device="mps", dtype=torch.float16): """ Inference function optimized for macOS/Apple Silicon with float16 + MPS """ messages = [ { "role": "user", "content": [ { "type": "image", "image": image_path }, {"type": "text", "text": prompt} ] } ] # Preparation for inference text = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) # Handle process_vision_info return values properly try: vision_info = process_vision_info(messages) # Safely unpack the return value image_inputs = vision_info[0] if len(vision_info) > 0 else None video_inputs = vision_info[1] if len(vision_info) > 1 else None except Exception as e: print(f"Warning: Error processing vision info: {e}") image_inputs, video_inputs = None, None inputs = processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ) # Move inputs to device and convert to appropriate dtype inputs = inputs.to(device) # Convert floating point tensors to the target dtype for consistency for key, value in inputs.items(): if isinstance(value, torch.Tensor) and value.dtype in [torch.float32, torch.float16, torch.bfloat16]: if dtype == torch.float16 and value.dtype != torch.float16: inputs[key] = value.to(dtype) print(f"๐Ÿ”„ Converted {key} to {dtype}") # Inference: Generation of the output with optimized settings print(f"๐Ÿš€ Starting inference on {device} with {dtype}") with torch.no_grad(): # Save memory on Apple Silicon generated_ids = model.generate( **inputs, max_new_tokens=8000, # Increased for float16 efficiency do_sample=False, # Use greedy for consistency pad_token_id=processor.tokenizer.eos_token_id, eos_token_id=processor.tokenizer.eos_token_id, use_cache=True, # Enable KV cache for speed output_attentions=False, output_hidden_states=False, ) generated_ids_trimmed = [ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] output_text = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False ) print(output_text[0]) return output_text[0] def load_model_macos(model_path, use_float16=True): """Load model with macOS-optimized settings for float16 + MPS""" device, dtype = get_optimal_device_and_dtype() # Override dtype based on parameter if not use_float16: dtype = torch.float32 device = "cpu" # Force CPU for float32 safety print(f"Loading model on {device} with {dtype}...") print(f"Platform: {platform.platform()}") print(f"PyTorch version: {torch.__version__}") # Configuration for Apple Silicon with float16 model_kwargs = { "torch_dtype": dtype, "trust_remote_code": True, "low_cpu_mem_usage": True, } # Handle device mapping if device == "mps": model_kwargs["device_map"] = None # Load on CPU first, then move to MPS print("๐Ÿ”„ Loading model on CPU first, then moving to MPS...") else: model_kwargs["device_map"] = "cpu" try: model = AutoModelForCausalLM.from_pretrained(model_path, **model_kwargs) # Move to MPS if available if device == "mps": print("๐Ÿš€ Moving model to MPS for GPU acceleration...") model = model.to("mps") # Verify model is on MPS sample_param = next(model.parameters()) print(f"โœ… Model device: {sample_param.device}, dtype: {sample_param.dtype}") else: print("โœ… Model loaded on CPU") model.eval() # Set to evaluation mode return model, device, dtype except Exception as e: print(f"โŒ Error loading model: {e}") if "mps" in str(e).lower() or device == "mps": print("๐Ÿ”„ MPS loading failed, falling back to CPU with float32...") return load_model_macos(model_path, use_float16=False) else: raise e if __name__ == "__main__": # Parse command line arguments parser = argparse.ArgumentParser(description="DotsOCR Apple Silicon Demo") parser.add_argument("--model_path", default="./weights/DotsOCR", help="Path to the model") parser.add_argument("--use_float16", action="store_true", default=True, help="Use float16 + MPS for faster inference (default: True)") parser.add_argument("--force_cpu", action="store_true", help="Force CPU inference with float32") args = parser.parse_args() # Check system information print(f"System: {platform.system()} {platform.release()}") print(f"Machine: {platform.machine()}") print(f"MPS available: {torch.backends.mps.is_available()}") model_path = args.model_path use_float16 = args.use_float16 and not args.force_cpu # Auto-detect float16 model if available float16_path = model_path + "_float16" if use_float16 and os.path.exists(float16_path): print(f"๐ŸŽฏ Found float16 model at {float16_path}") model_path = float16_path elif use_float16 and not os.path.exists(float16_path): print(f"โš ๏ธ Float16 model not found at {float16_path}") print("๐Ÿ’ก Consider running: python tools/convert_model_float16.py") print("๐Ÿ”„ Falling back to original model with auto-conversion...") if not os.path.exists(model_path): print(f"โŒ Model not found at {model_path}") exit(1) # Load model and processor try: model, device, dtype = load_model_macos(model_path, use_float16) processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) print(f"โœ… Model and processor loaded successfully") except Exception as e: print(f"โŒ Failed to load model: {e}") exit(1) image_path = "demo/demo_image1.jpg" if not os.path.exists(image_path): print(f"โŒ Demo image not found: {image_path}") exit(1) # Test with different prompt modes print("\n" + "="*60) print(f"๐Ÿš€ Starting inference tests on {device} with {dtype}") print("="*60) # Test with a simple prompt first test_prompt = "Extract all text content from this image." print(f"\n๐Ÿงช Quick test with simple prompt...") print(f"Prompt: {test_prompt}") print("-" * 50) try: result = inference_macos(image_path, test_prompt, model, processor, device, dtype) print(f"โœ… Quick test successful!") print(f"Result preview: {result[:200]}..." if len(result) > 200 else f"Result: {result}") except Exception as e: print(f"โŒ Quick test failed: {str(e)}") print("๐Ÿ”„ This might be a compatibility issue. Try running:") print(" python tools/convert_model_float16.py # for float16 + MPS") print(" python --force_cpu # for CPU fallback") exit(1) # If quick test passed, run full tests print(f"\n" + "="*60) print("๐ŸŽฏ Running full prompt mode tests...") print("="*60) success_count = 0 total_count = len(dict_promptmode_to_prompt) for prompt_mode, prompt in dict_promptmode_to_prompt.items(): print(f"\n--- Testing prompt mode: {prompt_mode} ---") print(f"Prompt: {prompt}") print("---") try: result = inference_macos(image_path, prompt, model, processor, device, dtype) print(f"โœ… Success for {prompt_mode}") success_count += 1 # Show a preview of longer results if len(result) > 300: print(f"Result preview: {result}") else: print(f"Result: {result}") except Exception as e: print(f"โŒ Error for {prompt_mode}: {str(e)}") print("-" * 60) print(f"\n๐ŸŽŠ Test Summary: {success_count}/{total_count} prompt modes successful") if success_count == total_count: print("๐ŸŽ‰ All tests passed! Your setup is working perfectly.") elif success_count > 0: print("โš ๏ธ Some tests passed. The model is working but may have compatibility issues.") else: print("โŒ All tests failed. Please check your setup.") print(f"\n๐Ÿ’ก Performance info:") print(f" Device: {device}") print(f" Data type: {dtype}") if device == "mps" and dtype == torch.float16: print(" ๐Ÿš€ You're using the fastest configuration (float16 + MPS)!") elif device == "cpu": print(" ๐ŸŒ Using CPU inference. Consider float16 + MPS for better performance.") print(f"\n๐Ÿ“Š To optimize further:") if device != "mps": print(" โ€ข Run: python tools/convert_model_float16.py") print(" โ€ข Then use the converted model for ~2x speedup") print(" โ€ข Reduce max_new_tokens for faster but shorter outputs") print(" โ€ข Use do_sample=False for deterministic results")