本文要点
大语言模型的推理效率是一个关键挑战。随着上下文增长,Transformer 的注意力计算复杂度为 ,而且 KV 缓存(Key-Value Cache) 的内存占用也成为瓶颈。
Multi-Query Attention(MQA) 和 Grouped Query Attention(GQA) 通过优化 KV 缓存机制,显著提升了推理吞吐量。
一、标准 Multi-Head Attention
1.1 回顾
标准 MHA(Multi-Head Attention)中,每个注意力头都有独立的 W_Q、W_K、W_V 投影:
对于 个注意力头,每个头的 K 和 V 都是独立的。
1.2 KV 缓存问题
推理时,需要缓存所有历史 token 的 Key 和 Value:
# 伪代码k_cache = [] # 形状: [batch, heads, seq_len, head_dim]v_cache = []
for token in new_tokens: k, v = compute_kv(token) k_cache.append(k) v_cache.append(v) attn_output = attention(q, k_cache, v_cache)问题:
- 个头的 K、V 都要缓存
- 内存占用与 成正比
- 长上下文场景内存爆炸
二、Multi-Query Attention
2.1 核心思想
MQA(Multi-Query Attention,2019)让所有注意力头共享同一个 Key 和 Value 投影:
数学上:
- 多组 Query:
- 单一 Key:
- 单一 Value:
2.2 内存节省
| 配置 | KV 头数 | KV 缓存内存 |
|---|---|---|
| MHA (32 heads) | 32 | 32x |
| MQA | 1 | 1x |
| GQA (8 groups) | 8 | 8x |
对于 LLaMA 70B(80GB 显存):
- MHA:KV 缓存需要 80GB+ 显存
- MQA:大幅减少,可支持更长上下文
2.3 代码实现
class MultiQueryAttention(nn.Module): def __init__(self, d_model, n_heads): super().__init__() self.n_heads = n_heads self.head_dim = d_model // n_heads
# 多 Query 头 self.W_q = nn.Linear(d_model, d_model) # 单一 K/V 头 self.W_k = nn.Linear(d_model, self.head_dim) self.W_v = nn.Linear(d_model, self.head_dim)
def forward(self, x, kv_cache=None): # 多 Query q = self.W_q(x).view(batch, seq, self.n_heads, self.head_dim)
# 单一 K/V k = self.W_k(x).unsqueeze(1) # [B, 1, seq, head_dim] v = self.W_v(x).unsqueeze(1)
if kv_cache is not None: k = torch.cat([kv_cache['k'], k], dim=2) v = torch.cat([kv_cache['v'], v], dim=2)
# 注意:k/v 会在所有 query 头间共享 attn = scaled_dot_product_attention(q, k, v) return attn三、Grouped Query Attention
3.1 MQA 的问题
MQA 虽然大幅减少 KV 缓存,但也存在问题:
- 所有头共享同一个 K/V:可能丢失多样性信息
- 质量下降:某些任务精度降低
GQA(Grouped Query Attention)提出分组共享的折中方案。
3.2 GQA 原理
GQA 在 MHA 和 MQA 之间引入分组概念:
关键特性:
- 个 Query 头
- 个 Key/Value 组
- 每个 KV 组被多个 Query 头共享
3.3 配置对比
| 注意力类型 | Query 头数 | KV 头数 | 内存比 |
|---|---|---|---|
| MHA | h | h | 1.0 |
| GQA | h | g (g << h) | g/h |
| MQA | h | 1 | 1/h |
典型配置(LLaMA 3 8B):
- MHA:8 个 Query 头,8 个 KV 头
- GQA:8 个 Query 头,2 个 KV 组(仅 1/4 内存)
3.4 实现
class GroupedQueryAttention(nn.Module): def __init__(self, d_model, n_query_heads=8, n_kv_groups=2): super().__init__() self.n_query_heads = n_query_heads self.n_kv_groups = n_kv_groups self.head_dim = d_model // n_query_heads
# Query 头数 = 8 self.W_q = nn.Linear(d_model, d_model) # KV 组数 = 2 self.W_k = nn.Linear(d_model, self.head_dim * n_kv_groups) self.W_v = nn.Linear(d_model, self.head_dim * n_kv_groups)
def forward(self, x, kv_cache=None): # Query: [B, seq, 8 heads] q = self.W_q(x).view(batch, seq, self.n_query_heads, self.head_dim)
# KV: [B, seq, 2 groups, head_dim] k = self.W_k(x).view(batch, seq, self.n_kv_groups, self.head_dim) v = self.W_v(x).view(batch, seq, self.n_kv_groups, self.head_dim)
# 扩展 KV 组到 Query 头数 # [B, seq, 2 groups, head_dim] -> [B, seq, 8 heads, head_dim] k = k.repeat_interleave(self.n_query_heads // self.n_kv_groups, dim=2) v = v.repeat_interleave(self.n_query_heads // self.n_kv_groups, dim=2)
return scaled_dot_product_attention(q, k, v)四、实际应用
4.1 模型采用情况
| 模型 | 注意力类型 | 说明 |
|---|---|---|
| Llama 1/2 | MQA | 早期版本 |
| Llama 3 | GQA (8Q, 2KV) | 显著改进 |
| Mistral 7B | GQA (8Q, 2KV) | 高效长上下文 |
| PaLM 2 | GQA | 多种配置 |
| Gemini | GQA | 定制配置 |
4.2 性能收益
以 Mistral 7B 为例:
# 上下文长度 32k tokenscontext_length = 32768head_dim = 128
# MHA 内存mha_kv_memory = 2 * 8 * 32768 * 128 * 4 # bytes (FP16)# ≈ 268 MB per layer
# GQA 内存 (2 KV heads)gqa_kv_memory = 2 * 2 * 32768 * 128 * 4 # bytes (FP16)# ≈ 67 MB per layer (4x 节省)4.3 推理框架支持
# transformers 启用 GQAmodel = AutoModelForCausalLM.from_pretrained( "mistralai/Mistral-7B-v0.1", attn_implementation="flash_attention_2" # 启用优化的注意力实现)
# vLLM 自动检测并优化 GQAfrom vllm import LLMllm = LLM(model="mistralai/Mistral-7B-v0.1")五、与 Flash Attention 的结合
GQA/MQA 与 Flash Attention 结合效果更佳:
# 组合优化config = { "attn_implementation": "flash_attention_2", "use_sliding_window": True, "kv_cache_dtype": "auto",}
# Flash Attention 的 tiling 策略天然适配 GQA# 减少 KV 缓存的 HBM 访问次数六、总结
| 技术 | KV 头数 | 内存节省 | 精度影响 |
|---|---|---|---|
| MHA | h | 1x | 无 |
| GQA | g << h | g/h | 很小 |
| MQA | 1 | 1/h | 可察觉 |
GQA 已成为现代 LLM 的标准配置:
- Llama 3、Mistral 等主流模型采用
- 平衡内存效率与模型质量
- 是长上下文支持的关键技术
小结
MQA 通过共享 KV 头将 KV Cache 内存减少到 1/H,但可能损失质量。GQA 在 MHA 和 MQA 之间取折衷,用少量 KV 组实现接近 MHA 的质量。LLaMA 2+、Mistral、Qwen 等主流模型已全部采用 GQA,它已成为 LLM 推理优化的标准组件。配合 Flash Attention 使用,GQA 可显著降低推理延迟和内存占用。
常见问题 FAQ
6.1 GQA 的 KV head 数应该设多少?
研究表明 KV heads = num_attention_heads / 4 是较好的起点。例如 32 个注意力头用 8 个 KV head。LLaMA 2 70B 用 8 个 KV head(32 attention heads),Mistral 7B 用 8 个 KV head(32 attention heads)。
6.2 MQA 会导致质量下降吗?
会,MQA 在某些任务上比 MHA 下降 1-3%。但 GQA 基本弥补了这个差距(下降 <1%),同时保留了大部分推理加速优势。这就是为什么新模型都选择 GQA 而非 MQA。
6.3 GQA 和 Flash Attention 是什么关系?
互补关系。GQA 减少 KV Cache 大小(减少内存带宽),Flash Attention 减少 attention 计算中的 HBM 访问次数。两者结合效果叠加,是当前 LLM 推理的标准优化组合。
6.4 哪些模型使用了 GQA?
LLaMA 2(70B)、LLaMA 3(所有规模)、Mistral 7B、Mixtral 8x7B、Qwen 2、DeepSeek V2 等主流模型均采用 GQA。只有少数早期模型(GPT-3、PaLM)使用 MHA。
参考资料
支持与分享
如果这篇文章对你有帮助,欢迎支持作者或分享给更多人
部分信息可能已经过时






