""" Model dtype converter for Apple Silicon compatibility This script converts all model weights to float32 and saves a compatible version """ import os import torch from transformers.models.auto.modeling_auto import AutoModelForCausalLM from transformers.models.auto.processing_auto import AutoProcessor def convert_model_to_float32(model_path, output_path=None): """ Convert a model to float32 and optionally save it """ if output_path is None: output_path = model_path + "_float32" print(f"Loading model from: {model_path}") # Load model with float32 model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype=torch.float32, trust_remote_code=True, low_cpu_mem_usage=True, device_map="cpu" ) print("Converting all parameters to float32...") # Force convert all parameters with torch.no_grad(): converted_count = 0 for name, param in model.named_parameters(): if param.dtype != torch.float32: param.data = param.data.to(torch.float32) converted_count += 1 print(f" Converted {name}: {param.dtype} -> float32") for name, buffer in model.named_buffers(): if buffer.dtype not in [torch.float32, torch.int64, torch.long, torch.bool]: buffer.data = buffer.data.to(torch.float32) converted_count += 1 print(f" Converted buffer {name}: {buffer.dtype} -> float32") print(f"✅ Converted {converted_count} parameters/buffers to float32") # Save the converted model if not os.path.exists(output_path): os.makedirs(output_path) print(f"Saving converted model to: {output_path}") model.save_pretrained(output_path, safe_serialization=True) # Also copy the processor try: processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) processor.save_pretrained(output_path) print("✅ Processor also saved") except Exception as e: print(f"Warning: Could not save processor: {e}") print("✅ Model conversion completed!") return output_path def test_converted_model(model_path): """ Test the converted model with a simple inference """ print(f"Testing converted model: {model_path}") try: model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype=torch.float32, trust_remote_code=True, device_map="cpu" ) processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) # Verify all parameters are float32 non_float32_params = [] for name, param in model.named_parameters(): if param.dtype != torch.float32: non_float32_params.append((name, param.dtype)) if non_float32_params: print("❌ Still have non-float32 parameters:") for name, dtype in non_float32_params[:5]: # Show first 5 print(f" {name}: {dtype}") else: print("✅ All parameters are float32") print("✅ Converted model loads successfully!") return True except Exception as e: print(f"❌ Error testing converted model: {e}") return False if __name__ == "__main__": model_path = "./weights/DotsOCR" if not os.path.exists(model_path): print(f"❌ Model not found at {model_path}") print("Please ensure the model is downloaded first.") exit(1) print("🔧 DotsOCR Model Converter for Apple Silicon") print("=" * 50) # Convert the model try: output_path = model_path + "_float32" # 如果 output_path不存在,转换 if not os.path.exists(output_path): output_path = convert_model_to_float32(model_path, output_path) # Test the converted model if test_converted_model(output_path): print(f"\n🎉 Success! Use the converted model at: {output_path}") print("You can now run inference with this converted model:") print(f" python demo/demo_hf_macos_ultimate.py --model_path {output_path}") else: print(f"\n❌ Conversion completed but testing failed.") except Exception as e: print(f"❌ Conversion failed: {e}") print("\nTrying alternative approach...") # Alternative: try to load and immediately convert try: print("Loading model with explicit float32 conversion...") model = AutoModelForCausalLM.from_pretrained( model_path, trust_remote_code=True, device_map="cpu", torch_dtype=torch.float32 ) # Check if this worked mixed_precision = False for name, param in model.named_parameters(): if param.dtype != torch.float32: mixed_precision = True break if mixed_precision: print("Model still has mixed precision - this explains the dtype errors.") print("The model weights themselves contain mixed dtypes that cannot be easily converted.") print("\nRecommendations:") print("1. Use the online demo: https://dotsocr.xiaohongshu.com/") print("2. Run on a Linux machine with NVIDIA GPU") print("3. Wait for a native Apple Silicon compatible model release") else: print("✅ Model is actually float32 compatible!") except Exception as alt_error: print(f"❌ Alternative approach also failed: {alt_error}")