mobile wallpaper 1mobile wallpaper 2mobile wallpaper 3mobile wallpaper 4
1127 字
4 分钟
MQA/GQA:KV 缓存优化与注意力变体
2025-01-28

本文要点#

大语言模型的推理效率是一个关键挑战。随着上下文增长,Transformer 的注意力计算复杂度为 O(n2)O(n^2),而且 KV 缓存(Key-Value Cache) 的内存占用也成为瓶颈。

Multi-Query Attention(MQA)Grouped Query Attention(GQA) 通过优化 KV 缓存机制,显著提升了推理吞吐量。

一、标准 Multi-Head Attention#

1.1 回顾#

标准 MHA(Multi-Head Attention)中,每个注意力头都有独立的 W_Q、W_K、W_V 投影:

graph TB A["输入 X"] --> B["Q = XW_Q"] A --> C["K = XW_K"] A --> D["V = XW_V"] B --> E["Multi-Head Attention"] C --> E D --> E E --> F["输出"]

对于 hh 个注意力头,每个头的 K 和 V 都是独立的。

1.2 KV 缓存问题#

推理时,需要缓存所有历史 token 的 Key 和 Value:

# 伪代码
k_cache = [] # 形状: [batch, heads, seq_len, head_dim]
v_cache = []
for token in new_tokens:
k, v = compute_kv(token)
k_cache.append(k)
v_cache.append(v)
attn_output = attention(q, k_cache, v_cache)

问题:

  • hh 个头的 K、V 都要缓存
  • 内存占用与 h×seq_lenh \times \text{seq\_len} 成正比
  • 长上下文场景内存爆炸

二、Multi-Query Attention#

2.1 核心思想#

MQA(Multi-Query Attention,2019)让所有注意力头共享同一个 Key 和 Value 投影

graph TB A["输入 X"] --> B["Q = XW_Q (多头)"] A --> C["K = XW_K (单一)"] A --> D["V = XW_V (单一)"] B --> E["Multi-Query Attention"] C --> E D --> E E --> F["输出"]

数学上:

  • 多组 Query:Q1,Q2,...,QhQ_1, Q_2, ..., Q_h
  • 单一 Key:KsharedK_{shared}
  • 单一 Value:VsharedV_{shared}

2.2 内存节省#

配置KV 头数KV 缓存内存
MHA (32 heads)3232x
MQA11x
GQA (8 groups)88x

对于 LLaMA 70B(80GB 显存):

  • MHA:KV 缓存需要 80GB+ 显存
  • MQA:大幅减少,可支持更长上下文

2.3 代码实现#

class MultiQueryAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.n_heads = n_heads
self.head_dim = d_model // n_heads
# 多 Query 头
self.W_q = nn.Linear(d_model, d_model)
# 单一 K/V 头
self.W_k = nn.Linear(d_model, self.head_dim)
self.W_v = nn.Linear(d_model, self.head_dim)
def forward(self, x, kv_cache=None):
# 多 Query
q = self.W_q(x).view(batch, seq, self.n_heads, self.head_dim)
# 单一 K/V
k = self.W_k(x).unsqueeze(1) # [B, 1, seq, head_dim]
v = self.W_v(x).unsqueeze(1)
if kv_cache is not None:
k = torch.cat([kv_cache['k'], k], dim=2)
v = torch.cat([kv_cache['v'], v], dim=2)
# 注意:k/v 会在所有 query 头间共享
attn = scaled_dot_product_attention(q, k, v)
return attn

三、Grouped Query Attention#

3.1 MQA 的问题#

MQA 虽然大幅减少 KV 缓存,但也存在问题:

  • 所有头共享同一个 K/V:可能丢失多样性信息
  • 质量下降:某些任务精度降低

GQA(Grouped Query Attention)提出分组共享的折中方案。

3.2 GQA 原理#

GQA 在 MHA 和 MQA 之间引入分组概念:

graph TB A["输入 X"] A --> B["Q_heads = XW_Q (8头)"] A --> C["K_groups = XW_K (2组)"] A --> D["V_groups = XW_V (2组)"] B --> E["Attention"] C --> E D --> E style C fill:#90EE90 style D fill:#90EE90

关键特性:

  • nquery=8n_{query} = 8 个 Query 头
  • nkv=2n_{kv} = 2 个 Key/Value 组
  • 每个 KV 组被多个 Query 头共享

3.3 配置对比#

注意力类型Query 头数KV 头数内存比
MHAhh1.0
GQAhg (g << h)g/h
MQAh11/h

典型配置(LLaMA 3 8B):

  • MHA:8 个 Query 头,8 个 KV 头
  • GQA:8 个 Query 头,2 个 KV 组(仅 1/4 内存)

3.4 实现#

class GroupedQueryAttention(nn.Module):
def __init__(self, d_model, n_query_heads=8, n_kv_groups=2):
super().__init__()
self.n_query_heads = n_query_heads
self.n_kv_groups = n_kv_groups
self.head_dim = d_model // n_query_heads
# Query 头数 = 8
self.W_q = nn.Linear(d_model, d_model)
# KV 组数 = 2
self.W_k = nn.Linear(d_model, self.head_dim * n_kv_groups)
self.W_v = nn.Linear(d_model, self.head_dim * n_kv_groups)
def forward(self, x, kv_cache=None):
# Query: [B, seq, 8 heads]
q = self.W_q(x).view(batch, seq, self.n_query_heads, self.head_dim)
# KV: [B, seq, 2 groups, head_dim]
k = self.W_k(x).view(batch, seq, self.n_kv_groups, self.head_dim)
v = self.W_v(x).view(batch, seq, self.n_kv_groups, self.head_dim)
# 扩展 KV 组到 Query 头数
# [B, seq, 2 groups, head_dim] -> [B, seq, 8 heads, head_dim]
k = k.repeat_interleave(self.n_query_heads // self.n_kv_groups, dim=2)
v = v.repeat_interleave(self.n_query_heads // self.n_kv_groups, dim=2)
return scaled_dot_product_attention(q, k, v)

四、实际应用#

4.1 模型采用情况#

模型注意力类型说明
Llama 1/2MQA早期版本
Llama 3GQA (8Q, 2KV)显著改进
Mistral 7BGQA (8Q, 2KV)高效长上下文
PaLM 2GQA多种配置
GeminiGQA定制配置

4.2 性能收益#

以 Mistral 7B 为例:

# 上下文长度 32k tokens
context_length = 32768
head_dim = 128
# MHA 内存
mha_kv_memory = 2 * 8 * 32768 * 128 * 4 # bytes (FP16)
# ≈ 268 MB per layer
# GQA 内存 (2 KV heads)
gqa_kv_memory = 2 * 2 * 32768 * 128 * 4 # bytes (FP16)
# ≈ 67 MB per layer (4x 节省)

4.3 推理框架支持#

# transformers 启用 GQA
model = AutoModelForCausalLM.from_pretrained(
"mistralai/Mistral-7B-v0.1",
attn_implementation="flash_attention_2" # 启用优化的注意力实现
)
# vLLM 自动检测并优化 GQA
from vllm import LLM
llm = LLM(model="mistralai/Mistral-7B-v0.1")

五、与 Flash Attention 的结合#

GQA/MQA 与 Flash Attention 结合效果更佳:

# 组合优化
config = {
"attn_implementation": "flash_attention_2",
"use_sliding_window": True,
"kv_cache_dtype": "auto",
}
# Flash Attention 的 tiling 策略天然适配 GQA
# 减少 KV 缓存的 HBM 访问次数

六、总结#

技术KV 头数内存节省精度影响
MHAh1x
GQAg << hg/h很小
MQA11/h可察觉

GQA 已成为现代 LLM 的标准配置:

  • Llama 3、Mistral 等主流模型采用
  • 平衡内存效率与模型质量
  • 是长上下文支持的关键技术

小结#

MQA 通过共享 KV 头将 KV Cache 内存减少到 1/H,但可能损失质量。GQA 在 MHA 和 MQA 之间取折衷,用少量 KV 组实现接近 MHA 的质量。LLaMA 2+、Mistral、Qwen 等主流模型已全部采用 GQA,它已成为 LLM 推理优化的标准组件。配合 Flash Attention 使用,GQA 可显著降低推理延迟和内存占用。

常见问题 FAQ#

6.1 GQA 的 KV head 数应该设多少?#

研究表明 KV heads = num_attention_heads / 4 是较好的起点。例如 32 个注意力头用 8 个 KV head。LLaMA 2 70B 用 8 个 KV head(32 attention heads),Mistral 7B 用 8 个 KV head(32 attention heads)。

6.2 MQA 会导致质量下降吗?#

会,MQA 在某些任务上比 MHA 下降 1-3%。但 GQA 基本弥补了这个差距(下降 <1%),同时保留了大部分推理加速优势。这就是为什么新模型都选择 GQA 而非 MQA。

6.3 GQA 和 Flash Attention 是什么关系?#

互补关系。GQA 减少 KV Cache 大小(减少内存带宽),Flash Attention 减少 attention 计算中的 HBM 访问次数。两者结合效果叠加,是当前 LLM 推理的标准优化组合。

6.4 哪些模型使用了 GQA?#

LLaMA 2(70B)、LLaMA 3(所有规模)、Mistral 7B、Mixtral 8x7B、Qwen 2、DeepSeek V2 等主流模型均采用 GQA。只有少数早期模型(GPT-3、PaLM)使用 MHA。

参考资料#

支持与分享

如果这篇文章对你有帮助,欢迎支持作者或分享给更多人

MQA/GQA:KV 缓存优化与注意力变体
https://blog.souloss.com/posts/machine-learning/llm-paper-history/mqa-and-gqa-attention/
作者
Souloss
发布于
2025-01-28
许可协议
CC BY-NC-SA 4.0

部分信息可能已经过时