2025 年,马斯克在 X 上转发了一篇论文,配文:「This is interesting.」
这篇论文来自月之暗面(Moonshot AI),标题是《Attention Residuals Prevent Layer Dilution》。
它提出了一个看似简单但影响深远的问题:残差连接真的完美吗?
研究团队发现,传统的残差连接在深层网络中会导致「层稀释」问题。他们提出的解决方案——注意力残差(Attention Residuals),用动态注意力机制替代固定的残差连接,在 48B 参数的模型上取得了显著提升。
这是对 Transformer 架构的一次深度反思和创新。
本文要点
- 残差连接的「层稀释」问题
- PreNorm 与 PostNorm 的权衡
- Attention Residuals 核心思想
- Block AttnRes 优化策略
- 48B 参数规模实验结果
- 对未来模型架构的启示
一、残差连接:从救星到瓶颈
1.1 残差连接的诞生
2015 年,ResNet 论文提出了残差连接,解决了深层网络的训练难题。
# 标准残差连接def residual_block(x, layer): return x + layer(x) # 跳跃连接这个简单的加法操作,让网络可以突破 100 层甚至 1000 层的限制。
1.2 Transformer 中的残差
Transformer 继承了残差连接的设计:
# Transformer 层def transformer_layer(x, attn, ffn, norm): # 注意力子层 x = x + attn(norm(x)) # Pre-Norm 残差 # FFN 子层 x = x + ffn(norm(x)) # Pre-Norm 残差 return x问题在于:每一层对最终输出的贡献是固定的 1。
1.3 层稀释问题
┌─────────────────────────────────────────────────────────────┐│ 层稀释问题分析 │├─────────────────────────────────────────────────────────────┤│ ││ 数学推导: ││ ││ 输出 = x₀ + Δ₁ + Δ₂ + ... + Δₙ ││ ││ 如果每层的 Δᵢ 期望为 0(这在训练良好的模型中常见): ││ ││ 输出的方差主要来自 x₀ 和各层噪声的累积 ││ ││ 问题表现: ││ • 第 n 层的信号在最终输出中只占 1/(n+1) ││ • 深层学习到的特征被「稀释」 ││ • 模型倾向于「复制」而非「变换」 ││ • 层数增加但效果收益递减 ││ ││ 具体数值: ││ • 12 层模型:每层贡献约 7.7% ││ • 24 层模型:每层贡献约 4.0% ││ • 48 层模型:每层贡献约 2.0% ││ │└─────────────────────────────────────────────────────────────┘二、PreNorm vs PostNorm:权衡的艺术
2.1 两种归一化方式
2.2 各自的问题
┌─────────────────────────────────────────────────────────────┐│ PostNorm 问题 │├─────────────────────────────────────────────────────────────┤│ ││ 优点: ││ • 残差路径更「纯净」 ││ • 信息传递更直接 ││ ││ 缺点: ││ • 训练不稳定(梯度爆炸/消失) ││ • 难以训练非常深的网络 ││ • 需要精细的学习率调节 ││ ││ 为什么不稳定? ││ • 残差分支的方差会逐层累积 ││ • LayerNorm 在残差之后,无法控制中间激活 ││ │└─────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────┐│ PreNorm 问题 │├─────────────────────────────────────────────────────────────┤│ ││ 优点: ││ • 训练稳定 ││ • 可以训练更深的网络 ││ • 是当前主流选择 ││ ││ 缺点: ││ • 层稀释问题更严重 ││ • LayerNorm 后信号被「压扁」 ││ • 深层贡献进一步降低 ││ ││ 为什么稀释更严重? ││ • norm(x) 会压缩 x 的方差 ││ • 每层实际看到的是「压缩后」的信号 ││ • 原始 x 的信息被重复但未被有效利用 ││ │└─────────────────────────────────────────────────────────────┘2.3 现有的改进尝试
# 1. Sandwich Normdef sandwich_norm(x, attn, norm1, norm2): # 在残差前后都加 Norm return norm2(x + attn(norm1(x)))
# 2. DeepNormdef deep_norm(x, attn, norm, alpha, beta): # 放大残差,缩小分支 return norm(x) + alpha * attn(norm(beta * x))
# 3. RealFormerdef realformer(x, attn, prev_attn): # 在注意力中加入上层的残差 return x + attn(x) + prev_attn但这些都没有从根本上解决「固定贡献」的问题。
三、Attention Residuals:动态替代固定
3.1 核心思想
Kimi 团队的洞察是:既然固定的 1:1 残差连接有问题,为什么不动态学习残差的贡献权重?
3.2 Attention Residuals 定义
def attention_residual(x, delta, d_k): """ 使用注意力机制计算残差连接
Args: x: 原始输入 [batch, seq_len, dim] delta: 层变换输出 [batch, seq_len, dim] d_k: 缩放因子
Returns: 动态加权后的输出 """ # 将 x 和 delta 作为 Query 和 Key Q = x @ W_q # [batch, seq_len, dim] K = delta @ W_k # [batch, seq_len, dim] V = delta @ W_v # [batch, seq_len, dim]
# 计算注意力权重 scores = (Q @ K.transpose(-2, -1)) / math.sqrt(d_k) attn_weights = softmax(scores, dim=-1)
# 应用注意力 output = attn_weights @ V
return output3.3 为什么这样做有效?
┌─────────────────────────────────────────────────────────────┐│ Attention Residuals 优势 │├─────────────────────────────────────────────────────────────┤│ ││ 1. 动态权重分配 ││ • 模型可以学习何时「信任」原始输入 ││ • 可以学习何时「信任」变换结果 ││ • 不同位置可以有不同的权重 ││ ││ 2. 内容感知 ││ • 根据内容相关性决定残差权重 ││ • 有用的变换获得更高权重 ││ • 无用或有害的变换被抑制 ││ ││ 3. 解决层稀释 ││ • 不再受限于固定的 1:1 分配 ││ • 深层可以学习「放大」自己的贡献 ││ • 模型自动平衡各层的重要性 ││ ││ 4. 表达能力提升 ││ • 注意力机制引入额外的建模能力 ││ • 可以捕获 x 和 delta 之间的复杂关系 ││ • 等价于增加了模型容量 ││ │└─────────────────────────────────────────────────────────────┘四、Block AttnRes:高效的实现
4.1 计算复杂度问题
直接应用 Attention Residuals 会增加计算成本:
标准残差:O(d) # 简单加法注意力残差:O(d² + L²d) # 需要额外的注意力计算4.2 Block AttnRes 优化
Kimi 团队提出了 Block AttnRes,在保持效果的同时降低计算成本。
def block_attnres(x, delta, num_blocks=4): """ Block Attention Residual
将特征维度分块,每块独立计算注意力残差 """ batch_size, seq_len, dim = x.shape block_dim = dim // num_blocks
outputs = [] for i in range(num_blocks): # 分块 start = i * block_dim end = start + block_dim
x_block = x[:, :, start:end] delta_block = delta[:, :, start:end]
# 块内注意力残差 block_output = attention_residual(x_block, delta_block, block_dim) outputs.append(block_output)
# 合并 return torch.cat(outputs, dim=-1)4.3 不同变体对比
┌─────────────────────────────────────────────────────────────┐│ AttnRes 变体对比 │├──────────────┬──────────────┬────────────┬──────────────────┤│ 方法 │ 计算复杂度 │ 参数增量 │ 效果 │├──────────────┼──────────────┼────────────┼──────────────────┤│ Standard │ O(d) │ 0 │ 基线 ││ Full AttnRes │ O(d² + L²d) │ 3d² │ 最佳 ││ Block AttnRes│ O(d²/k) │ 3d²/k │ 接近最佳 ││ Linear Attn │ O(dL) │ 2d² │ 稍差 │└──────────────┴──────────────┴────────────┴──────────────────┘五、实验结果:48B 参数验证
5.1 实验设置
┌─────────────────────────────────────────────────────────────┐│ 实验配置 │├─────────────────────────────────────────────────────────────┤│ ││ 模型规模 ││ ├── 参数量:48B(480 亿) ││ ├── 层数:64 层 ││ ├── 隐藏维度:8192 ││ ├── 注意力头:64 ││ └── 训练 token:2T(2 万亿) ││ ││ 对比方法 ││ ├── PreNorm(标准残差) ││ ├── DeepNorm ││ ├── Sandwich Norm ││ └── AttnRes(本文方法) ││ ││ 评估基准 ││ ├── 语言建模困惑度 ││ ├── 下游任务性能 ││ └── 训练稳定性 ││ │└─────────────────────────────────────────────────────────────┘5.2 主要结果
┌─────────────────────────────────────────────────────────────┐│ 关键实验结果 │├─────────────────────────────────────────────────────────────┤│ ││ 困惑度(PPL)降低 ││ ├── 相比 PreNorm:-3.9% ││ ├── 相比 DeepNorm:-2.3% ││ └── 相比 Sandwich:-1.6% ││ ││ 下游任务提升 ││ ├── MMLU:+2.1% ││ ├── GSM8K:+3.5% ││ ├── HumanEval:+4.2% ││ └── 平均:+2.8% ││ ││ 训练稳定性 ││ ├── 梯度范数更稳定 ││ ├── 无需 warmup 调整 ││ └── 损失曲线更平滑 ││ ││ 层贡献分析 ││ ├── 深层贡献显著增加 ││ ├── 各层权重分布更均匀 ││ └── 解决了「复制偏好」问题 ││ │└─────────────────────────────────────────────────────────────┘5.3 层贡献可视化
图中:- 蓝线:标准残差(固定权重 1.0)- 红线:AttnRes(动态学习权重)
可以看到 AttnRes 的深层贡献权重显著高于浅层,说明模型学会了「信任」深层的变换结果。六、为什么马斯克点赞?
6.1 技术创新点
┌─────────────────────────────────────────────────────────────┐│ 创新点总结 │├─────────────────────────────────────────────────────────────┤│ ││ 1. 问题识别精准 ││ • 发现了残差连接的「层稀释」问题 ││ • 数学推导清晰,现象可观测 ││ ││ 2. 解决方案优雅 ││ • 用注意力替代固定残差 ││ • 与现有架构兼容 ││ • 不引入过多复杂度 ││ ││ 3. 验证充分 ││ • 48B 参数规模验证 ││ • 多个基线对比 ││ • 实验结果显著 ││ ││ 4. 实用价值高 ││ • 可以直接应用于现有模型 ││ • 对训练稳定性有帮助 ││ • 无需额外数据或计算资源 ││ │└─────────────────────────────────────────────────────────────┘6.2 对行业的影响
短期影响:• 新模型可能采用 AttnRes 架构• 现有模型可能进行架构升级• 相关研究可能进一步深入
长期影响:• 可能改变 Transformer 架构设计范式• 影响下一代大模型架构• 推动对残差连接的重新思考
商业价值:• 相同算力下更好的模型性能• 可能降低模型训练成本• 提升模型竞争力七、代码实现参考
7.1 完整的 AttnRes Layer
import torchimport torch.nn as nnimport math
class AttentionResidual(nn.Module): """注意力残差模块"""
def __init__(self, dim, num_heads=8, qkv_bias=True, attn_drop=0., proj_drop=0.): super().__init__() self.num_heads = num_heads self.head_dim = dim // num_heads self.scale = self.head_dim ** -0.5
# Q, K, V 投影 self.q_proj = nn.Linear(dim, dim, bias=qkv_bias) self.k_proj = nn.Linear(dim, dim, bias=qkv_bias) self.v_proj = nn.Linear(dim, dim, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, delta): """ Args: x: 原始输入 [B, L, D] delta: 变换输出 [B, L, D] Returns: 残差输出 [B, L, D] """ B, L, D = x.shape
# 计算注意力 Q = self.q_proj(x).reshape(B, L, self.num_heads, self.head_dim).transpose(1, 2) K = self.k_proj(delta).reshape(B, L, self.num_heads, self.head_dim).transpose(1, 2) V = self.v_proj(delta).reshape(B, L, self.num_heads, self.head_dim).transpose(1, 2)
attn = (Q @ K.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn)
x = (attn @ V).transpose(1, 2).reshape(B, L, D) x = self.proj(x) x = self.proj_drop(x)
return x
class AttnResTransformerLayer(nn.Module): """使用注意力残差的 Transformer 层"""
def __init__(self, dim, num_heads, mlp_ratio=4., drop=0., attn_drop=0.): super().__init__()
# Layer Norm self.norm1 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim)
# Self Attention self.attn = nn.MultiheadAttention(dim, num_heads, dropout=attn_drop, batch_first=True)
# Attention Residual self.attn_res = AttentionResidual(dim, num_heads, attn_drop=attn_drop)
# FFN mlp_dim = int(dim * mlp_ratio) self.mlp = nn.Sequential( nn.Linear(dim, mlp_dim), nn.GELU(), nn.Dropout(drop), nn.Linear(mlp_dim, dim), nn.Dropout(drop) )
# FFN 的 Attention Residual self.ffn_res = AttentionResidual(dim, num_heads=1, attn_drop=attn_drop)
def forward(self, x): # Self Attention attn_out, _ = self.attn(self.norm1(x), self.norm1(x), self.norm1(x)) # 使用 AttnRes 替代标准残差 x = self.attn_res(x, attn_out)
# FFN ffn_out = self.mlp(self.norm2(x)) # 使用 AttnRes 替代标准残差 x = self.ffn_res(x, ffn_out)
return x7.2 与标准 Transformer 的对比
# 标准 Transformer 层class StandardTransformerLayer(nn.Module): def forward(self, x): # 固定残差权重 1:1 x = x + self.attn(self.norm1(x)) x = x + self.mlp(self.norm2(x)) return x
# AttnRes Transformer 层class AttnResTransformerLayer(nn.Module): def forward(self, x): # 动态学习的残差权重 attn_out = self.attn(self.norm1(x)) x = self.attn_res(x, attn_out) # 注意力决定权重
ffn_out = self.mlp(self.norm2(x)) x = self.ffn_res(x, ffn_out) # 注意力决定权重 return x八、局限与未来方向
8.1 当前局限
┌─────────────────────────────────────────────────────────────┐│ 方法局限 │├─────────────────────────────────────────────────────────────┤│ ││ 计算开销 ││ • 增加了注意力计算的参数 ││ • Block AttnRes 可以缓解但仍有一定开销 ││ ││ 超参数敏感 ││ • Block 数量需要调优 ││ • 初始化方式影响收敛 ││ ││ 适用范围 ││ • 主要在深层模型上效果显著 ││ • 浅层模型收益有限 ││ ││ 理论解释 ││ • 为什么有效仍有待深入分析 ││ • 最优权重分布的理论指导 ││ │└─────────────────────────────────────────────────────────────┘8.2 未来研究方向
1. 效率优化 • 更高效的注意力计算 • 稀疏注意力变体 • 硬件感知优化
2. 理论深化 • 残差连接的最优形式 • 层稀释的理论分析 • 最优深度与宽度的关系
3. 应用拓展 • 多模态模型 • MoE 架构 • 推理优化
4. 架构创新 • 与其他改进的结合 • 动态深度调整 • 条件残差连接常见问题 FAQ
Q1:AttnRes 和标准残差可以混用吗?
A:可以。论文建议在深层使用 AttnRes,浅层可以使用标准残差,以平衡效果和效率。
Q2:Block AttnRes 的块数如何选择?
A:论文推荐 4-8 块。块数越多效果越好,但计算成本也越高。需要根据模型规模和资源权衡。
Q3:AttnRes 对训练稳定性有影响吗?
A:实验显示 AttnRes 的训练稳定性与 PreNorm 相当,甚至更好。注意力机制天然的归一化效果有助于稳定训练。
Q4:现有模型可以直接迁移使用 AttnRes 吗?
A:不能直接迁移。需要重新训练,因为 AttnRes 改变了模型的计算方式。但可以复用预训练的注意力权重作为初始化。
Q5:这个方法对推理速度有影响吗?
A:推理时会增加少量计算(额外的注意力计算)。Block AttnRes 可以控制开销在可接受范围内。
小结
Attention Residuals 是对 Transformer 架构的一次深刻反思。
核心贡献:
┌─────────────────────────────────────────────────────────────┐│ AttnRes 核心总结 │├─────────────────────────────────────────────────────────────┤│ ││ 发现问题:残差连接导致层稀释 ││ ││ 提出方案:用注意力机制动态学习残差权重 ││ ││ 验证效果:48B 模型上困惑度降低 3.9% ││ ││ 工程优化:Block AttnRes 降低计算开销 ││ ││ 行业影响:获得马斯克点赞,可能影响下一代模型架构 ││ │└─────────────────────────────────────────────────────────────┘启示:
即使是 Transformer 这样成熟的架构,仍有改进空间。关键在于发现并定义正确的问题。
参考资料
支持与分享
如果这篇文章对你有帮助,欢迎支持作者或分享给更多人
部分信息可能已经过时






