mobile wallpaper 1mobile wallpaper 2mobile wallpaper 3mobile wallpaper 4
1029 字
3 分钟
注意力残差与 Kimi 架构创新:重新定义深度学习残差连接
2025-02-14

2025 年,马斯克在 X 上转发了一篇论文,配文:「This is interesting.」

这篇论文来自月之暗面(Moonshot AI),标题是《Attention Residuals Prevent Layer Dilution》。

它提出了一个看似简单但影响深远的问题:残差连接真的完美吗?

研究团队发现,传统的残差连接在深层网络中会导致「层稀释」问题。他们提出的解决方案——注意力残差(Attention Residuals),用动态注意力机制替代固定的残差连接,在 48B 参数的模型上取得了显著提升。

这是对 Transformer 架构的一次深度反思和创新。

本文要点#

  • 残差连接的「层稀释」问题
  • PreNorm 与 PostNorm 的权衡
  • Attention Residuals 核心思想
  • Block AttnRes 优化策略
  • 48B 参数规模实验结果
  • 对未来模型架构的启示

一、残差连接:从救星到瓶颈#

1.1 残差连接的诞生#

2015 年,ResNet 论文提出了残差连接,解决了深层网络的训练难题。

# 标准残差连接
def residual_block(x, layer):
return x + layer(x) # 跳跃连接

这个简单的加法操作,让网络可以突破 100 层甚至 1000 层的限制。

1.2 Transformer 中的残差#

Transformer 继承了残差连接的设计:

# Transformer 层
def transformer_layer(x, attn, ffn, norm):
# 注意力子层
x = x + attn(norm(x)) # Pre-Norm 残差
# FFN 子层
x = x + ffn(norm(x)) # Pre-Norm 残差
return x

问题在于:每一层对最终输出的贡献是固定的 1。

1.3 层稀释问题#

flowchart TD A["输入 x₀"] --> B["Layer 1"] B --> C["x₁ = x₀ + Δ₁"] C --> D["Layer 2"] D --> E["x₂ = x₁ + Δ₂ = x₀ + Δ₁ + Δ₂"] E --> F["..."] F --> G["xₙ = x₀ + ΣΔᵢ"] H["问题:每一层贡献固定为 1"] --> I["深层贡献被稀释"] I --> J["x₀ 在输出中占比 1/n+1"] style H fill:#ffeb3b style I fill:#ff9800 style J fill:#f44336,color:#fff
┌─────────────────────────────────────────────────────────────┐
│ 层稀释问题分析 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 数学推导: │
│ │
│ 输出 = x₀ + Δ₁ + Δ₂ + ... + Δₙ │
│ │
│ 如果每层的 Δᵢ 期望为 0(这在训练良好的模型中常见): │
│ │
│ 输出的方差主要来自 x₀ 和各层噪声的累积 │
│ │
│ 问题表现: │
│ • 第 n 层的信号在最终输出中只占 1/(n+1) │
│ • 深层学习到的特征被「稀释」 │
│ • 模型倾向于「复制」而非「变换」 │
│ • 层数增加但效果收益递减 │
│ │
│ 具体数值: │
│ • 12 层模型:每层贡献约 7.7% │
│ • 24 层模型:每层贡献约 4.0% │
│ • 48 层模型:每层贡献约 2.0% │
│ │
└─────────────────────────────────────────────────────────────┘

二、PreNorm vs PostNorm:权衡的艺术#

2.1 两种归一化方式#

flowchart LR subgraph PostNorm A1[x] --> B1[Attention] B1 --> C1[Add] A1 --> C1 C1 --> D1[LayerNorm] end subgraph PreNorm A2[x] --> B2[LayerNorm] B2 --> C2[Attention] C2 --> D2[Add] A2 --> D2 end

2.2 各自的问题#

┌─────────────────────────────────────────────────────────────┐
│ PostNorm 问题 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 优点: │
│ • 残差路径更「纯净」 │
│ • 信息传递更直接 │
│ │
│ 缺点: │
│ • 训练不稳定(梯度爆炸/消失) │
│ • 难以训练非常深的网络 │
│ • 需要精细的学习率调节 │
│ │
│ 为什么不稳定? │
│ • 残差分支的方差会逐层累积 │
│ • LayerNorm 在残差之后,无法控制中间激活 │
│ │
└─────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────┐
│ PreNorm 问题 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 优点: │
│ • 训练稳定 │
│ • 可以训练更深的网络 │
│ • 是当前主流选择 │
│ │
│ 缺点: │
│ • 层稀释问题更严重 │
│ • LayerNorm 后信号被「压扁」 │
│ • 深层贡献进一步降低 │
│ │
│ 为什么稀释更严重? │
│ • norm(x) 会压缩 x 的方差 │
│ • 每层实际看到的是「压缩后」的信号 │
│ • 原始 x 的信息被重复但未被有效利用 │
│ │
└─────────────────────────────────────────────────────────────┘

2.3 现有的改进尝试#

# 1. Sandwich Norm
def sandwich_norm(x, attn, norm1, norm2):
# 在残差前后都加 Norm
return norm2(x + attn(norm1(x)))
# 2. DeepNorm
def deep_norm(x, attn, norm, alpha, beta):
# 放大残差,缩小分支
return norm(x) + alpha * attn(norm(beta * x))
# 3. RealFormer
def realformer(x, attn, prev_attn):
# 在注意力中加入上层的残差
return x + attn(x) + prev_attn

但这些都没有从根本上解决「固定贡献」的问题。


三、Attention Residuals:动态替代固定#

3.1 核心思想#

Kimi 团队的洞察是:既然固定的 1:1 残差连接有问题,为什么不动态学习残差的贡献权重?

flowchart TB subgraph 传统残差 A1[x] --> B1["+"] C1["f(x)"] --> B1 B1 --> D1["x + f(x)<br/>固定权重 1:1"] end subgraph 注意力残差 A2[x] --> B2[Attention] C2["f(x)"] --> B2 B2 --> D2["Attn(x, f(x))<br/>动态权重"] end

3.2 Attention Residuals 定义#

def attention_residual(x, delta, d_k):
"""
使用注意力机制计算残差连接
Args:
x: 原始输入 [batch, seq_len, dim]
delta: 层变换输出 [batch, seq_len, dim]
d_k: 缩放因子
Returns:
动态加权后的输出
"""
# 将 x 和 delta 作为 Query 和 Key
Q = x @ W_q # [batch, seq_len, dim]
K = delta @ W_k # [batch, seq_len, dim]
V = delta @ W_v # [batch, seq_len, dim]
# 计算注意力权重
scores = (Q @ K.transpose(-2, -1)) / math.sqrt(d_k)
attn_weights = softmax(scores, dim=-1)
# 应用注意力
output = attn_weights @ V
return output

3.3 为什么这样做有效?#

┌─────────────────────────────────────────────────────────────┐
│ Attention Residuals 优势 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 1. 动态权重分配 │
│ • 模型可以学习何时「信任」原始输入 │
│ • 可以学习何时「信任」变换结果 │
│ • 不同位置可以有不同的权重 │
│ │
│ 2. 内容感知 │
│ • 根据内容相关性决定残差权重 │
│ • 有用的变换获得更高权重 │
│ • 无用或有害的变换被抑制 │
│ │
│ 3. 解决层稀释 │
│ • 不再受限于固定的 1:1 分配 │
│ • 深层可以学习「放大」自己的贡献 │
│ • 模型自动平衡各层的重要性 │
│ │
│ 4. 表达能力提升 │
│ • 注意力机制引入额外的建模能力 │
│ • 可以捕获 x 和 delta 之间的复杂关系 │
│ • 等价于增加了模型容量 │
│ │
└─────────────────────────────────────────────────────────────┘

四、Block AttnRes:高效的实现#

4.1 计算复杂度问题#

直接应用 Attention Residuals 会增加计算成本:

标准残差:O(d) # 简单加法
注意力残差:O(d² + L²d) # 需要额外的注意力计算

4.2 Block AttnRes 优化#

Kimi 团队提出了 Block AttnRes,在保持效果的同时降低计算成本。

flowchart TB A["输入 x [L, d]"] --> B["分块 [L, d/k]"] B --> C["块内注意力"] C --> D["合并 [L, d]"] E["核心思想:将 d 维分成 k 块"] --> F["每块独立计算注意力"] F --> G["复杂度从 O(d²) 降到 O(d²/k)"]
def block_attnres(x, delta, num_blocks=4):
"""
Block Attention Residual
将特征维度分块,每块独立计算注意力残差
"""
batch_size, seq_len, dim = x.shape
block_dim = dim // num_blocks
outputs = []
for i in range(num_blocks):
# 分块
start = i * block_dim
end = start + block_dim
x_block = x[:, :, start:end]
delta_block = delta[:, :, start:end]
# 块内注意力残差
block_output = attention_residual(x_block, delta_block, block_dim)
outputs.append(block_output)
# 合并
return torch.cat(outputs, dim=-1)

4.3 不同变体对比#

┌─────────────────────────────────────────────────────────────┐
│ AttnRes 变体对比 │
├──────────────┬──────────────┬────────────┬──────────────────┤
│ 方法 │ 计算复杂度 │ 参数增量 │ 效果 │
├──────────────┼──────────────┼────────────┼──────────────────┤
│ Standard │ O(d) │ 0 │ 基线 │
│ Full AttnRes │ O(d² + L²d) │ 3d² │ 最佳 │
│ Block AttnRes│ O(d²/k) │ 3d²/k │ 接近最佳 │
│ Linear Attn │ O(dL) │ 2d² │ 稍差 │
└──────────────┴──────────────┴────────────┴──────────────────┘

五、实验结果:48B 参数验证#

5.1 实验设置#

┌─────────────────────────────────────────────────────────────┐
│ 实验配置 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 模型规模 │
│ ├── 参数量:48B(480 亿) │
│ ├── 层数:64 层 │
│ ├── 隐藏维度:8192 │
│ ├── 注意力头:64 │
│ └── 训练 token:2T(2 万亿) │
│ │
│ 对比方法 │
│ ├── PreNorm(标准残差) │
│ ├── DeepNorm │
│ ├── Sandwich Norm │
│ └── AttnRes(本文方法) │
│ │
│ 评估基准 │
│ ├── 语言建模困惑度 │
│ ├── 下游任务性能 │
│ └── 训练稳定性 │
│ │
└─────────────────────────────────────────────────────────────┘

5.2 主要结果#

xychart-beta title "48B 模型困惑度对比(越低越好)" x-axis ["PreNorm", "DeepNorm", "Sandwich", "AttnRes"] y-axis "困惑度" 4.0 --> 4.5 bar [4.35, 4.28, 4.25, 4.18]
┌─────────────────────────────────────────────────────────────┐
│ 关键实验结果 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 困惑度(PPL)降低 │
│ ├── 相比 PreNorm:-3.9% │
│ ├── 相比 DeepNorm:-2.3% │
│ └── 相比 Sandwich:-1.6% │
│ │
│ 下游任务提升 │
│ ├── MMLU:+2.1% │
│ ├── GSM8K:+3.5% │
│ ├── HumanEval:+4.2% │
│ └── 平均:+2.8% │
│ │
│ 训练稳定性 │
│ ├── 梯度范数更稳定 │
│ ├── 无需 warmup 调整 │
│ └── 损失曲线更平滑 │
│ │
│ 层贡献分析 │
│ ├── 深层贡献显著增加 │
│ ├── 各层权重分布更均匀 │
│ └── 解决了「复制偏好」问题 │
│ │
└─────────────────────────────────────────────────────────────┘

5.3 层贡献可视化#

xychart-beta title "各层对输出的贡献权重" x-axis ["Layer 1", "Layer 16", "Layer 32", "Layer 48", "Layer 64"] y-axis "贡献权重" 0 --> 3 line [1.0, 1.0, 1.0, 1.0, 1.0] line [0.8, 1.2, 1.5, 2.0, 2.5]
图中:
- 蓝线:标准残差(固定权重 1.0)
- 红线:AttnRes(动态学习权重)
可以看到 AttnRes 的深层贡献权重显著高于浅层,
说明模型学会了「信任」深层的变换结果。

六、为什么马斯克点赞?#

6.1 技术创新点#

┌─────────────────────────────────────────────────────────────┐
│ 创新点总结 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 1. 问题识别精准 │
│ • 发现了残差连接的「层稀释」问题 │
│ • 数学推导清晰,现象可观测 │
│ │
│ 2. 解决方案优雅 │
│ • 用注意力替代固定残差 │
│ • 与现有架构兼容 │
│ • 不引入过多复杂度 │
│ │
│ 3. 验证充分 │
│ • 48B 参数规模验证 │
│ • 多个基线对比 │
│ • 实验结果显著 │
│ │
│ 4. 实用价值高 │
│ • 可以直接应用于现有模型 │
│ • 对训练稳定性有帮助 │
│ • 无需额外数据或计算资源 │
│ │
└─────────────────────────────────────────────────────────────┘

6.2 对行业的影响#

短期影响:
• 新模型可能采用 AttnRes 架构
• 现有模型可能进行架构升级
• 相关研究可能进一步深入
长期影响:
• 可能改变 Transformer 架构设计范式
• 影响下一代大模型架构
• 推动对残差连接的重新思考
商业价值:
• 相同算力下更好的模型性能
• 可能降低模型训练成本
• 提升模型竞争力

七、代码实现参考#

7.1 完整的 AttnRes Layer#

import torch
import torch.nn as nn
import math
class AttentionResidual(nn.Module):
"""注意力残差模块"""
def __init__(self, dim, num_heads=8, qkv_bias=True, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
# Q, K, V 投影
self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
self.k_proj = nn.Linear(dim, dim, bias=qkv_bias)
self.v_proj = nn.Linear(dim, dim, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, delta):
"""
Args:
x: 原始输入 [B, L, D]
delta: 变换输出 [B, L, D]
Returns:
残差输出 [B, L, D]
"""
B, L, D = x.shape
# 计算注意力
Q = self.q_proj(x).reshape(B, L, self.num_heads, self.head_dim).transpose(1, 2)
K = self.k_proj(delta).reshape(B, L, self.num_heads, self.head_dim).transpose(1, 2)
V = self.v_proj(delta).reshape(B, L, self.num_heads, self.head_dim).transpose(1, 2)
attn = (Q @ K.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ V).transpose(1, 2).reshape(B, L, D)
x = self.proj(x)
x = self.proj_drop(x)
return x
class AttnResTransformerLayer(nn.Module):
"""使用注意力残差的 Transformer 层"""
def __init__(self, dim, num_heads, mlp_ratio=4., drop=0., attn_drop=0.):
super().__init__()
# Layer Norm
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
# Self Attention
self.attn = nn.MultiheadAttention(dim, num_heads, dropout=attn_drop, batch_first=True)
# Attention Residual
self.attn_res = AttentionResidual(dim, num_heads, attn_drop=attn_drop)
# FFN
mlp_dim = int(dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_dim),
nn.GELU(),
nn.Dropout(drop),
nn.Linear(mlp_dim, dim),
nn.Dropout(drop)
)
# FFN 的 Attention Residual
self.ffn_res = AttentionResidual(dim, num_heads=1, attn_drop=attn_drop)
def forward(self, x):
# Self Attention
attn_out, _ = self.attn(self.norm1(x), self.norm1(x), self.norm1(x))
# 使用 AttnRes 替代标准残差
x = self.attn_res(x, attn_out)
# FFN
ffn_out = self.mlp(self.norm2(x))
# 使用 AttnRes 替代标准残差
x = self.ffn_res(x, ffn_out)
return x

7.2 与标准 Transformer 的对比#

# 标准 Transformer 层
class StandardTransformerLayer(nn.Module):
def forward(self, x):
# 固定残差权重 1:1
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
# AttnRes Transformer 层
class AttnResTransformerLayer(nn.Module):
def forward(self, x):
# 动态学习的残差权重
attn_out = self.attn(self.norm1(x))
x = self.attn_res(x, attn_out) # 注意力决定权重
ffn_out = self.mlp(self.norm2(x))
x = self.ffn_res(x, ffn_out) # 注意力决定权重
return x

八、局限与未来方向#

8.1 当前局限#

┌─────────────────────────────────────────────────────────────┐
│ 方法局限 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 计算开销 │
│ • 增加了注意力计算的参数 │
│ • Block AttnRes 可以缓解但仍有一定开销 │
│ │
│ 超参数敏感 │
│ • Block 数量需要调优 │
│ • 初始化方式影响收敛 │
│ │
│ 适用范围 │
│ • 主要在深层模型上效果显著 │
│ • 浅层模型收益有限 │
│ │
│ 理论解释 │
│ • 为什么有效仍有待深入分析 │
│ • 最优权重分布的理论指导 │
│ │
└─────────────────────────────────────────────────────────────┘

8.2 未来研究方向#

1. 效率优化
• 更高效的注意力计算
• 稀疏注意力变体
• 硬件感知优化
2. 理论深化
• 残差连接的最优形式
• 层稀释的理论分析
• 最优深度与宽度的关系
3. 应用拓展
• 多模态模型
• MoE 架构
• 推理优化
4. 架构创新
• 与其他改进的结合
• 动态深度调整
• 条件残差连接

常见问题 FAQ#

Q1:AttnRes 和标准残差可以混用吗?

A:可以。论文建议在深层使用 AttnRes,浅层可以使用标准残差,以平衡效果和效率。

Q2:Block AttnRes 的块数如何选择?

A:论文推荐 4-8 块。块数越多效果越好,但计算成本也越高。需要根据模型规模和资源权衡。

Q3:AttnRes 对训练稳定性有影响吗?

A:实验显示 AttnRes 的训练稳定性与 PreNorm 相当,甚至更好。注意力机制天然的归一化效果有助于稳定训练。

Q4:现有模型可以直接迁移使用 AttnRes 吗?

A:不能直接迁移。需要重新训练,因为 AttnRes 改变了模型的计算方式。但可以复用预训练的注意力权重作为初始化。

Q5:这个方法对推理速度有影响吗?

A:推理时会增加少量计算(额外的注意力计算)。Block AttnRes 可以控制开销在可接受范围内。


小结#

Attention Residuals 是对 Transformer 架构的一次深刻反思。

核心贡献:

┌─────────────────────────────────────────────────────────────┐
│ AttnRes 核心总结 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 发现问题:残差连接导致层稀释 │
│ │
│ 提出方案:用注意力机制动态学习残差权重 │
│ │
│ 验证效果:48B 模型上困惑度降低 3.9% │
│ │
│ 工程优化:Block AttnRes 降低计算开销 │
│ │
│ 行业影响:获得马斯克点赞,可能影响下一代模型架构 │
│ │
└─────────────────────────────────────────────────────────────┘

启示:

即使是 Transformer 这样成熟的架构,仍有改进空间。关键在于发现并定义正确的问题。


参考资料#

支持与分享

如果这篇文章对你有帮助,欢迎支持作者或分享给更多人

注意力残差与 Kimi 架构创新:重新定义深度学习残差连接
https://blog.souloss.com/posts/machine-learning/agent-guide/attention-residual-and-kimi-architecture/
作者
Souloss
发布于
2025-02-14
许可协议
CC BY-NC-SA 4.0

部分信息可能已经过时