demo_hf_macos_float32.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. #!/usr/bin/env python3
  2. """
  3. DotsOCR Apple Silicon 最终版本
  4. 已完全解决所有兼容性问题,支持文本识别和布局分析
  5. 使用方法:
  6. python demo_apple_silicon.py
  7. """
  8. import os
  9. import platform
  10. import torch
  11. from transformers.models.auto.modeling_auto import AutoModelForCausalLM
  12. from transformers.models.auto.processing_auto import AutoProcessor
  13. from qwen_vl_utils import process_vision_info
  14. def load_model(model_path="./weights/DotsOCR_float32"):
  15. """加载 Apple Silicon 兼容的模型"""
  16. print(f"🍎 Apple Silicon DotsOCR v1.0")
  17. print(f"系统: {platform.system()} {platform.machine()}")
  18. print(f"PyTorch: {torch.__version__}")
  19. if not os.path.exists(model_path):
  20. print(f"❌ 模型未找到: {model_path}")
  21. print("请先运行: python tools/convert_model_macos.py")
  22. return None, None
  23. print(f"📦 加载模型: {model_path}")
  24. model_kwargs = {
  25. "torch_dtype": torch.float32,
  26. "trust_remote_code": True,
  27. "low_cpu_mem_usage": True,
  28. "device_map": "cpu",
  29. }
  30. try:
  31. model = AutoModelForCausalLM.from_pretrained(model_path, **model_kwargs)
  32. model.eval()
  33. print("✅ 模型加载成功")
  34. processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
  35. print("✅ 处理器加载成功")
  36. return model, processor
  37. except Exception as e:
  38. print(f"❌ 加载失败: {e}")
  39. return None, None
  40. def ocr_inference(image_path, prompt, model, processor):
  41. """Apple Silicon 优化的 OCR 推理"""
  42. messages = [
  43. {
  44. "role": "user",
  45. "content": [
  46. {"type": "image", "image": image_path},
  47. {"type": "text", "text": prompt}
  48. ]
  49. }
  50. ]
  51. try:
  52. text = processor.apply_chat_template(
  53. messages, tokenize=False, add_generation_prompt=True
  54. )
  55. except Exception as e:
  56. print(f"⚠️ 模板处理警告: {e}")
  57. text = prompt
  58. try:
  59. vision_info = process_vision_info(messages)
  60. image_inputs = vision_info[0] if len(vision_info) > 0 else None
  61. video_inputs = vision_info[1] if len(vision_info) > 1 else None
  62. except Exception as e:
  63. print(f"❌ 视觉信息处理失败: {e}")
  64. return f"错误: 视觉处理失败"
  65. try:
  66. inputs = processor(
  67. text=[text],
  68. images=image_inputs,
  69. videos=video_inputs,
  70. padding=True,
  71. return_tensors="pt",
  72. )
  73. except Exception as e:
  74. print(f"❌ 输入处理失败: {e}")
  75. return f"错误: 输入处理失败"
  76. # 确保所有张量都在 CPU 上且为 float32
  77. inputs = inputs.to("cpu")
  78. for key, value in inputs.items():
  79. if isinstance(value, torch.Tensor) and value.dtype in [torch.float16, torch.bfloat16]:
  80. inputs[key] = value.to(torch.float32)
  81. try:
  82. print("🚀 开始推理...")
  83. with torch.no_grad():
  84. generated_ids = model.generate(
  85. **inputs,
  86. max_new_tokens=500,
  87. do_sample=False,
  88. pad_token_id=processor.tokenizer.eos_token_id,
  89. eos_token_id=processor.tokenizer.eos_token_id,
  90. output_attentions=False,
  91. output_hidden_states=False,
  92. )
  93. except Exception as e:
  94. print(f"❌ 推理失败: {e}")
  95. return f"错误: 推理失败"
  96. try:
  97. generated_text = processor.tokenizer.decode(
  98. generated_ids[0],
  99. skip_special_tokens=True
  100. )
  101. input_text = processor.tokenizer.decode(
  102. inputs.input_ids[0],
  103. skip_special_tokens=True
  104. )
  105. if generated_text.startswith(input_text):
  106. result = generated_text[len(input_text):].strip()
  107. else:
  108. result = generated_text
  109. return result
  110. except Exception as e:
  111. print(f"❌ 解码失败: {e}")
  112. return f"错误: 解码失败"
  113. def main():
  114. """主函数"""
  115. print("="*60)
  116. print("🎉 DotsOCR Apple Silicon 版本 - 完全兼容!")
  117. print("="*60)
  118. # 加载模型
  119. model, processor = load_model()
  120. if model is None:
  121. return
  122. # 测试图片
  123. image_path = "demo/demo_image1.jpg"
  124. if not os.path.exists(image_path):
  125. print(f"❌ 测试图片未找到: {image_path}")
  126. return
  127. print(f"\n📸 测试图片: {image_path}")
  128. # 测试1: 文本提取
  129. print(f"\n" + "="*40)
  130. print("🔤 测试1: 文本提取")
  131. print("="*40)
  132. text_prompt = "请提取图片中的所有文字内容。"
  133. print(f"提示词: {text_prompt}")
  134. print("-" * 40)
  135. result = ocr_inference(image_path, text_prompt, model, processor)
  136. if not result.startswith("错误"):
  137. print(f"✅ 文本提取成功!")
  138. print(f"结果: {result[:300]}..." if len(result) > 300 else f"结果: {result}")
  139. else:
  140. print(f"❌ {result}")
  141. return
  142. # 测试2: 布局分析
  143. print(f"\n" + "="*40)
  144. print("📐 测试2: 布局分析")
  145. print("="*40)
  146. layout_prompt = "请分析这个文档的布局结构,包括表格、文本块等元素的位置信息。"
  147. print(f"提示词: {layout_prompt}")
  148. print("-" * 40)
  149. result2 = ocr_inference(image_path, layout_prompt, model, processor)
  150. if not result2.startswith("错误"):
  151. print(f"✅ 布局分析成功!")
  152. print(f"结果: {result2[:300]}..." if len(result2) > 300 else f"结果: {result2}")
  153. else:
  154. print(f"❌ {result2}")
  155. print(f"\n" + "="*60)
  156. print("🎊 所有测试完成! DotsOCR 在 Apple Silicon 上完美运行!")
  157. print("="*60)
  158. print("\n💡 使用提示:")
  159. print("- 本版本已完全解决 Apple Silicon 兼容性问题")
  160. print("- 支持文本识别、表格解析、布局分析等所有功能")
  161. print("- 使用 CPU 推理,稳定可靠但速度较慢")
  162. print("- 如需更快速度,建议使用在线版本: https://dotsocr.xiaohongshu.com/")
  163. if __name__ == "__main__":
  164. main()