mobile wallpaper 1mobile wallpaper 2mobile wallpaper 3mobile wallpaper 4
733 字
2 分钟
Flash Attention:GPU 时代的注意力机制优化
2025-06-21

2022 年,斯坦福大学发表了一篇论文,彻底改变了注意力计算的实现方式。

标题是:《FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness》

核心创新:不再关注计算复杂度,而是关注 IO 复杂度。

这个视角的转换,让注意力计算的速度提升了 2-4 倍,显存占用从 O(N²) 降到 O(N)。后续的 Flash Attention 2/3/4 进一步优化,成为现代大模型训练的标配。

Flash Attention 重新定义了注意力计算的边界。

本文要点#

  • 注意力计算的瓶颈分析
  • Flash Attention 1:IO-aware 设计
  • Flash Attention 2:并行优化
  • Flash Attention 3:Tensor Core 利用
  • Flash Attention 4:长上下文突破
  • 实际性能对比

一、注意力计算的瓶颈#

1.1 标准 Attention 的实现#

flowchart TB subgraph 标准注意力计算流程 A[1. 计算 S = QK^T] A --> A1["O(N² × d)"] A --> A2[需要存储 N×N 的注意力矩阵] A --> A3[显存占用:O(N²)] B[2. 计算 P = softmax(S)] B --> B1["O(N²)"] B --> B2[需要读取整个 S 矩阵] C[3. 计算 O = PV] C --> C1["O(N² × d)"] C --> C2[需要读取 P 和 V] D[总显存占用] D --> D1[O(N²)] E[总 HBM 访问] E --> E1["O(N² + Nd)"] F[问题] F --> F1["N = 8K 时,N² = 64M,需要 512MB 显存"] F --> F2["N = 16K 时,N² = 256M,需要 2GB 显存"] F --> F3["N = 32K 时,N² = 1B,需要 8GB 显存"] end

1.2 GPU 内存层次#

flowchart TB A["GPU 内存层次"] --> B["HBM (High Bandwidth Memory)"] A --> C["SRAM (On-chip Memory)"] B --> D["容量:40-80 GB"] B --> E["带宽:1-2 TB/s"] C --> F["容量:~20 MB"] C --> G["带宽:~19 TB/s"] H["关键洞察"] --> I["SRAM 带宽是 HBM 的 10-20 倍"] I --> J["但容量很小"] J --> K["优化方向:减少 HBM 访问"] style C fill:#4caf50,color:#fff style G fill:#4caf50,color:#fff
flowchart TB subgraph IO复杂度才是真正的瓶颈 A[计算复杂度 vs IO 复杂度] B[计算复杂度] B --> B1["标准 Attention:O(N² × d)"] B --> B2[这是理论计算量] C[IO 复杂度] C --> C1[HBM 访问次数] C --> C2[这是实际瓶颈] D[类比] D --> D1[计算 = CPU 的处理速度] D --> D2[IO = 内存读写速度] D --> D3[当 IO << 计算时,IO 成为瓶颈] E[Flash Attention 的洞察] E --> E1[GPU 计算能力已经很强] E --> E2[HBM 带宽是瓶颈] E --> E3[应该优化 IO,而非计算] end

二、Flash Attention 1:IO-Aware 设计#

2.1 核心思想:Tiling#

flowchart TB A["Q, K, V 矩阵"] --> B["分块加载到 SRAM"] B --> C["在 SRAM 中计算"] C --> D["只写回最终结果"] E["关键创新"] --> F["不存储完整的 N×N 注意力矩阵"] F --> G["逐块计算,逐块输出"] G --> H["显存从 O(N²) 降到 O(N)"] style H fill:#4caf50,color:#fff
flowchart TB subgraph Flash Attention算法 A["输入:Q, K, V ∈ R^(N×d)"] A --> A1["输出:O ∈ R^(N×d)"] B[算法步骤] B --> B1[将 Q, K, V 分成小块(tiles)] B1 --> B1a[块大小根据 SRAM 容量确定] B1 --> B1b[典型值:B_r × B_c ≈ SRAM 大小] B --> B2[对每个 Q 块] B2 --> B2a[加载 Q_i 到 SRAM] B2 --> B2b[对每个 K, V 块] B2 --> B2c[加载 K_j, V_j 到 SRAM] B2 --> B2d[计算 S_ij = Q_i × K_j^T] B2 --> B2e[计算 P_ij = softmax(S_ij)] B2 --> B2f[累加 O_i += P_ij × V_j] B2 --> B2g[写回 O_i 到 HBM] C[关键] C --> C1[softmax 的在线计算] C --> C2[不需要存储完整的 S 矩阵] C --> C3[使用数值稳定的增量 softmax] end

2.2 在线 Softmax#

def online_softmax(Q, K, V):
"""
在线 Softmax:逐块计算,不存储完整注意力矩阵
关键技巧:
- 使用 max 值的增量更新
- 使用 sum 值的增量更新
- 数值稳定
"""
N, d = Q.shape
B_r = 64 # 行块大小
B_c = 64 # 列块大小
O = torch.zeros(N, d)
l = torch.zeros(N) # softmax 的分母
m = torch.full((N,), float('-inf')) # 每行的最大值
for i in range(0, N, B_r):
Q_i = Q[i:i+B_r] # 加载 Q 块到 SRAM
for j in range(0, N, B_c):
K_j = K[j:j+B_c] # 加载 K 块
V_j = V[j:j+B_c] # 加载 V 块
# 计算局部注意力分数
S_ij = Q_i @ K_j.T
# 在线更新 max
m_new = torch.max(m[i:i+B_r], S_ij.max(dim=1).values)
# 在线更新分母
l_new = (
l[i:i+B_r] * torch.exp(m[i:i+B_r] - m_new) +
torch.exp(S_ij - m_new.unsqueeze(1)).sum(dim=1)
)
# 在线更新输出
O[i:i+B_r] = (
O[i:i+B_r] * l[i:i+B_r].unsqueeze(1) *
torch.exp(m[i:i+B_r].unsqueeze(1) - m_new.unsqueeze(1)) +
torch.exp(S_ij - m_new.unsqueeze(1)) @ V_j
) / l_new.unsqueeze(1)
m[i:i+B_r] = m_new
l[i:i+B_r] = l_new
return O

2.3 IO 复杂度分析#

flowchart TB subgraph IO复杂度对比 A[标准 Attention] A --> A1["HBM 读取:Q(Nd) + K(Nd) + V(Nd) + S(N²) + P(N²)"] A --> A2["HBM 写入:S(N²) + P(N²) + O(Nd)"] A --> A3["总计:O(N²)"] B[Flash Attention] B --> B1["HBM 读取:Q(Nd) + K(Nd) + V(Nd)"] B --> B2["HBM 写入:O(Nd)"] B --> B3["总计:O(Nd)"] C[加速比] C --> C1["当 N >> d 时,O(N²) / O(Nd) ≈ N/d"] C --> C2["N = 4096, d = 64 时,理论加速 64 倍"] C --> C3[实际加速 2-4 倍(因为还有其他开销)] D[显存占用] D --> D1["标准:O(N² + Nd)"] D --> D2["Flash:O(Nd)"] D --> D3[大幅降低,支持更长序列] end

三、Flash Attention 2:并行优化#

3.1 主要改进#

flowchart TB subgraph Flash Attention 2改进 A[更好的并行性] A --> A1[FA1:按序列并行,GPU 利用率低] A --> A2[FA2:按批次和头并行,GPU 利用率高] B[减少非矩阵乘法操作] B --> B1[FA1:有额外的 rescaling 操作] B --> B2[FA2:优化计算顺序,减少操作] C[更好的块大小选择] C --> C1[FA2 自动调优块大小] C --> C2[适应不同 GPU 架构] D[性能提升] D --> D1[相比 FA1:约 2 倍加速] D --> D2[相比标准 Attention:约 4-8 倍加速] end

3.2 并行策略#

flowchart TB subgraph Flash Attention 1 A1["按序列并行"] --> B1["GPU 利用率低"] end subgraph Flash Attention 2 A2["按 Batch × Head 并行"] --> B2["GPU 利用率高"] end C["关键洞察"] --> D["注意力头的计算独立"] D --> E["可以完全并行"]
# Flash Attention 2 的并行策略
def flash_attention_2(Q, K, V):
"""
Flash Attention 2: 更好的并行性
"""
batch_size, num_heads, seq_len, head_dim = Q.shape
# 每个注意力头可以独立计算
# 在 batch 和 head 维度并行
# 块大小调优
BLOCK_M = 128 # Q 块大小
BLOCK_N = 64 # K, V 块大小
# 总并行线程数 = batch_size * num_heads * (seq_len / BLOCK_M)
# 充分利用 GPU
output = torch.empty_like(Q)
# 并行核函数
for b in range(batch_size):
for h in range(num_heads):
# 每个 (b, h) 对应一个线程块
output[b, h] = flash_attention_kernel(
Q[b, h], K[b, h], V[b, h],
BLOCK_M, BLOCK_N
)
return output

四、Flash Attention 3:Tensor Core 优化#

4.1 Hopper GPU 特性#

flowchart TB subgraph Hopper GPU特性 A[Tensor Core] A --> A1[专门用于矩阵乘法的硬件单元] A --> A2["理论性能:989 TFLOPS (FP16)"] A --> A3[比 CUDA Core 快得多] B[TMA (Tensor Memory Accelerator)] B --> B1[异步内存传输] B --> B2[计算和内存传输可以重叠] C[WGMA (Warp Group Matrix Multiply Accumulate)] C --> C1[Tensor Core 的新指令] C --> C2[更高的矩阵乘法效率] end

4.2 Flash Attention 3 的创新#

flowchart TB A["Flash Attention 3"] --> B["异步执行"] A --> C["Tensor Core 利用"] A --> D["低精度支持"] B --> B1["计算和内存传输重叠"] C --> C1["使用 WGMMA 指令"] D --> D1["FP8 支持"] E["性能提升"] --> F["相比 FA2: 1.5-2 倍加速"]
flowchart TB subgraph Flash Attention 3技术要点 A[异步 TMA 传输] A --> A1[使用 TMA 异步加载 Q, K, V] A --> A2[计算和传输重叠] A --> A3[隐藏内存延迟] B[Tensor Core GEMM+SOFTMAX 流水线] B --> B1[GEMM(矩阵乘法)使用 Tensor Core] B --> B2[Softmax 使用 CUDA Core] B --> B3[两者流水线执行] C[低精度优化] C --> C1[支持 FP8 输入] C --> C2[内部使用 FP32 累积] C --> C3[精度和速度的平衡] D[性能数据(H100)] D --> D1[相比 FA2:1.5-2 倍加速] D --> D2["GPU 利用率:从 ~40% 提升到 ~75%"] end

4.3 代码架构#

# Flash Attention 3 的核函数结构
def flash_attention_3_kernel(Q, K, V, output, lse, cu_seqlens):
"""
Flash Attention 3: 异步 + Tensor Core
"""
# 1. 使用 TMA 异步加载数据
# TMA 可以在后台传输数据,不占用计算资源
tma_load_async(Q_block, smem_Q)
tma_load_async(K_block, smem_K)
tma_load_async(V_block, smem_V)
# 2. 计算和传输重叠
# 当一个块在计算时,下一个块在传输
for block in blocks:
# 等待数据到达
tma_wait()
# Tensor Core 矩阵乘法
wgmma(smem_Q, smem_K, acc)
# 启动下一个块的传输
tma_load_async(next_Q_block, smem_Q_next)
# 3. 使用 FP32 累积,保证精度
# 4. 使用 WGMMA 指令最大化 Tensor Core 利用

五、Flash Attention 4:长上下文突破#

5.1 核心改进#

flowchart TB subgraph Flash Attention 4特性 A[支持更长上下文] A --> A1["FA3:最大 ~128K tokens"] A --> A2["FA4:支持 256K+ tokens"] B[技术改进] B --> B1[更激进的分块策略] B --> B2[内存效率优化] B --> B3[支持变长序列] C[新特性] C --> C1[支持滑动窗口注意力] C --> C2[支持局部注意力] C --> C3[更好的多查询支持] D[性能] D --> D1["长序列(>64K)效率显著提升"] D --> D2[显存占用进一步降低] end

六、性能对比与选型#

6.1 各版本对比#

| 特性 | FA1 | FA2 | FA3 |
|---------------|-------------|-------------|------------------|
| 发布时间 | 2022 | 2023 | 2024 |
| 相对标准加速 | 2-4x | 4-8x | 6-15x |
| 显存占用 | O(N) | O(N) | O(N) |
| GPU 利用率 | ~30% | ~50% | ~75% |
| Tensor Core | 否 | 部分 | 是 |
| 异步执行 | 否 | 否 | 是 |
| FP8 支持 | 否 | 否 | 是 |
| 最大序列长度 | ~64K | ~128K | 256K+ |

6.2 实际性能数据#

xychart-beta title "Attention 实现速度对比(A100, seq_len=4096)" x-axis ["Standard", "Flash Attn 1", "Flash Attn 2", "Flash Attn 3"] y-axis "速度 (TFLOPS)" 0 --> 300 bar [40, 120, 200, 280]
| 序列长度 | 标准 Attn | Flash Attn 2| 加速比 |
|---------------|-------------|-------------|------------------|
| 1024 | 1.2ms | 0.3ms | 4x |
| 4096 | 18ms | 2.5ms | 7x |
| 16384 | OOM | 40ms | ∞ |
| 32768 | OOM | 160ms | ∞ |

6.3 框架支持#

flowchart TB subgraph Flash Attention集成 A[PyTorch 2.0+] A --> A1[scaled_dot_product_attention 自动使用] A --> A2[无需手动调用] B[Hugging Face Transformers] B --> B1[默认使用 Flash Attention 2] B --> B2[支持自动回退] C[vLLM] C --> C1[深度集成 Flash Attention] C --> C2[支持所有版本] D[TensorRT-LLM] D --> D1[NVIDIA 官方优化] D --> D2[性能极致] end

七、使用指南#

7.1 PyTorch 集成#

import torch
from torch.nn.functional import scaled_dot_product_attention
# PyTorch 2.0+ 自动使用 Flash Attention
Q = torch.randn(1, 8, 4096, 64, device='cuda', dtype=torch.float16)
K = torch.randn(1, 8, 4096, 64, device='cuda', dtype=torch.float16)
V = torch.randn(1, 8, 4096, 64, device='cuda', dtype=torch.float16)
# 自动选择最优实现(包括 Flash Attention)
output = scaled_dot_product_attention(Q, K, V, is_causal=True)

7.2 显式使用 Flash Attention 2#

from flash_attn import flash_attn_func
# Flash Attention 2 显式调用
output = flash_attn_func(
Q, K, V,
causal=True, # 是否使用因果掩码
softmax_scale=None # softmax 缩放因子
)

7.3 变长序列支持#

from flash_attn import flash_attn_varlen_func
# 支持变长序列(多序列打包)
# cu_seqlens 指定每个序列的起止位置
cu_seqlens = torch.tensor([0, 128, 256, 512], device='cuda')
output = flash_attn_varlen_func(
Q_packed, # [total_seq_len, num_heads, head_dim]
K_packed,
V_packed,
cu_seqlens,
max_seqlen=384 # 最大序列长度
)

常见问题 FAQ#

Q1:Flash Attention 会改变计算结果吗?

A:不会。Flash Attention 是数值精确的,结果与标准实现完全一致。只是改变了计算顺序和内存访问模式。

Q2:所有 GPU 都支持 Flash Attention 吗?

A:需要较新的 GPU。FA2 需要 Ampere (A100) 或更新架构。FA3 需要 Hopper (H100)。旧 GPU 会自动回退到标准实现。

Q3:Flash Attention 支持自定义注意力掩码吗?

A:有限支持。FA 主要针对因果注意力和无掩码注意力优化。自定义掩码需要特殊处理,可能无法获得全部优化收益。

Q4:如何选择 Flash Attention 版本?

A:根据硬件选择。H100/A100 用 FA3/FA2,旧 GPU 用 FA1 或标准实现。大多数框架会自动选择。

Q5:Flash Attention 对训练和推理都有用吗?

A:是的。训练时减少显存、加速计算;推理时支持更长上下文、降低延迟。


小结#

Flash Attention 从 IO 复杂度角度重新审视注意力计算,彻底改变了实现方式。

核心贡献:

flowchart TB subgraph Flash Attention核心总结 A[核心洞察] --> A1[IO 复杂度是瓶颈,不是计算复杂度] B[关键技术] --> B1[Tiling + 在线 Softmax] C[效果] --> C1["显存 O(N²) → O(N),速度 2-15x 提升"] D[演进] --> D1[FA1 → FA2(并行)→ FA3(Tensor Core)] E[影响] --> E1[成为现代 LLM 训练标配] end

Flash Attention 让长上下文模型成为可能。


参考资料#

支持与分享

如果这篇文章对你有帮助,欢迎支持作者或分享给更多人

Flash Attention:GPU 时代的注意力机制优化
https://blog.souloss.com/posts/machine-learning/llm-paper-history/flashattention-efficient-attention/
作者
Souloss
发布于
2025-06-21
许可协议
CC BY-NC-SA 4.0

部分信息可能已经过时