| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276 |
- """
- Float16 Model converter for Apple Silicon MPS acceleration
- This script converts model weights to float16 and fixes compatibility issues
- """
- import os
- import shutil
- import torch
- from transformers.models.auto.modeling_auto import AutoModelForCausalLM
- from transformers.models.auto.processing_auto import AutoProcessor
- def create_float16_model(model_path, output_path=None):
- """
- Create a float16 version of the model with MPS compatibility fixes
- """
- if output_path is None:
- output_path = model_path + "_float16"
-
- print(f"Creating float16 model from: {model_path}")
- print(f"Output path: {output_path}")
-
- if os.path.exists(output_path):
- print(f"⚠️ Output path already exists: {output_path}")
- response = input("Do you want to overwrite it? (y/N): ")
- if response.lower() != 'y':
- print("Conversion cancelled.")
- return None
- shutil.rmtree(output_path)
-
- # First copy all files
- print("📁 Copying model files...")
- shutil.copytree(model_path, output_path)
-
- # Load and convert model
- print("🔄 Loading and converting model to float16...")
- model = AutoModelForCausalLM.from_pretrained(
- model_path,
- torch_dtype=torch.float16, # Load as float16
- trust_remote_code=True,
- low_cpu_mem_usage=True,
- device_map="cpu"
- )
-
- # Force convert all parameters to float16
- with torch.no_grad():
- converted_count = 0
- for name, param in model.named_parameters():
- if param.dtype not in [torch.float16, torch.int64, torch.long, torch.bool]:
- param.data = param.data.to(torch.float16)
- converted_count += 1
- if converted_count <= 5: # Show first 5
- print(f" Converted {name}: {param.dtype} -> float16")
-
- for name, buffer in model.named_buffers():
- if buffer.dtype not in [torch.float16, torch.int64, torch.long, torch.bool]:
- buffer.data = buffer.data.to(torch.float16)
- converted_count += 1
- if converted_count <= 5:
- print(f" Converted buffer {name}: {buffer.dtype} -> float16")
-
- print(f"✅ Converted {converted_count} parameters/buffers to float16")
-
- # Now fix the modeling files for float16 + MPS compatibility
- fix_modeling_files_for_float16(output_path)
-
- # Save the converted model
- print("💾 Saving converted model...")
- model.save_pretrained(output_path, safe_serialization=True)
-
- print("✅ Float16 model conversion completed!")
- return output_path
- def fix_modeling_files_for_float16(model_path):
- """
- Fix the modeling files for float16 + MPS compatibility
- """
- print("🔧 Fixing modeling files for float16 + MPS compatibility...")
-
- vision_file = os.path.join(model_path, "modeling_dots_vision.py")
-
- if not os.path.exists(vision_file):
- print(f"⚠️ Vision file not found: {vision_file}")
- return
-
- # Read the file
- with open(vision_file, 'r', encoding='utf-8') as f:
- content = f.read()
-
- # Fix 1: Update Flash Attention fallback for float16
- old_fallback = """except ImportError:
- HAS_FLASH_ATTN = False
- def flash_attn_varlen_func(*args, **kwargs):
- print("Flash Attention not available. Using fallback implementation.")"""
-
- new_fallback = """except ImportError:
- HAS_FLASH_ATTN = False
- def flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, causal=False, **kwargs):
- \"\"\"
- Float16 optimized fallback implementation for flash_attn_varlen_func.
- Optimized for Apple Silicon MPS.
- \"\"\"
- print("Flash Attention not available. Using float16 MPS-optimized fallback.")
-
- # q, k, v shapes: (total_seq_len, num_heads, head_dim)
- batch_size = len(cu_seqlens_q) - 1
- outputs = []
-
- for i in range(batch_size):
- start_q = cu_seqlens_q[i]
- end_q = cu_seqlens_q[i + 1]
- start_k = cu_seqlens_k[i]
- end_k = cu_seqlens_k[i + 1]
-
- q_seq = q[start_q:end_q] # (seq_len_q, num_heads, head_dim)
- k_seq = k[start_k:end_k] # (seq_len_k, num_heads, head_dim)
- v_seq = v[start_k:end_k] # (seq_len_k, num_heads, head_dim)
-
- # Transpose for standard attention: (num_heads, seq_len, head_dim)
- q_seq = q_seq.transpose(0, 1)
- k_seq = k_seq.transpose(0, 1)
- v_seq = v_seq.transpose(0, 1)
-
- # Standard scaled dot-product attention with float16 optimization
- scores = torch.matmul(q_seq, k_seq.transpose(-2, -1)) / math.sqrt(q_seq.size(-1))
-
- # Apply causal mask if needed
- if causal and q_seq.size(1) > 1:
- seq_len = q_seq.size(1)
- causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=q.device, dtype=q.dtype), diagonal=1).bool()
- scores.masked_fill_(causal_mask, float('-inf'))
-
- # Use float32 for softmax stability, then convert back to float16
- attn_weights = F.softmax(scores.float(), dim=-1).to(q.dtype)
- attn_output = torch.matmul(attn_weights, v_seq)
-
- # Transpose back: (seq_len, num_heads, head_dim)
- attn_output = attn_output.transpose(0, 1)
- outputs.append(attn_output)
-
- # Concatenate all sequences
- return torch.cat(outputs, dim=0)"""
-
- if old_fallback in content:
- content = content.replace(old_fallback, new_fallback)
- print(" ✅ Updated Flash Attention fallback for float16")
-
- # Fix 2: Update rotary position embedding for float16
- old_rotary = """def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
- orig_dtype = tensor.dtype
- tensor = tensor.float()
- cos = freqs.cos()
- sin = freqs.sin()
- cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
- sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
- output = (tensor * cos) + (rotate_half(tensor) * sin)
- output = output.to(orig_dtype)
- return output"""
-
- new_rotary = """def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
- orig_dtype = tensor.dtype
- # For float16, use float32 for computation stability
- tensor = tensor.float()
- cos = freqs.cos()
- sin = freqs.sin()
- cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
- sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
- output = (tensor * cos) + (rotate_half(tensor) * sin)
- # Convert back to original dtype (float16 for MPS efficiency)
- output = output.to(orig_dtype)
- return output"""
-
- if old_rotary in content:
- content = content.replace(old_rotary, new_rotary)
- print(" ✅ Updated rotary position embedding for float16")
-
- # Fix 3: Remove problematic bfloat16 conversion if it exists
- if "hidden_states.bfloat16()" in content:
- content = content.replace("hidden_states.bfloat16()", "hidden_states.to(torch.float16)")
- print(" ✅ Fixed bfloat16 conversion to float16")
-
- # Fix 4: Update attention weights conversion
- if ".to(q.dtype)" in content and "softmax" in content:
- # Keep the original behavior for float16
- print(" ✅ Attention weights dtype conversion already compatible")
-
- # Write back the file
- with open(vision_file, 'w', encoding='utf-8') as f:
- f.write(content)
-
- print("✅ Modeling files fixed for float16 + MPS")
- def test_float16_model(model_path):
- """
- Test the float16 model with MPS
- """
- print(f"Testing float16 model: {model_path}")
-
- if not torch.backends.mps.is_available():
- print("❌ MPS not available on this system")
- return False
-
- try:
- model = AutoModelForCausalLM.from_pretrained(
- model_path,
- torch_dtype=torch.float16,
- trust_remote_code=True,
- device_map=None # Load on CPU first
- )
-
- # Move to MPS
- model = model.to("mps")
- processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
-
- # Verify dtype
- non_float16_params = []
- for name, param in model.named_parameters():
- if param.dtype not in [torch.float16, torch.int64, torch.long, torch.bool]:
- non_float16_params.append((name, param.dtype))
-
- if non_float16_params:
- print("⚠️ Some parameters are not float16:")
- for name, dtype in non_float16_params[:3]:
- print(f" {name}: {dtype}")
- else:
- print("✅ All floating-point parameters are float16")
-
- print("✅ Float16 model loads successfully on MPS!")
- return True
-
- except Exception as e:
- print(f"❌ Error testing float16 model: {e}")
- return False
- if __name__ == "__main__":
- model_path = "./weights/DotsOCR"
-
- if not os.path.exists(model_path):
- print(f"❌ Model not found at {model_path}")
- print("Please ensure the model is downloaded first.")
- exit(1)
-
- if not torch.backends.mps.is_available():
- print("❌ MPS (Metal Performance Shaders) not available on this system")
- print("This converter is specifically for Apple Silicon Macs with MPS support.")
- exit(1)
-
- print("🚀 DotsOCR Float16 + MPS Converter")
- print("=" * 50)
-
- try:
- output_path = create_float16_model(model_path)
-
- if output_path and test_float16_model(output_path):
- print(f"\n🎉 Success! Float16 model created at: {output_path}")
- print("You can now run faster inference with:")
- print(f" python demo/demo_hf_macos.py --model_path {output_path}")
- print("\n💡 Benefits of float16 + MPS:")
- print(" - ~2x faster inference compared to float32 CPU")
- print(" - ~50% less memory usage")
- print(" - Native Apple Silicon GPU acceleration")
- else:
- print(f"\n❌ Float16 conversion completed but testing failed.")
-
- except Exception as e:
- print(f"❌ Conversion failed: {e}")
- import traceback
- traceback.print_exc()
|