convert_model_float16.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. """
  2. Float16 Model converter for Apple Silicon MPS acceleration
  3. This script converts model weights to float16 and fixes compatibility issues
  4. """
  5. import os
  6. import shutil
  7. import torch
  8. from transformers.models.auto.modeling_auto import AutoModelForCausalLM
  9. from transformers.models.auto.processing_auto import AutoProcessor
  10. def create_float16_model(model_path, output_path=None):
  11. """
  12. Create a float16 version of the model with MPS compatibility fixes
  13. """
  14. if output_path is None:
  15. output_path = model_path + "_float16"
  16. print(f"Creating float16 model from: {model_path}")
  17. print(f"Output path: {output_path}")
  18. if os.path.exists(output_path):
  19. print(f"⚠️ Output path already exists: {output_path}")
  20. response = input("Do you want to overwrite it? (y/N): ")
  21. if response.lower() != 'y':
  22. print("Conversion cancelled.")
  23. return None
  24. shutil.rmtree(output_path)
  25. # First copy all files
  26. print("📁 Copying model files...")
  27. shutil.copytree(model_path, output_path)
  28. # Load and convert model
  29. print("🔄 Loading and converting model to float16...")
  30. model = AutoModelForCausalLM.from_pretrained(
  31. model_path,
  32. torch_dtype=torch.float16, # Load as float16
  33. trust_remote_code=True,
  34. low_cpu_mem_usage=True,
  35. device_map="cpu"
  36. )
  37. # Force convert all parameters to float16
  38. with torch.no_grad():
  39. converted_count = 0
  40. for name, param in model.named_parameters():
  41. if param.dtype not in [torch.float16, torch.int64, torch.long, torch.bool]:
  42. param.data = param.data.to(torch.float16)
  43. converted_count += 1
  44. if converted_count <= 5: # Show first 5
  45. print(f" Converted {name}: {param.dtype} -> float16")
  46. for name, buffer in model.named_buffers():
  47. if buffer.dtype not in [torch.float16, torch.int64, torch.long, torch.bool]:
  48. buffer.data = buffer.data.to(torch.float16)
  49. converted_count += 1
  50. if converted_count <= 5:
  51. print(f" Converted buffer {name}: {buffer.dtype} -> float16")
  52. print(f"✅ Converted {converted_count} parameters/buffers to float16")
  53. # Now fix the modeling files for float16 + MPS compatibility
  54. fix_modeling_files_for_float16(output_path)
  55. # Save the converted model
  56. print("💾 Saving converted model...")
  57. model.save_pretrained(output_path, safe_serialization=True)
  58. print("✅ Float16 model conversion completed!")
  59. return output_path
  60. def fix_modeling_files_for_float16(model_path):
  61. """
  62. Fix the modeling files for float16 + MPS compatibility
  63. """
  64. print("🔧 Fixing modeling files for float16 + MPS compatibility...")
  65. vision_file = os.path.join(model_path, "modeling_dots_vision.py")
  66. if not os.path.exists(vision_file):
  67. print(f"⚠️ Vision file not found: {vision_file}")
  68. return
  69. # Read the file
  70. with open(vision_file, 'r', encoding='utf-8') as f:
  71. content = f.read()
  72. # Fix 1: Update Flash Attention fallback for float16
  73. old_fallback = """except ImportError:
  74. HAS_FLASH_ATTN = False
  75. def flash_attn_varlen_func(*args, **kwargs):
  76. print("Flash Attention not available. Using fallback implementation.")"""
  77. new_fallback = """except ImportError:
  78. HAS_FLASH_ATTN = False
  79. def flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, causal=False, **kwargs):
  80. \"\"\"
  81. Float16 optimized fallback implementation for flash_attn_varlen_func.
  82. Optimized for Apple Silicon MPS.
  83. \"\"\"
  84. print("Flash Attention not available. Using float16 MPS-optimized fallback.")
  85. # q, k, v shapes: (total_seq_len, num_heads, head_dim)
  86. batch_size = len(cu_seqlens_q) - 1
  87. outputs = []
  88. for i in range(batch_size):
  89. start_q = cu_seqlens_q[i]
  90. end_q = cu_seqlens_q[i + 1]
  91. start_k = cu_seqlens_k[i]
  92. end_k = cu_seqlens_k[i + 1]
  93. q_seq = q[start_q:end_q] # (seq_len_q, num_heads, head_dim)
  94. k_seq = k[start_k:end_k] # (seq_len_k, num_heads, head_dim)
  95. v_seq = v[start_k:end_k] # (seq_len_k, num_heads, head_dim)
  96. # Transpose for standard attention: (num_heads, seq_len, head_dim)
  97. q_seq = q_seq.transpose(0, 1)
  98. k_seq = k_seq.transpose(0, 1)
  99. v_seq = v_seq.transpose(0, 1)
  100. # Standard scaled dot-product attention with float16 optimization
  101. scores = torch.matmul(q_seq, k_seq.transpose(-2, -1)) / math.sqrt(q_seq.size(-1))
  102. # Apply causal mask if needed
  103. if causal and q_seq.size(1) > 1:
  104. seq_len = q_seq.size(1)
  105. causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=q.device, dtype=q.dtype), diagonal=1).bool()
  106. scores.masked_fill_(causal_mask, float('-inf'))
  107. # Use float32 for softmax stability, then convert back to float16
  108. attn_weights = F.softmax(scores.float(), dim=-1).to(q.dtype)
  109. attn_output = torch.matmul(attn_weights, v_seq)
  110. # Transpose back: (seq_len, num_heads, head_dim)
  111. attn_output = attn_output.transpose(0, 1)
  112. outputs.append(attn_output)
  113. # Concatenate all sequences
  114. return torch.cat(outputs, dim=0)"""
  115. if old_fallback in content:
  116. content = content.replace(old_fallback, new_fallback)
  117. print(" ✅ Updated Flash Attention fallback for float16")
  118. # Fix 2: Update rotary position embedding for float16
  119. old_rotary = """def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
  120. orig_dtype = tensor.dtype
  121. tensor = tensor.float()
  122. cos = freqs.cos()
  123. sin = freqs.sin()
  124. cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
  125. sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
  126. output = (tensor * cos) + (rotate_half(tensor) * sin)
  127. output = output.to(orig_dtype)
  128. return output"""
  129. new_rotary = """def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
  130. orig_dtype = tensor.dtype
  131. # For float16, use float32 for computation stability
  132. tensor = tensor.float()
  133. cos = freqs.cos()
  134. sin = freqs.sin()
  135. cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
  136. sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
  137. output = (tensor * cos) + (rotate_half(tensor) * sin)
  138. # Convert back to original dtype (float16 for MPS efficiency)
  139. output = output.to(orig_dtype)
  140. return output"""
  141. if old_rotary in content:
  142. content = content.replace(old_rotary, new_rotary)
  143. print(" ✅ Updated rotary position embedding for float16")
  144. # Fix 3: Remove problematic bfloat16 conversion if it exists
  145. if "hidden_states.bfloat16()" in content:
  146. content = content.replace("hidden_states.bfloat16()", "hidden_states.to(torch.float16)")
  147. print(" ✅ Fixed bfloat16 conversion to float16")
  148. # Fix 4: Update attention weights conversion
  149. if ".to(q.dtype)" in content and "softmax" in content:
  150. # Keep the original behavior for float16
  151. print(" ✅ Attention weights dtype conversion already compatible")
  152. # Write back the file
  153. with open(vision_file, 'w', encoding='utf-8') as f:
  154. f.write(content)
  155. print("✅ Modeling files fixed for float16 + MPS")
  156. def test_float16_model(model_path):
  157. """
  158. Test the float16 model with MPS
  159. """
  160. print(f"Testing float16 model: {model_path}")
  161. if not torch.backends.mps.is_available():
  162. print("❌ MPS not available on this system")
  163. return False
  164. try:
  165. model = AutoModelForCausalLM.from_pretrained(
  166. model_path,
  167. torch_dtype=torch.float16,
  168. trust_remote_code=True,
  169. device_map=None # Load on CPU first
  170. )
  171. # Move to MPS
  172. model = model.to("mps")
  173. processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
  174. # Verify dtype
  175. non_float16_params = []
  176. for name, param in model.named_parameters():
  177. if param.dtype not in [torch.float16, torch.int64, torch.long, torch.bool]:
  178. non_float16_params.append((name, param.dtype))
  179. if non_float16_params:
  180. print("⚠️ Some parameters are not float16:")
  181. for name, dtype in non_float16_params[:3]:
  182. print(f" {name}: {dtype}")
  183. else:
  184. print("✅ All floating-point parameters are float16")
  185. print("✅ Float16 model loads successfully on MPS!")
  186. return True
  187. except Exception as e:
  188. print(f"❌ Error testing float16 model: {e}")
  189. return False
  190. if __name__ == "__main__":
  191. model_path = "./weights/DotsOCR"
  192. if not os.path.exists(model_path):
  193. print(f"❌ Model not found at {model_path}")
  194. print("Please ensure the model is downloaded first.")
  195. exit(1)
  196. if not torch.backends.mps.is_available():
  197. print("❌ MPS (Metal Performance Shaders) not available on this system")
  198. print("This converter is specifically for Apple Silicon Macs with MPS support.")
  199. exit(1)
  200. print("🚀 DotsOCR Float16 + MPS Converter")
  201. print("=" * 50)
  202. try:
  203. output_path = create_float16_model(model_path)
  204. if output_path and test_float16_model(output_path):
  205. print(f"\n🎉 Success! Float16 model created at: {output_path}")
  206. print("You can now run faster inference with:")
  207. print(f" python demo/demo_hf_macos.py --model_path {output_path}")
  208. print("\n💡 Benefits of float16 + MPS:")
  209. print(" - ~2x faster inference compared to float32 CPU")
  210. print(" - ~50% less memory usage")
  211. print(" - Native Apple Silicon GPU acceleration")
  212. else:
  213. print(f"\n❌ Float16 conversion completed but testing failed.")
  214. except Exception as e:
  215. print(f"❌ Conversion failed: {e}")
  216. import traceback
  217. traceback.print_exc()