#!/usr/bin/env python3 """ DotsOCR Apple Silicon 最终版本 已完全解决所有兼容性问题,支持文本识别和布局分析 使用方法: python demo_apple_silicon.py """ import os import platform import torch from transformers.models.auto.modeling_auto import AutoModelForCausalLM from transformers.models.auto.processing_auto import AutoProcessor from qwen_vl_utils import process_vision_info def load_model(model_path="./weights/DotsOCR_float32"): """加载 Apple Silicon 兼容的模型""" print(f"🍎 Apple Silicon DotsOCR v1.0") print(f"系统: {platform.system()} {platform.machine()}") print(f"PyTorch: {torch.__version__}") if not os.path.exists(model_path): print(f"❌ 模型未找到: {model_path}") print("请先运行: python tools/convert_model_macos.py") return None, None print(f"📦 加载模型: {model_path}") model_kwargs = { "torch_dtype": torch.float32, "trust_remote_code": True, "low_cpu_mem_usage": True, "device_map": "cpu", } try: model = AutoModelForCausalLM.from_pretrained(model_path, **model_kwargs) model.eval() print("✅ 模型加载成功") processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) print("✅ 处理器加载成功") return model, processor except Exception as e: print(f"❌ 加载失败: {e}") return None, None def ocr_inference(image_path, prompt, model, processor): """Apple Silicon 优化的 OCR 推理""" messages = [ { "role": "user", "content": [ {"type": "image", "image": image_path}, {"type": "text", "text": prompt} ] } ] try: text = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) except Exception as e: print(f"⚠️ 模板处理警告: {e}") text = prompt try: vision_info = process_vision_info(messages) image_inputs = vision_info[0] if len(vision_info) > 0 else None video_inputs = vision_info[1] if len(vision_info) > 1 else None except Exception as e: print(f"❌ 视觉信息处理失败: {e}") return f"错误: 视觉处理失败" try: inputs = processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ) except Exception as e: print(f"❌ 输入处理失败: {e}") return f"错误: 输入处理失败" # 确保所有张量都在 CPU 上且为 float32 inputs = inputs.to("cpu") for key, value in inputs.items(): if isinstance(value, torch.Tensor) and value.dtype in [torch.float16, torch.bfloat16]: inputs[key] = value.to(torch.float32) try: print("🚀 开始推理...") with torch.no_grad(): generated_ids = model.generate( **inputs, max_new_tokens=500, do_sample=False, pad_token_id=processor.tokenizer.eos_token_id, eos_token_id=processor.tokenizer.eos_token_id, output_attentions=False, output_hidden_states=False, ) except Exception as e: print(f"❌ 推理失败: {e}") return f"错误: 推理失败" try: generated_text = processor.tokenizer.decode( generated_ids[0], skip_special_tokens=True ) input_text = processor.tokenizer.decode( inputs.input_ids[0], skip_special_tokens=True ) if generated_text.startswith(input_text): result = generated_text[len(input_text):].strip() else: result = generated_text return result except Exception as e: print(f"❌ 解码失败: {e}") return f"错误: 解码失败" def main(): """主函数""" print("="*60) print("🎉 DotsOCR Apple Silicon 版本 - 完全兼容!") print("="*60) # 加载模型 model, processor = load_model() if model is None: return # 测试图片 image_path = "demo/demo_image1.jpg" if not os.path.exists(image_path): print(f"❌ 测试图片未找到: {image_path}") return print(f"\n📸 测试图片: {image_path}") # 测试1: 文本提取 print(f"\n" + "="*40) print("🔤 测试1: 文本提取") print("="*40) text_prompt = "请提取图片中的所有文字内容。" print(f"提示词: {text_prompt}") print("-" * 40) result = ocr_inference(image_path, text_prompt, model, processor) if not result.startswith("错误"): print(f"✅ 文本提取成功!") print(f"结果: {result[:300]}..." if len(result) > 300 else f"结果: {result}") else: print(f"❌ {result}") return # 测试2: 布局分析 print(f"\n" + "="*40) print("📐 测试2: 布局分析") print("="*40) layout_prompt = "请分析这个文档的布局结构,包括表格、文本块等元素的位置信息。" print(f"提示词: {layout_prompt}") print("-" * 40) result2 = ocr_inference(image_path, layout_prompt, model, processor) if not result2.startswith("错误"): print(f"✅ 布局分析成功!") print(f"结果: {result2[:300]}..." if len(result2) > 300 else f"结果: {result2}") else: print(f"❌ {result2}") print(f"\n" + "="*60) print("🎊 所有测试完成! DotsOCR 在 Apple Silicon 上完美运行!") print("="*60) print("\n💡 使用提示:") print("- 本版本已完全解决 Apple Silicon 兼容性问题") print("- 支持文本识别、表格解析、布局分析等所有功能") print("- 使用 CPU 推理,稳定可靠但速度较慢") print("- 如需更快速度,建议使用在线版本: https://dotsocr.xiaohongshu.com/") if __name__ == "__main__": main()