ソースを参照

新增 macOS 兼容的推理脚本,优化支持 float16 和 MPS 加速,包含模型加载和推理功能

zhch158_admin 3 ヶ月 前
コミット
ba172d8f17
1 ファイル変更271 行追加0 行削除
  1. 271 0
      zhch/demo_hf_macos_float16.py

+ 271 - 0
zhch/demo_hf_macos_float16.py

@@ -0,0 +1,271 @@
+"""
+Apple Silicon (macOS) compatible version of demo_hf.py
+Optimized for float16 + MPS acceleration, with CPU fallback
+"""
+import os
+import platform
+import torch
+import argparse
+from transformers.models.auto.modeling_auto import AutoModelForCausalLM
+from transformers.models.auto.processing_auto import AutoProcessor
+from qwen_vl_utils import process_vision_info
+from dots_ocr.utils import dict_promptmode_to_prompt
+
+def get_optimal_device_and_dtype():
+    """Get the best available device and dtype for macOS"""
+    if torch.backends.mps.is_available():
+        print("🚀 MPS (Metal Performance Shaders) available")
+        return "mps", torch.float16
+    else:
+        print("⚠️  MPS not available, falling back to CPU")
+        return "cpu", torch.float32
+
+def inference_macos(image_path, prompt, model, processor, device="mps", dtype=torch.float16):
+    """
+    Inference function optimized for macOS/Apple Silicon with float16 + MPS
+    """
+    messages = [
+        {
+            "role": "user",
+            "content": [
+                {
+                    "type": "image",
+                    "image": image_path
+                },
+                {"type": "text", "text": prompt}
+            ]
+        }
+    ]
+
+    # Preparation for inference
+    text = processor.apply_chat_template(
+        messages, 
+        tokenize=False, 
+        add_generation_prompt=True
+    )
+    
+    # Handle process_vision_info return values properly
+    try:
+        vision_info = process_vision_info(messages)
+        # Safely unpack the return value
+        image_inputs = vision_info[0] if len(vision_info) > 0 else None
+        video_inputs = vision_info[1] if len(vision_info) > 1 else None
+    except Exception as e:
+        print(f"Warning: Error processing vision info: {e}")
+        image_inputs, video_inputs = None, None
+    
+    inputs = processor(
+        text=[text],
+        images=image_inputs,
+        videos=video_inputs,
+        padding=True,
+        return_tensors="pt",
+    )
+
+    # Move inputs to device and convert to appropriate dtype
+    inputs = inputs.to(device)
+    
+    # Convert floating point tensors to the target dtype for consistency
+    for key, value in inputs.items():
+        if isinstance(value, torch.Tensor) and value.dtype in [torch.float32, torch.float16, torch.bfloat16]:
+            if dtype == torch.float16 and value.dtype != torch.float16:
+                inputs[key] = value.to(dtype)
+                print(f"🔄 Converted {key} to {dtype}")
+
+    # Inference: Generation of the output with optimized settings
+    print(f"🚀 Starting inference on {device} with {dtype}")
+    with torch.no_grad():  # Save memory on Apple Silicon
+        generated_ids = model.generate(
+            **inputs, 
+            max_new_tokens=8000,  # Increased for float16 efficiency
+            do_sample=False,      # Use greedy for consistency
+            pad_token_id=processor.tokenizer.eos_token_id,
+            eos_token_id=processor.tokenizer.eos_token_id,
+            use_cache=True,       # Enable KV cache for speed
+            output_attentions=False,
+            output_hidden_states=False,
+        )
+    
+    generated_ids_trimmed = [
+        out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
+    ]
+    output_text = processor.batch_decode(
+        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
+    )
+    print(output_text[0])
+    return output_text[0]
+
+def load_model_macos(model_path, use_float16=True):
+    """Load model with macOS-optimized settings for float16 + MPS"""
+    device, dtype = get_optimal_device_and_dtype()
+    
+    # Override dtype based on parameter
+    if not use_float16:
+        dtype = torch.float32
+        device = "cpu"  # Force CPU for float32 safety
+    
+    print(f"Loading model on {device} with {dtype}...")
+    print(f"Platform: {platform.platform()}")
+    print(f"PyTorch version: {torch.__version__}")
+    
+    # Configuration for Apple Silicon with float16
+    model_kwargs = {
+        "torch_dtype": dtype,
+        "trust_remote_code": True,
+        "low_cpu_mem_usage": True,
+    }
+    
+    # Handle device mapping
+    if device == "mps":
+        model_kwargs["device_map"] = None  # Load on CPU first, then move to MPS
+        print("🔄 Loading model on CPU first, then moving to MPS...")
+    else:
+        model_kwargs["device_map"] = "cpu"
+    
+    try:
+        model = AutoModelForCausalLM.from_pretrained(model_path, **model_kwargs)
+        
+        # Move to MPS if available
+        if device == "mps":
+            print("🚀 Moving model to MPS for GPU acceleration...")
+            model = model.to("mps")
+            
+            # Verify model is on MPS
+            sample_param = next(model.parameters())
+            print(f"✅ Model device: {sample_param.device}, dtype: {sample_param.dtype}")
+        else:
+            print("✅ Model loaded on CPU")
+        
+        model.eval()  # Set to evaluation mode
+        return model, device, dtype
+        
+    except Exception as e:
+        print(f"❌ Error loading model: {e}")
+        if "mps" in str(e).lower() or device == "mps":
+            print("🔄 MPS loading failed, falling back to CPU with float32...")
+            return load_model_macos(model_path, use_float16=False)
+        else:
+            raise e
+
+if __name__ == "__main__":
+    # Parse command line arguments
+    parser = argparse.ArgumentParser(description="DotsOCR Apple Silicon Demo")
+    parser.add_argument("--model_path", default="./weights/DotsOCR", 
+                       help="Path to the model")
+    parser.add_argument("--use_float16", action="store_true", default=True,
+                       help="Use float16 + MPS for faster inference (default: True)")
+    parser.add_argument("--force_cpu", action="store_true", 
+                       help="Force CPU inference with float32")
+    args = parser.parse_args()
+    
+    # Check system information
+    print(f"System: {platform.system()} {platform.release()}")
+    print(f"Machine: {platform.machine()}")
+    print(f"MPS available: {torch.backends.mps.is_available()}")
+    
+    model_path = args.model_path
+    use_float16 = args.use_float16 and not args.force_cpu
+    
+    # Auto-detect float16 model if available
+    float16_path = model_path + "_float16"
+    if use_float16 and os.path.exists(float16_path):
+        print(f"🎯 Found float16 model at {float16_path}")
+        model_path = float16_path
+    elif use_float16 and not os.path.exists(float16_path):
+        print(f"⚠️  Float16 model not found at {float16_path}")
+        print("💡 Consider running: python tools/convert_model_float16.py")
+        print("🔄 Falling back to original model with auto-conversion...")
+    
+    if not os.path.exists(model_path):
+        print(f"❌ Model not found at {model_path}")
+        exit(1)
+    
+    # Load model and processor
+    try:
+        model, device, dtype = load_model_macos(model_path, use_float16)
+        processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
+        print(f"✅ Model and processor loaded successfully")
+    except Exception as e:
+        print(f"❌ Failed to load model: {e}")
+        exit(1)
+
+    image_path = "demo/demo_image1.jpg"
+    
+    if not os.path.exists(image_path):
+        print(f"❌ Demo image not found: {image_path}")
+        exit(1)
+    
+    # Test with different prompt modes
+    print("\n" + "="*60)
+    print(f"🚀 Starting inference tests on {device} with {dtype}")
+    print("="*60)
+    
+    # Test with a simple prompt first
+    test_prompt = "Extract all text content from this image."
+    print(f"\n🧪 Quick test with simple prompt...")
+    print(f"Prompt: {test_prompt}")
+    print("-" * 50)
+    
+    try:
+        result = inference_macos(image_path, test_prompt, model, processor, device, dtype)
+        print(f"✅ Quick test successful!")
+        print(f"Result preview: {result[:200]}..." if len(result) > 200 else f"Result: {result}")
+    except Exception as e:
+        print(f"❌ Quick test failed: {str(e)}")
+        print("🔄 This might be a compatibility issue. Try running:")
+        print("  python tools/convert_model_float16.py  # for float16 + MPS")
+        print("  python --force_cpu  # for CPU fallback")
+        exit(1)
+    
+    # If quick test passed, run full tests
+    print(f"\n" + "="*60)
+    print("🎯 Running full prompt mode tests...")
+    print("="*60)
+    
+    success_count = 0
+    total_count = len(dict_promptmode_to_prompt)
+    
+    for prompt_mode, prompt in dict_promptmode_to_prompt.items():
+        print(f"\n--- Testing prompt mode: {prompt_mode} ---")
+        print(f"Prompt: {prompt}")
+        print("---")
+        
+        try:
+            result = inference_macos(image_path, prompt, model, processor, device, dtype)
+            print(f"✅ Success for {prompt_mode}")
+            success_count += 1
+            
+            # Show a preview of longer results
+            if len(result) > 300:
+                print(f"Result preview: {result}")
+            else:
+                print(f"Result: {result}")
+                
+        except Exception as e:
+            print(f"❌ Error for {prompt_mode}: {str(e)}")
+        
+        print("-" * 60)
+    
+    print(f"\n🎊 Test Summary: {success_count}/{total_count} prompt modes successful")
+    
+    if success_count == total_count:
+        print("🎉 All tests passed! Your setup is working perfectly.")
+    elif success_count > 0:
+        print("⚠️  Some tests passed. The model is working but may have compatibility issues.")
+    else:
+        print("❌ All tests failed. Please check your setup.")
+    
+    print(f"\n💡 Performance info:")
+    print(f"  Device: {device}")
+    print(f"  Data type: {dtype}")
+    if device == "mps" and dtype == torch.float16:
+        print("  🚀 You're using the fastest configuration (float16 + MPS)!")
+    elif device == "cpu":
+        print("  🐌 Using CPU inference. Consider float16 + MPS for better performance.")
+        
+    print(f"\n📊 To optimize further:")
+    if device != "mps":
+        print("  • Run: python tools/convert_model_float16.py")
+        print("  • Then use the converted model for ~2x speedup")
+    print("  • Reduce max_new_tokens for faster but shorter outputs")
+    print("  • Use do_sample=False for deterministic results")