| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198 |
- #!/usr/bin/env python3
- """
- DotsOCR Apple Silicon 最终版本
- 已完全解决所有兼容性问题,支持文本识别和布局分析
- 使用方法:
- python demo_apple_silicon.py
- """
- import os
- import platform
- import torch
- from transformers.models.auto.modeling_auto import AutoModelForCausalLM
- from transformers.models.auto.processing_auto import AutoProcessor
- from qwen_vl_utils import process_vision_info
- def load_model(model_path="./weights/DotsOCR_float32"):
- """加载 Apple Silicon 兼容的模型"""
- print(f"🍎 Apple Silicon DotsOCR v1.0")
- print(f"系统: {platform.system()} {platform.machine()}")
- print(f"PyTorch: {torch.__version__}")
-
- if not os.path.exists(model_path):
- print(f"❌ 模型未找到: {model_path}")
- print("请先运行: python tools/convert_model_macos.py")
- return None, None
-
- print(f"📦 加载模型: {model_path}")
-
- model_kwargs = {
- "torch_dtype": torch.float32,
- "trust_remote_code": True,
- "low_cpu_mem_usage": True,
- "device_map": "cpu",
- }
-
- try:
- model = AutoModelForCausalLM.from_pretrained(model_path, **model_kwargs)
- model.eval()
- print("✅ 模型加载成功")
-
- processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
- print("✅ 处理器加载成功")
-
- return model, processor
-
- except Exception as e:
- print(f"❌ 加载失败: {e}")
- return None, None
- def ocr_inference(image_path, prompt, model, processor):
- """Apple Silicon 优化的 OCR 推理"""
- messages = [
- {
- "role": "user",
- "content": [
- {"type": "image", "image": image_path},
- {"type": "text", "text": prompt}
- ]
- }
- ]
- try:
- text = processor.apply_chat_template(
- messages, tokenize=False, add_generation_prompt=True
- )
- except Exception as e:
- print(f"⚠️ 模板处理警告: {e}")
- text = prompt
- try:
- vision_info = process_vision_info(messages)
- image_inputs = vision_info[0] if len(vision_info) > 0 else None
- video_inputs = vision_info[1] if len(vision_info) > 1 else None
- except Exception as e:
- print(f"❌ 视觉信息处理失败: {e}")
- return f"错误: 视觉处理失败"
- try:
- inputs = processor(
- text=[text],
- images=image_inputs,
- videos=video_inputs,
- padding=True,
- return_tensors="pt",
- )
- except Exception as e:
- print(f"❌ 输入处理失败: {e}")
- return f"错误: 输入处理失败"
- # 确保所有张量都在 CPU 上且为 float32
- inputs = inputs.to("cpu")
- for key, value in inputs.items():
- if isinstance(value, torch.Tensor) and value.dtype in [torch.float16, torch.bfloat16]:
- inputs[key] = value.to(torch.float32)
- try:
- print("🚀 开始推理...")
- with torch.no_grad():
- generated_ids = model.generate(
- **inputs,
- max_new_tokens=500,
- do_sample=False,
- pad_token_id=processor.tokenizer.eos_token_id,
- eos_token_id=processor.tokenizer.eos_token_id,
- output_attentions=False,
- output_hidden_states=False,
- )
-
- except Exception as e:
- print(f"❌ 推理失败: {e}")
- return f"错误: 推理失败"
- try:
- generated_text = processor.tokenizer.decode(
- generated_ids[0],
- skip_special_tokens=True
- )
-
- input_text = processor.tokenizer.decode(
- inputs.input_ids[0],
- skip_special_tokens=True
- )
-
- if generated_text.startswith(input_text):
- result = generated_text[len(input_text):].strip()
- else:
- result = generated_text
-
- return result
-
- except Exception as e:
- print(f"❌ 解码失败: {e}")
- return f"错误: 解码失败"
- def main():
- """主函数"""
- print("="*60)
- print("🎉 DotsOCR Apple Silicon 版本 - 完全兼容!")
- print("="*60)
-
- # 加载模型
- model, processor = load_model()
- if model is None:
- return
-
- # 测试图片
- image_path = "demo/demo_image1.jpg"
- if not os.path.exists(image_path):
- print(f"❌ 测试图片未找到: {image_path}")
- return
-
- print(f"\n📸 测试图片: {image_path}")
-
- # 测试1: 文本提取
- print(f"\n" + "="*40)
- print("🔤 测试1: 文本提取")
- print("="*40)
-
- text_prompt = "请提取图片中的所有文字内容。"
- print(f"提示词: {text_prompt}")
- print("-" * 40)
-
- result = ocr_inference(image_path, text_prompt, model, processor)
- if not result.startswith("错误"):
- print(f"✅ 文本提取成功!")
- print(f"结果: {result[:300]}..." if len(result) > 300 else f"结果: {result}")
- else:
- print(f"❌ {result}")
- return
-
- # 测试2: 布局分析
- print(f"\n" + "="*40)
- print("📐 测试2: 布局分析")
- print("="*40)
-
- layout_prompt = "请分析这个文档的布局结构,包括表格、文本块等元素的位置信息。"
- print(f"提示词: {layout_prompt}")
- print("-" * 40)
-
- result2 = ocr_inference(image_path, layout_prompt, model, processor)
- if not result2.startswith("错误"):
- print(f"✅ 布局分析成功!")
- print(f"结果: {result2[:300]}..." if len(result2) > 300 else f"结果: {result2}")
- else:
- print(f"❌ {result2}")
-
- print(f"\n" + "="*60)
- print("🎊 所有测试完成! DotsOCR 在 Apple Silicon 上完美运行!")
- print("="*60)
-
- print("\n💡 使用提示:")
- print("- 本版本已完全解决 Apple Silicon 兼容性问题")
- print("- 支持文本识别、表格解析、布局分析等所有功能")
- print("- 使用 CPU 推理,稳定可靠但速度较慢")
- print("- 如需更快速度,建议使用在线版本: https://dotsocr.xiaohongshu.com/")
- if __name__ == "__main__":
- main()
|