2022 年,斯坦福大学发表了一篇论文,彻底改变了注意力计算的实现方式。
标题是:《FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness》
核心创新:不再关注计算复杂度,而是关注 IO 复杂度。
这个视角的转换,让注意力计算的速度提升了 2-4 倍,显存占用从 O(N²) 降到 O(N)。后续的 Flash Attention 2/3/4 进一步优化,成为现代大模型训练的标配。
Flash Attention 重新定义了注意力计算的边界。
本文要点
- 注意力计算的瓶颈分析
- Flash Attention 1:IO-aware 设计
- Flash Attention 2:并行优化
- Flash Attention 3:Tensor Core 利用
- Flash Attention 4:长上下文突破
- 实际性能对比
一、注意力计算的瓶颈
1.1 标准 Attention 的实现
1.2 GPU 内存层次
二、Flash Attention 1:IO-Aware 设计
2.1 核心思想:Tiling
2.2 在线 Softmax
def online_softmax(Q, K, V): """ 在线 Softmax:逐块计算,不存储完整注意力矩阵
关键技巧: - 使用 max 值的增量更新 - 使用 sum 值的增量更新 - 数值稳定 """ N, d = Q.shape B_r = 64 # 行块大小 B_c = 64 # 列块大小
O = torch.zeros(N, d) l = torch.zeros(N) # softmax 的分母 m = torch.full((N,), float('-inf')) # 每行的最大值
for i in range(0, N, B_r): Q_i = Q[i:i+B_r] # 加载 Q 块到 SRAM
for j in range(0, N, B_c): K_j = K[j:j+B_c] # 加载 K 块 V_j = V[j:j+B_c] # 加载 V 块
# 计算局部注意力分数 S_ij = Q_i @ K_j.T
# 在线更新 max m_new = torch.max(m[i:i+B_r], S_ij.max(dim=1).values) # 在线更新分母 l_new = ( l[i:i+B_r] * torch.exp(m[i:i+B_r] - m_new) + torch.exp(S_ij - m_new.unsqueeze(1)).sum(dim=1) )
# 在线更新输出 O[i:i+B_r] = ( O[i:i+B_r] * l[i:i+B_r].unsqueeze(1) * torch.exp(m[i:i+B_r].unsqueeze(1) - m_new.unsqueeze(1)) + torch.exp(S_ij - m_new.unsqueeze(1)) @ V_j ) / l_new.unsqueeze(1)
m[i:i+B_r] = m_new l[i:i+B_r] = l_new
return O2.3 IO 复杂度分析
三、Flash Attention 2:并行优化
3.1 主要改进
3.2 并行策略
# Flash Attention 2 的并行策略def flash_attention_2(Q, K, V): """ Flash Attention 2: 更好的并行性 """ batch_size, num_heads, seq_len, head_dim = Q.shape
# 每个注意力头可以独立计算 # 在 batch 和 head 维度并行
# 块大小调优 BLOCK_M = 128 # Q 块大小 BLOCK_N = 64 # K, V 块大小
# 总并行线程数 = batch_size * num_heads * (seq_len / BLOCK_M) # 充分利用 GPU
output = torch.empty_like(Q)
# 并行核函数 for b in range(batch_size): for h in range(num_heads): # 每个 (b, h) 对应一个线程块 output[b, h] = flash_attention_kernel( Q[b, h], K[b, h], V[b, h], BLOCK_M, BLOCK_N )
return output四、Flash Attention 3:Tensor Core 优化
4.1 Hopper GPU 特性
4.2 Flash Attention 3 的创新
4.3 代码架构
# Flash Attention 3 的核函数结构def flash_attention_3_kernel(Q, K, V, output, lse, cu_seqlens): """ Flash Attention 3: 异步 + Tensor Core """ # 1. 使用 TMA 异步加载数据 # TMA 可以在后台传输数据,不占用计算资源 tma_load_async(Q_block, smem_Q) tma_load_async(K_block, smem_K) tma_load_async(V_block, smem_V)
# 2. 计算和传输重叠 # 当一个块在计算时,下一个块在传输 for block in blocks: # 等待数据到达 tma_wait()
# Tensor Core 矩阵乘法 wgmma(smem_Q, smem_K, acc)
# 启动下一个块的传输 tma_load_async(next_Q_block, smem_Q_next)
# 3. 使用 FP32 累积,保证精度 # 4. 使用 WGMMA 指令最大化 Tensor Core 利用五、Flash Attention 4:长上下文突破
5.1 核心改进
六、性能对比与选型
6.1 各版本对比
| 特性 | FA1 | FA2 | FA3 ||---------------|-------------|-------------|------------------|| 发布时间 | 2022 | 2023 | 2024 || 相对标准加速 | 2-4x | 4-8x | 6-15x || 显存占用 | O(N) | O(N) | O(N) || GPU 利用率 | ~30% | ~50% | ~75% || Tensor Core | 否 | 部分 | 是 || 异步执行 | 否 | 否 | 是 || FP8 支持 | 否 | 否 | 是 || 最大序列长度 | ~64K | ~128K | 256K+ |6.2 实际性能数据
| 序列长度 | 标准 Attn | Flash Attn 2| 加速比 ||---------------|-------------|-------------|------------------|| 1024 | 1.2ms | 0.3ms | 4x || 4096 | 18ms | 2.5ms | 7x || 16384 | OOM | 40ms | ∞ || 32768 | OOM | 160ms | ∞ |6.3 框架支持
七、使用指南
7.1 PyTorch 集成
import torchfrom torch.nn.functional import scaled_dot_product_attention
# PyTorch 2.0+ 自动使用 Flash AttentionQ = torch.randn(1, 8, 4096, 64, device='cuda', dtype=torch.float16)K = torch.randn(1, 8, 4096, 64, device='cuda', dtype=torch.float16)V = torch.randn(1, 8, 4096, 64, device='cuda', dtype=torch.float16)
# 自动选择最优实现(包括 Flash Attention)output = scaled_dot_product_attention(Q, K, V, is_causal=True)7.2 显式使用 Flash Attention 2
from flash_attn import flash_attn_func
# Flash Attention 2 显式调用output = flash_attn_func( Q, K, V, causal=True, # 是否使用因果掩码 softmax_scale=None # softmax 缩放因子)7.3 变长序列支持
from flash_attn import flash_attn_varlen_func
# 支持变长序列(多序列打包)# cu_seqlens 指定每个序列的起止位置cu_seqlens = torch.tensor([0, 128, 256, 512], device='cuda')
output = flash_attn_varlen_func( Q_packed, # [total_seq_len, num_heads, head_dim] K_packed, V_packed, cu_seqlens, max_seqlen=384 # 最大序列长度)常见问题 FAQ
Q1:Flash Attention 会改变计算结果吗?
A:不会。Flash Attention 是数值精确的,结果与标准实现完全一致。只是改变了计算顺序和内存访问模式。
Q2:所有 GPU 都支持 Flash Attention 吗?
A:需要较新的 GPU。FA2 需要 Ampere (A100) 或更新架构。FA3 需要 Hopper (H100)。旧 GPU 会自动回退到标准实现。
Q3:Flash Attention 支持自定义注意力掩码吗?
A:有限支持。FA 主要针对因果注意力和无掩码注意力优化。自定义掩码需要特殊处理,可能无法获得全部优化收益。
Q4:如何选择 Flash Attention 版本?
A:根据硬件选择。H100/A100 用 FA3/FA2,旧 GPU 用 FA1 或标准实现。大多数框架会自动选择。
Q5:Flash Attention 对训练和推理都有用吗?
A:是的。训练时减少显存、加速计算;推理时支持更长上下文、降低延迟。
小结
Flash Attention 从 IO 复杂度角度重新审视注意力计算,彻底改变了实现方式。
核心贡献:
Flash Attention 让长上下文模型成为可能。
参考资料
- FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness - Dao et al. 2022
- FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning - Dao 2023
- FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision - Dao et al. 2024
- Flash Attention GitHub
支持与分享
如果这篇文章对你有帮助,欢迎支持作者或分享给更多人
部分信息可能已经过时






