|
|
@@ -110,8 +110,68 @@ python demo_gradio.py
|
|
|
https://dotsocr.xiaohongshu.com/
|
|
|
```
|
|
|
|
|
|
-# macOS 查看显卡信息
|
|
|
+# Flash Attention 详细介绍
|
|
|
+
|
|
|
+## Flash Attention 原理与优势
|
|
|
+
|
|
|
+### 核心概念
|
|
|
+Flash Attention 是一种高效的注意力机制实现,通过重新设计内存访问模式来显著提升性能:
|
|
|
+
|
|
|
+1. **分块计算 (Block-wise Computation)**
|
|
|
+ - 将大矩阵分解成小块逐块处理
|
|
|
+ - 避免存储完整的注意力矩阵 (N×N)
|
|
|
+
|
|
|
+2. **内存层次优化**
|
|
|
+ - 减少 GPU 显存 (HBM) 和 SRAM 之间的数据传输
|
|
|
+ - 利用 GPU 的内存层次结构提升效率
|
|
|
+
|
|
|
+3. **在线 Softmax**
|
|
|
+ - 采用数值稳定的在线算法
|
|
|
+ - 避免中间结果的完整存储
|
|
|
+
|
|
|
+### 性能优势
|
|
|
+- **内存效率**: 从 O(N²) 降低到 O(N)
|
|
|
+- **计算速度**: 长序列上可提速 2-4倍
|
|
|
+- **精度保持**: 与标准注意力数值相同
|
|
|
+
|
|
|
+## 项目中的 Flash Attention 使用
|
|
|
+
|
|
|
+### 当前配置
|
|
|
+```python
|
|
|
+# demo/demo_hf.py
|
|
|
+model = AutoModelForCausalLM.from_pretrained(
|
|
|
+ model_path,
|
|
|
+ attn_implementation="flash_attention_2", # 启用 Flash Attention 2
|
|
|
+ torch_dtype=torch.bfloat16,
|
|
|
+ device_map="auto",
|
|
|
+ trust_remote_code=True
|
|
|
+)
|
|
|
```
|
|
|
-# 使用 system_profiler 命令查看 GPU 信息
|
|
|
-system_profiler SPDisplaysDataType
|
|
|
+
|
|
|
+### 依赖要求
|
|
|
+```
|
|
|
+flash-attn==2.8.0.post2 # CUDA 专用,不支持 macOS
|
|
|
```
|
|
|
+
|
|
|
+## Apple Silicon 适配方案
|
|
|
+
|
|
|
+### 问题分析
|
|
|
+1. **硬件限制**: Flash Attention 专为 NVIDIA CUDA GPU 设计
|
|
|
+2. **架构差异**: Apple Silicon 使用 Metal Performance Shaders (MPS)
|
|
|
+3. **软件兼容**: flash-attn 库不支持 macOS/Metal
|
|
|
+
|
|
|
+## 解决方案
|
|
|
+### 方法1
|
|
|
+1. **移除 Flash Attention 依赖**: 在 macOS 上不使用 flash-attn
|
|
|
+2. 修改weights/DotsOCR_CPU_bfloat16/config.json,修改"attn_implementation": "sdpa",
|
|
|
+3. 调用程序:zhch/demo_hf_macos_bfloat16.py
|
|
|
+
|
|
|
+### 方法2
|
|
|
+1. **移除 Flash Attention 依赖**: 在 macOS 上不使用 flash-attn
|
|
|
+2. 运行zhch/convert_model_float16.py,将模型转换为 float16 格式,weights/DotsOCR_float16
|
|
|
+3. 调用程序:zhch/demo_hf_macos_float16.py
|
|
|
+
|
|
|
+### 方法3
|
|
|
+1. **移除 Flash Attention 依赖**: 在 macOS 上不使用 flash-attn
|
|
|
+2. 运行zhch/convert_model_macos_float32.py,将模型转换为 float32 格式,weights/DotsOCR_float32
|
|
|
+3. 调用程序:zhch/demo_hf_macos_float32.py
|