mobile wallpaper 1mobile wallpaper 2mobile wallpaper 3mobile wallpaper 4
2578 字
7 分钟
RWKV:线性注意力的开源替代
2025-02-26

2023 年,Bo Peng 提出了 RWKV(Receptance Weighted Key Value),一种将 Transformer 的并行训练优势与 RNN 的高效推理相结合的架构。RWKV 通过将注意力机制改写为线性时间复杂度的 RNN 形式,实现了 O(n) 的推理效率,同时在语言建模性能上接近同规模的 Transformer。从 RWKV-4 到 RWKV-6(Finch),这个架构持续演进,成为 Mamba/SSM 之外最有力的 Transformer 替代方案之一。

RWKV 证明了:不需要注意力机制的二次复杂度,也能训练出高质量的序列模型。

本文要点#

  • 线性注意力:将 O(n²) 注意力改写为 O(n) RNN
  • 时间衰减机制(Time Decay):通道级别的衰减参数
  • Token Shift:引入上下文信息的巧妙技巧
  • RWKV-4 架构:Time-Mixing 和 Channel-Mixing 模块
  • RWKV-5(Eagle):多头公式化、数据依赖衰减
  • RWKV-6(Finch):数据依赖线性注意力与 Lerp Token Shift
  • 与 Mamba/SSM 的对比分析
  • 性能评估与同规模 Transformer 对比
  • 开源社区与 RWKV Foundation

一、背景:为什么需要替代 Transformer?#

1.1 Transformer 的二次复杂度瓶颈#

标准自注意力的计算复杂度为 O(n²),这带来了两个核心问题:

  1. 训练成本:长序列训练的时间和内存随长度二次增长
  2. 推理效率:KV Cache 随序列长度线性增长,限制批处理大小
graph TD A["Transformer 的问题"] --> B["O(n²) 计算复杂度"] A --> C["KV Cache 线性增长"] B --> D["长序列训练成本高"] C --> E["推理吞吐量受限"] D --> F["需要替代架构"] E --> F F --> G["两个方向"] G --> H["SSM/Mamba<br/>状态空间模型"] G --> I["RWKV<br/>线性注意力 RNN"] style H fill:#2196F3,color:#fff style I fill:#4CAF50,color:#fff

1.2 线性注意力的基本思想#

标准注意力的计算可以写为:

Attention(Q, K, V) = softmax(QK^T / √d) × V

其中 softmax 使得 Q 和 K 必须成对计算,导致 O(n²)。如果去掉 softmax(或用核函数近似),计算可以重排为:

LinearAttention(Q, K, V) = (Q × (K^T × V)) / (Q × (K^T × 1))

这样 K^T × V 可以先计算(O(n × d²)),然后再与 Q 相乘(O(n × d²)),总复杂度降为 O(n × d²)。当 d << n 时,这比 O(n²) 高效得多。

更重要的是,这种重排使得推理时可以维护一个累积状态(类似 RNN),实现 O(1) 的逐 Token 推理。

二、RWKV 核心机制#

2.1 Token Shift#

RWKV 的第一个关键设计是 Token Shift(Token 移位):将当前 Token 和前一个 Token 的线性插值作为输入。

# Token Shift 机制
# 对于时间步 t,输入不是单纯的 x_t
# 而是 x_t 和 x_{t-1} 的线性插值
def token_mixing(x_t, x_t_minus_1, ratio):
"""Token Shift:线性插值当前和前一个 Token"""
return ratio * x_t + (1 - ratio) * x_t_minus_1

Token Shift 的作用是让模型在处理当前 Token 时直接获取前一个 Token 的信息,类似于 n-gram 模型中的上下文效果,但更加灵活。

2.2 时间衰减机制(Time Decay)#

RWKV 的第二个核心设计是时间衰减。在标准注意力中,所有位置的关系通过 softmax 归一化。RWKV 用可学习的衰减参数替代 softmax:

# 简化的 RWKV 注意力计算
# 每个通道有独立的衰减参数 w
w = channel_decay # [hidden_dim],可学习参数,0 < w < 1
# 累积状态随时间衰减
# state_t = w * state_{t-1} + k_t * v_t
# output_t = receptance(wkv_state_t)
# 等价于:对每个位置 j <= t
# weight_{t,j} = w^(t-j) # 越远的位置权重越小
graph LR subgraph "时间衰减示意" T5["位置 t-4<br/>权重 w⁴"] --> T4["位置 t-3<br/>权重 w³"] T4 --> T3["位置 t-2<br/>权重 w²"] T3 --> T2["位置 t-1<br/>权重 w¹"] T2 --> T1["位置 t<br/>权重 w⁰=1"] end T5 --> |"衰减"| T4 T4 --> |"衰减"| T3 T3 --> |"衰减"| T2 T2 --> |"无衰减"| T1 style T1 fill:#4CAF50,color:#fff style T5 fill:#FFCDD2

关键特性:

  • 每个通道有独立的衰减参数 w,而非全局共享
  • w 接近 1 表示”长记忆”(类似长距离注意力)
  • w 接近 0 表示”短记忆”(只关注最近几个 Token)

2.3 WKV 计算#

RWKV 的核心计算被称为 WKV(Weighted Key Value),其 RNN 形式如下:

def rwkv_wkv(r, k, v, w, u, state):
"""
RWKV 的 WKV 计算(单步)
r: receptance,决定如何读取状态 [hidden_dim]
k: key,用于更新状态 [hidden_dim]
v: value,用于更新状态 [hidden_dim]
w: time decay,衰减系数 [hidden_dim]
u: time bonus,当前 Token 的额外权重 [hidden_dim]
state: 累积的 KV 状态 [hidden_dim]
"""
# 当前 Token 对输出的贡献(通过 u 加权)
current = torch.exp(u) * k * v
# 衰减后的历史状态 + 当前贡献
wkv = state * torch.exp(-w) + current
# 更新状态
new_state = wkv
# 通过 receptance 门控输出
output = torch.sigmoid(r) * wkv
return output, new_state

三、RWKV-4 架构#

RWKV-4 是第一个被广泛验证的版本,由交替的 Time-Mixing 和 Channel-Mixing 模块组成。

3.1 整体架构#

graph TD subgraph "RWKV-4 Block" Input["输入 x"] --> TM["Time-Mixing<br/>(替代 Self-Attention)"] TM --> LN1["LayerNorm"] LN1 --> CM["Channel-Mixing<br/>(替代 FFN)"] CM --> LN2["LayerNorm"] LN2 --> Output["输出"] end subgraph "Time-Mixing 内部" R["R (Receptance)"] K["K (Key)"] V["V (Value)"] W["W (Time Decay)"] U["U (Time Bonus)"] R & K & V & W & U --> WKV["WKV 计算"] WKV --> OUT["门控输出"] end style TM fill:#1976D2,color:#fff style CM fill:#FF9800,color:#fff

3.2 Time-Mixing 模块#

Time-Mixing 替代了 Transformer 的 Self-Attention:

  1. 对输入应用 Token Shift,得到当前 Token 和前一个 Token 的混合
  2. 通过线性投影生成 R、K、V、W、U 五个向量
  3. 通过 WKV 计算得到输出
  4. 输出经过门控(Receptance 门)和残差连接

3.3 Channel-Mixing 模块#

Channel-Mixing 替代了 Transformer 的 FFN:

def channel_mixing(x, x_prev, W1, W2):
"""
Channel-Mixing 模块
使用前一个 Token 的信息(通过 Token Shift)
"""
# Token Shift
x_mixed = token_shift(x, x_prev)
# 非线性变换(类似 Square ReLU)
hidden = torch.square(torch.relu(W1(x_mixed)))
output = W2(hidden)
return output

3.4 与 Transformer 的对比#

特性TransformerRWKV-4
训练复杂度O(n² × d)O(n × d²)
推理复杂度(逐 Token)O(n × d)(需读 KV Cache)O(d²)(固定状态)
并行训练完全并行可并行
推理内存KV Cache 随序列线性增长固定大小状态
位置编码需要(RoPE/ALiBi)不需要(内置时间衰减)

四、RWKV-5(Eagle):多头改进#

RWKV-5 在 RWKV-4 的基础上引入了多头机制和更精细的数据依赖衰减。

4.1 核心改进#

  1. 多头公式化:类似 Transformer 的多头注意力,RWKV-5 将隐藏维度分为多个头,每个头有独立的参数
  2. 数据依赖衰减:衰减参数 w 不再是固定的可学习参数,而是由输入数据动态生成
  3. GroupNorm:在 WKV 计算中加入 GroupNorm,提高训练稳定性
# RWKV-5 的数据依赖衰减
# w 不再是固定参数,而是由输入 x 生成
w = softplus(time_decay_proj(x)) # 每步动态计算
# 这使得模型可以根据输入内容调整"记忆长度"
# 对需要长距离依赖的内容 → w 接近 1(长记忆)
# 对只需局部上下文的内容 → w 接近 0(短记忆)

五、RWKV-6(Finch):更精细的控制#

RWKV-6 是目前最新的版本,引入了更精细的数据依赖机制。

5.1 核心改进#

  1. Lerp-based Token Shift:用数据依赖的插值比替代固定的 Token Shift
  2. 细粒度注意力路由:每个 Token 可以动态选择关注不同的历史信息
  3. 改进的初始化:更合理的参数初始化策略
# RWKV-6 的数据依赖 Token Shift
# ratio 不再是固定的,而是由输入动态决定
ratio = sigmoid(lerp_proj(x_t))
x_shifted = ratio * x_t + (1 - ratio) * x_{t-1}
# 这使得模型可以根据当前内容决定
# 多少信息来自当前 Token,多少来自前一个 Token

5.2 模型规模#

RWKV-6 已训练了多种规模的模型:

模型参数量训练数据语言
RWKV-6-1.6B1.6B1.4T Tokens英文
RWKV-6-3B3B2.5T Tokens英文
RWKV-6-7B7B3.5T Tokens多语言
RWKV-6-14B14B5T Tokens多语言

六、RWKV vs Mamba:两大替代方案对比#

graph TD subgraph "Transformer 替代方案" direction TB RWKV_Box["RWKV"] RWKV_Features["线性注意力 RNN<br/>时间衰减机制<br/>Token Shift<br/>O(n) 推理"] Mamba_Box["Mamba/SSM"] Mamba_Features["状态空间模型<br/>选择性扫描<br/>硬件感知算法<br/>O(n) 推理"] RWKV_Box --> RWKV_Features Mamba_Box --> Mamba_Features end style RWKV_Box fill:#4CAF50,color:#fff style Mamba_Box fill:#2196F3,color:#fff
特性RWKV-6Mamba-2
理论基础线性注意力状态空间模型
核心机制时间衰减 + Token Shift选择性扫描
推理效率O(1) 状态更新O(1) 状态更新
训练方式可并行可并行(扫描算法)
硬件优化标准 CUDA定制 CUDA 内核
长序列性能良好优秀
开源社区RWKV Foundation广泛学术支持
多语言支持已有多语言模型主要英文

七、性能评估#

7.1 与同规模 Transformer 的对比#

模型参数量WikiText-103 PPL基准测试
GPT-2 (Transformer)1.5B18.3基准
RWKV-4-1.5B1.5B19.2接近 GPT-2
RWKV-5-1.5B1.5B18.1与 GPT-2 持平
RWKV-6-1.6B1.6B17.5超越 GPT-2
LLaMA-7B (Transformer)7B12.6基准
RWKV-6-7B7B13.8接近 LLaMA

7.2 推理效率对比#

操作Transformer (7B)RWKV-6 (7B)
单 Token 生成(短序列)15ms12ms
单 Token 生成(10K 序列)45ms12ms
单 Token 生成(100K 序列)内存不足12ms
内存占用(10K 序列)8GB KV Cache200MB 状态
内存占用(100K 序列)80GB KV Cache200MB 状态

RWKV 的推理速度和内存占用不随序列长度变化,这是其相对于 Transformer 的核心优势。

八、开源社区#

8.1 RWKV Foundation#

RWKV 项目由 RWKV Foundation 维护,是一个活跃的开源社区:

  1. 模型权重:所有规模模型权重完全开源(Apache 2.0)
  2. 训练框架:提供完整的训练和微调代码
  3. 多语言支持:社区贡献了中文、日文、韩文等多语言版本
  4. 生态工具:WebUI、API 服务、量化工具等

8.2 社区贡献#

  • RWKV-World:多语言版本,支持 100+ 种语言
  • RWKV-Music:音乐生成模型
  • RWKV-Visual:视觉语言模型探索
  • ChatRWKV:类似 ChatGPT 的对话界面

常见问题 FAQ#

8.1 Q1: RWKV 能完全替代 Transformer 吗?#

目前还不能。RWKV 在语言建模上接近同规模 Transformer,但在需要精确长距离依赖的任务上(如需要回溯全文的阅读理解)仍有一定差距。不过,对于大多数实际应用(对话、文本生成、代码补全),RWKV 已经足够好用。

8.2 Q2: RWKV 和 Mamba 应该选哪个?#

选择取决于场景:

  • 需要成熟的生态和多语言支持 → RWKV
  • 需要极致的硬件优化和学术支持 → Mamba
  • 混合架构(Transformer + 线性层)→ 两者都可用
  • 纯推理场景、超长序列 → 两者都适合

8.3 Q3: RWKV 的训练效率如何?#

RWKV 的训练可以像 Transformer 一样并行化(使用时间维度的并行扫描),但训练效率略低于 Transformer(约慢 10-20%),因为 RWKV 的计算模式对 GPU 不如标准注意力友好。RWKV 的优势主要在推理端。

8.4 Q4: RWKV 支持多模态吗?#

目前 RWKV 的多模态支持还在早期阶段。社区有 RWKV-Visual 等探索性项目,但还没有像 LLaVA 那样成熟的多模态方案。这是 RWKV 生态的一个待完善方向。

8.5 Q5: RWKV 如何处理超长上下文?#

RWKV 天然支持超长上下文,因为其推理状态是固定大小的。理论上,RWKV 可以处理无限长度的序列(状态不随序列增长)。但实际效果受限于时间衰减机制——太远的信息会被衰减到几乎为零。RWKV-6 的数据依赖衰减在一定程度上缓解了这个问题。

小结#

RWKV 的核心贡献可以总结为:

  1. 线性注意力 RNN:将注意力机制改写为 O(n) 复杂度的 RNN 形式
  2. 时间衰减机制:每个通道独立的可学习衰减参数,替代 softmax
  3. Token Shift:简洁有效的上下文信息注入方式
  4. 持续演进:从 RWKV-4 到 RWKV-6,架构不断优化
  5. 开源生态:完全开源,社区活跃,多语言支持

RWKV 证明了 Transformer 的二次复杂度不是必需的——通过精心设计的线性注意力机制,可以在保持竞争力的同时实现 O(n) 的推理效率。虽然它还没有在所有任务上超越 Transformer,但作为最有力的替代方案之一,RWKV 为序列建模开辟了新的方向。

对于想深入了解的读者,建议阅读顺序:

  1. 本文(RWKV)→ 理解线性注意力替代方案
  2. 第 47 篇(Mamba/SSM)→ 理解状态空间模型替代方案
  3. 第 18 篇(MQA/GQA)→ 理解 KV Cache 优化

参考资料#

支持与分享

如果这篇文章对你有帮助,欢迎支持作者或分享给更多人

RWKV:线性注意力的开源替代
https://blog.souloss.com/posts/machine-learning/llm-paper-history/rwkv-linear-attention/
作者
Souloss
发布于
2025-02-26
许可协议
CC BY-NC-SA 4.0

部分信息可能已经过时