mobile wallpaper 1mobile wallpaper 2mobile wallpaper 3mobile wallpaper 4
1323 字
4 分钟
Mamba 与 SSM:挑战 Transformer 的新架构
2025-04-06

Transformer 的注意力机制虽然强大,但其二次复杂度在长序列建模中始终是瓶颈。Mamba 通过引入选择性状态空间模型(Selective SSM),以线性复杂度实现了媲美 Transformer 的序列建模能力,被认为是自 2017 年以来最具革命性的架构创新之一。

本文将详细解读 SSM 的理论基础、Mamba 的核心创新、硬件感知算法设计以及与 Transformer 的全面对比。

本文要点#

  • 状态空间模型(SSM)基础:从经典控制理论到深度学习
  • HiPPO 框架与 S4 模型:连续信号离散化的关键突破
  • Mamba 核心创新:选择性状态空间机制
  • 输入依赖参数化:让 SSM 具备动态建模能力
  • 硬件感知并行扫描算法:GPU 友好的高效实现
  • Mamba 架构设计:用 SSM 替代 MLP 层
  • 与 Transformer 的全面对比
  • Mamba-2 与结构化状态空间对偶(SSD)理论
  • 实际应用场景与未来展望

一、背景:Transformer 的长序列困境#

flowchart TB A["Transformer 的长序列困境"] --> B["注意力复杂度"] A --> C["KV Cache 问题"] A --> D["序列长度的限制"] B --> B1["自注意力: O(N²) 计算复杂度"] B --> B2["序列长度翻倍 → 计算量翻四倍"] B --> B3["难以处理超长上下文"] C --> C1["KV Cache 随序列线性增长"] C --> C2["推理时内存占用巨大"] C --> C3["Batch Size 受限"] D --> D1["GPT-4: 128K 上下文"] D --> D2["Claude: 200K 上下文"] D --> D3["更长的序列需要全新的架构"] style A fill:#f44336,color:#fff
# 各种序列建模范式的复杂度对比
sequence_models = {
"Transformer": {"complexity": "O(N²·d)", "status": "当前主流"},
"RNN/LSTM": {"complexity": "O(N·d²)", "status": "已被取代"},
"SSM/S4": {"complexity": "O(N·d·logN)", "status": "Mamba 前身"},
"Mamba": {"complexity": "O(N·d)", "status": "Transformer 挑战者"},
}

二、SSM 基础:从控制理论到深度学习#

2.1 经典状态空间模型#

状态空间模型(SSM)核心由两个方程组成:

  • 状态方程h'(t) = A · h(t) + B · x(t) — 隐藏状态如何演化
  • 输出方程y(t) = C · h(t) + D · x(t) — 如何从状态生成输出
import torch, torch.nn as nn, math
class SimpleSSM(nn.Module):
"""经典 SSM: h[t]=A·h[t-1]+B·x[t], y[t]=C·h[t]+D·x[t]"""
def __init__(self, d_model, d_state=16):
super().__init__()
self.d_state = d_state
self.A = nn.Parameter(torch.randn(d_state, d_state))
self.B = nn.Parameter(torch.randn(d_state, 1))
self.C = nn.Parameter(torch.randn(1, d_state))
self.D = nn.Parameter(torch.randn(1))
def forward(self, x):
batch, seq_len, _ = x.shape
h = torch.zeros(batch, self.d_state, device=x.device)
outputs = []
for t in range(seq_len):
h = h @ self.A.T + x[:, t, :] @ self.B.T
outputs.append(h @ self.C.T + self.D * x[:, t, :])
return torch.stack(outputs, dim=1)

2.2 SSM 的信号流#

flowchart LR subgraph SSM["状态空间模型信号流"] X["输入 x(t)"] -->|"B 矩阵"| STATE["隐藏状态 h(t)"] STATE -->|"A 矩阵<br/>状态转移"| STATE STATE -->|"C 矩阵"| Y["输出 y(t)"] X -->|"D 矩阵<br/>直通"| Y end subgraph DISK["离散化过程"] C1["连续 SSM"] -->|"零阶保持 ZOH"| C2["离散 SSM"] C2 -->|"Ā = f(A, Δ)"| C3["可递归计算"] end style SSM fill:#e3f2fd style DISK fill:#fff3e0

2.3 HiPPO 框架与 S4#

S4 模型使用 HiPPO(High-order Polynomial Projection Operators)初始化状态矩阵 A,赋予 SSM 长程记忆能力:

def hippo_initializer(N):
"""HiPPO-LegS 矩阵:将信号投影到正交多项式基"""
A = torch.zeros(N, N)
for n in range(N):
for k in range(N):
if n > k:
A[n, k] = math.sqrt((2 * n + 1) * (2 * k + 1))
elif n == k:
A[n, k] = n + 1
return -A
# S4 核心创新:HiPPO 初始化 + 结构化矩阵分解 + 训练时卷积模式

2.4 S4 的局限#

flowchart TB A["S4 的局限"] --> B["线性时不变 LTI"] A --> C["固定参数"] A --> D["无法选择性遗忘"] B --> B1["A, B, C 矩阵与输入无关"] C --> C1["无法根据内容调整行为"] D --> D1["所有历史信息被均匀压缩"] D --> D2["无法像注意力一样关注重要 Token"] D --> D3["在语言建模中表现不如 Transformer"] style A fill:#ff9800,color:#fff

三、Mamba 核心:选择性状态空间机制#

3.1 核心洞察:让参数依赖输入#

Mamba 打破 SSM 的线性时不变(LTI)约束,让状态空间参数成为输入的函数:

flowchart TB A["Mamba: 选择性 SSM"] --> B["核心思想"] A --> C["实现方式"] A --> D["效果"] B --> B1["让 SSM 参数依赖于输入"] B --> B2["选择性传播或遗忘信息"] B --> B3["类似注意力的动态权重,但线性复杂度"] C --> C1["B = Linear(x) — 输入选择"] C --> C2["C = Linear(x) — 输出选择"] C --> C3["Δ = softplus(Linear(x)) — 步长"] D --> D1["重要信息: 大 Δ → 精确保留"] D --> D2["噪声/填充: 小 Δ → 快速遗忘"] D --> D3["实现内容感知的序列建模"] style A fill:#4caf50,color:#fff

3.2 选择性 SSM 的实现#

class SelectiveSSM(nn.Module):
"""Mamba S6: B, C, Δ 都是输入 x 的函数"""
def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
super().__init__()
self.d_inner = int(expand * d_model)
self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
self.conv1d = nn.Conv1d(self.d_inner, self.d_inner,
kernel_size=d_conv, groups=self.d_inner)
# 关键:输入依赖参数投影
self.x_proj = nn.Linear(self.d_inner, d_state * 2 + 1, bias=False)
self.dt_proj = nn.Linear(1, self.d_inner, bias=True)
A = torch.arange(1, d_state + 1, dtype=torch.float32)
self.A_log = nn.Parameter(torch.log(A.repeat(self.d_inner, 1)))
self.D = nn.Parameter(torch.ones(self.d_inner))
self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
def forward(self, x):
batch, seq_len, _ = x.shape
xz = self.in_proj(x)
x_b, z = xz.chunk(2, dim=-1)
# 因果卷积 + SiLU
x_b = torch.silu(
self.conv1d(x_b.transpose(1, 2))[:, :, :seq_len].transpose(1, 2))
# 选择性参数(核心!B, C, Δ 依赖于输入)
params = self.x_proj(x_b)
delta = torch.softplus(self.dt_proj(params[:, :, :1].transpose(1, 2)))
B = params[:, :, 1:params.shape[-1]//2 + 1]
C = params[:, :, params.shape[-1]//2 + 1:]
# 离散化 + 扫描 + 门控输出
A = -torch.exp(self.A_log)
y = self._scan(x_b, delta, A, B, C, self.D)
return self.out_proj(y * torch.silu(z))
def _scan(self, x, delta, A, B, C, D):
"""并行扫描: h[t] = Ā[t]·h[t-1] + B̄[t]·x[t], y[t] = C[t]·h[t]"""
dA = torch.exp(delta.unsqueeze(-1) * A.unsqueeze(0).unsqueeze(0))
dB = delta.unsqueeze(-1) * B.unsqueeze(2)
h = torch.zeros(x.shape[0], self.d_inner, A.shape[1], device=x.device)
ys = []
for t in range(x.shape[1]):
h = dA[:, t] * h + dB[:, t] * x[:, t].unsqueeze(-1)
ys.append(torch.sum(h * C[:, t].unsqueeze(1), dim=-1))
return torch.stack(ys, dim=1) + D * x

3.3 选择性机制的工作原理#

flowchart TB subgraph SELECTIVE["选择性 SSM"] INPUT["输入 x₁...xₜ"] --> STEP["每个时间步 t"] STEP --> DELTA["Δₜ: 步长,决定保留程度"] STEP --> BT["Bₜ: 输入选择"] STEP --> CT["Cₜ: 输出选择"] DELTA --> UPDATE["hₜ = Āₜ·hₜ₋₁ + B̄ₜ·xₜ"] BT --> UPDATE UPDATE --> OUTPUT["yₜ = Cₜ·hₜ + D·xₜ"] CT --> OUTPUT end subgraph ANALOGY["与注意力的类比"] A1["注意力: Q·K^T 动态权重"] A2["Mamba: Δ,B,C 动态信息流"] A3["共同点: 输入依赖的动态机制"] end style SELECTIVE fill:#e8f5e9 style ANALOGY fill:#e3f2fd

四、硬件感知并行扫描#

选择性 SSM 打破了 LTI 约束,无法使用 S4 的卷积模式。Mamba 通过硬件感知并行扫描解决:

# 并行扫描:O(N) 递归 → O(log N) 并行步数
# CUDA 优化策略:
hardware_optimizations = {
"scan_kernel": "自定义 CUDA 内核实现并行扫描",
"memory_reuse": "中间状态仅在 SRAM 中计算,不写 HBM",
"kernel_fusion": "离散化 + 扫描 + 输出合并为单一内核",
"recomputation": "反向传播时重算中间值,减少显存占用",
}
flowchart LR subgraph NAIVE["朴素递归"] N1["逐时间步计算"] N2["O(N) 无法并行"] N3["大量 HBM 读写"] end subgraph MAMBA_HW["Mamba 硬件感知"] M1["并行扫描算法"] M2["O(log N) 并行步数"] M3["SRAM 中间计算"] end NAIVE -->|"显著加速"| MAMBA_HW style NAIVE fill:#ffcdd2 style MAMBA_HW fill:#c8e6c9

五、Mamba 架构设计#

Mamba 将 SSM 集成为完整模块,替代 Transformer 中的 MLP 层:

flowchart TB subgraph TRANSFORMER_BLOCK["Transformer Block"] T_IN["x"] --> T_NORM1["RMSNorm"] T_NORM1 --> T_ATTN["Multi-Head Attention"] T_ATTN --> T_ADD1["+ 残差"] T_ADD1 --> T_NORM2["RMSNorm"] T_NORM2 --> T_MLP["MLP"] T_MLP --> T_ADD2["+ 残差"] end subgraph MAMBA_BLOCK["Mamba Block"] M_IN["x"] --> M_NORM["RMSNorm"] M_NORM --> M_PROJ["投影 1x→2x + 分叉"] M_PROJ --> M_CONV["因果卷积 + SiLU"] M_CONV --> M_SSM["选择性 SSM"] M_SSM --> M_GATE["× SiLU(z) 门控"] M_GATE --> M_OUT["+ 残差"] end style TRANSFORMER_BLOCK fill:#e3f2fd style MAMBA_BLOCK fill:#e8f5e9
class MambaBlock(nn.Module):
def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
super().__init__()
self.norm = nn.LayerNorm(d_model)
self.ssm = SelectiveSSM(d_model, d_state, d_conv, expand)
def forward(self, x):
return x + self.ssm(self.norm(x))
class MambaModel(nn.Module):
"""Mamba LM: 多层 Mamba Block + Embedding + LM Head"""
def __init__(self, vocab_size, d_model=1024, n_layers=48, **kwargs):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.layers = nn.ModuleList([MambaBlock(d_model, **kwargs)
for _ in range(n_layers)])
self.norm_f = nn.LayerNorm(d_model)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
def forward(self, input_ids):
x = self.embedding(input_ids)
for layer in self.layers:
x = layer(x)
return self.lm_head(self.norm_f(x))

六、与 Transformer 的全面对比#

flowchart TB subgraph COMPLEXITY["复杂度分析 N=序列长度, d=模型维度"] A --> B["Transformer"] A --> C["Mamba"] B --> B1["训练: O(N²·d)"] B --> B2["推理: KV Cache 线性增长"] C --> C1["训练: O(N·d)"] C --> C2["推理: 固定状态大小"] end subgraph KEY["关键差异"] D1["Transformer: 训练并行,推理沉重"] D2["Mamba: 训练并行,推理高效"] end style COMPLEXITY fill:#e3f2fd style KEY fill:#fff3e0

6.1 性能对比#

# 语言建模 Perplexity(越低越好)
performance = {
"WikiText-103": {"Transformer++": 17.7, "Mamba-130M": 17.1, "Mamba-1.3B": 11.2},
"Pile": {"Transformer++": 9.8, "Mamba": 8.8},
"发现": [
"同等参数量下 Mamba 优于 Transformer++",
"推理速度快 5 倍,内存不随序列增长"
]
}
维度TransformerMamba
训练并行度高度并行并行扫描,高度并行
推理复杂度O(N)(KV Cache)O(1)(固定状态)
长序列建模受限于上下文窗口理论上无限长
内存占用KV Cache 线性增长固定状态大小
复述能力优秀较弱(压缩表示)

七、实验结果#

results = {
"预训练": "The Pile 300B tokens,所有尺寸优于 Transformer++",
"下游任务": {
"常识推理": "与同级别 Transformer 持平或更优",
"阅读理解": "显著优于 S4 等先前 SSM",
},
"长序列": "Long Range Arena 大幅领先,书级别建模超越所有基线"
}

八、Mamba-2:结构化状态空间对偶#

Mamba-2 通过 SSD 理论统一了 SSM 与 Attention:

mamba2_innovation = {
"SSD 理论": "选择性 SSM = 结构化半可分矩阵上的特定注意力",
"统一视角": "SSM ↔ 半可分矩阵 ↔ 结构化注意力",
"训练速度": "比 Mamba-1 快 2-8 倍",
"混合架构": "更容易与 Transformer 注意力层混合"
}
flowchart TB A["SSD: 统一视角"] --> B["SSM 递归视图"] A --> C["注意力矩阵视图"] A --> D["半可分矩阵分解"] B --> E["hₜ = Āₜhₜ₋₁ + B̄ₜxₜ"] C --> F["Y = M·X (M 是结构化掩码)"] D --> G["O(N) 而非 O(N²)"] E --> H["SSM 和 Attention 是同一事物的不同视图"] F --> H G --> H style A fill:#9c27b0,color:#fff style H fill:#4caf50,color:#fff

九、应用场景#

applications = {
"语言建模": "大规模 LLM 预训练,线性复杂度推理高效",
"基因组学": "DNA 序列建模,处理百万级碱基",
"视觉": "VMamba / Vision Mamba,高分辨率图像",
"音频": "语音合成、音乐生成",
"混合架构": "Jamba: Transformer + Mamba 各取所长"
}
flowchart TB subgraph JAMBA["Jamba 混合架构"] L1["Mamba"] --> L2["Mamba"] --> L3["Attention"] L3 --> L4["Mamba"] --> L5["Mamba"] --> L6["Attention"] end subgraph ADV["混合优势"] A1["Mamba: 高效长序列"] A2["Attention: 精确检索"] end JAMBA --> ADV style JAMBA fill:#e8f5e9 style ADV fill:#e3f2fd

常见问题 FAQ#

Q1:Mamba 能完全取代 Transformer 吗?

A:目前还不能。Mamba 在语言建模上表现出色,但在精确检索特定 Token 的任务上 Transformer 更有优势。更可能的发展方向是混合架构(如 Jamba)。

Q2:选择性 SSM 与普通 SSM 的关键区别是什么?

A:普通 SSM(如 S4)的参数 A、B、C 固定不变。选择性 SSM 让这些参数成为输入的函数,模型可以动态决定保留或遗忘信息。

Q3:Mamba 的推理速度为什么比 Transformer 快?

A:Transformer 推理时需要维护不断增长的 KV Cache,内存带宽成为瓶颈。Mamba 的状态大小固定(d_state 维度),推理时间和内存都是常数级别。

Q4:什么是硬件感知算法,为什么 Mamba 需要它?

A:选择性 SSM 因参数输入依赖,无法使用 S4 的卷积模式并行训练。硬件感知算法通过自定义 CUDA 内核、SRAM 中间计算、内核融合实现高效执行。

Q5:Mamba-2 相比 Mamba-1 有哪些改进?

A:Mamba-2 通过 SSD 理论统一了 SSM 和 Attention 视角,实现 2-8 倍训练加速。引入结构化矩阵分解减少参数量,混合架构更易实现。

Q6:Mamba 适合处理多长的序列?

A:理论上可处理任意长度序列。实际在数万到数十万 Token 上表现良好,百万级超长序列 SSM 仍是最有前景的方案。

Q7:Mamba 在代码生成任务上表现如何?

A:接近 Transformer,但精确复述长段代码时略弱(压缩表示不如注意力的精确存储)。短上下文任务两者差异不大。


小结#

Mamba 通过选择性状态空间模型,为序列建模开辟了全新方向:

flowchart TB A["Mamba 核心总结"] --> B["核心创新"] A --> C["关键优势"] A --> D["架构设计"] A --> E["未来方向"] B --> B1["选择性 SSM: 输入依赖参数"] B --> B2["硬件感知并行扫描"] B --> B3["Mamba Block 替代 MLP"] C --> C1["训练: O(N·d) 线性复杂度"] C --> C2["推理: O(1) 常数状态"] C --> C3["内存: 不随序列增长"] D --> D1["因果卷积 + 选择性 SSM + 门控"] D --> D2["Mamba-2: SSD 统一理论"] D --> D3["混合架构: Mamba + Attention"] E --> E1["更大规模的预训练验证"] E --> E2["多模态扩展"] E --> E3["与 Transformer 深度融合"] style A fill:#4caf50,color:#fff

参考资料#

支持与分享

如果这篇文章对你有帮助,欢迎支持作者或分享给更多人

Mamba 与 SSM:挑战 Transformer 的新架构
https://blog.souloss.com/posts/machine-learning/llm-paper-history/mamba-and-ssm-state-space-model/
作者
Souloss
发布于
2025-04-06
许可协议
CC BY-NC-SA 4.0

部分信息可能已经过时