mobile wallpaper 1mobile wallpaper 2mobile wallpaper 3mobile wallpaper 4
665 字
2 分钟
Transformer-XL 论文解读:让 Transformer 学会"记忆"更长上下文
2025-03-12

Transformer 虽然强大,但有一个致命缺陷:无法处理超过固定长度的上下文。Transformer-XL 通过段级递归机制和相对位置编码解决了这个问题。

本文将深入解析 Transformer-XL 的核心创新。

本文要点#

  • 固定长度上下文的局限性
  • 段级递归机制原理
  • 相对位置编码设计
  • 与 vanilla Transformer 的对比

一、问题背景:Transformer 的”断章取义”#

1.1 固定长度上下文的困扰#

┌─────────────────────────────────────────────────────────────┐
│ Transformer 固定长度上下文问题 │
├─────────────────────────────────────────────────────────────┤
│ │
│ Vanilla Transformer: │
│ • 训练时固定上下文长度(如 512 tokens) │
│ • 推理时逐段处理,每段独立 │
│ • 段与段之间没有信息传递 │
│ │
│ 问题: │
│ • 无法捕捉跨段的长距离依赖 │
│ • 段边界处产生上下文碎片化(Fragmentation) │
│ • 生成时可能"忘记"之前的内容 │
│ │
└─────────────────────────────────────────────────────────────┘
# Vanilla Transformer 的处理方式
def vanilla_transformer_generate(text, max_length=100):
"""逐段独立处理,没有跨段信息传递"""
segments = split_into_chunks(text, chunk_size=512)
results = []
for seg in segments:
# 每段独立编码,不考虑其他段
encoding = transformer_encoder(seg)
# 在编码上逐 token 生成
generated = generate_from_encoding(encoding, max_new_tokens)
results.append(generated)
return results # 段与段之间完全独立!

1.2 上下文碎片化问题#

生成示例:上下文碎片化的影响
文本:"我叫小明,来自北京,毕业于清华大学计算机系,
之后在Google工作多年,去年创立了自己的AI公司..."
问题段1(处理"我叫小明"):无法知道后面说的是什么
问题段2(处理"毕业于清华"):不记得前面提到过Google
生成结果可能:
"我叫小明,毕业于清华大学。" ← 丢失了中间信息

二、核心创新:段级递归机制#

2.1 递归连接的核心思想#

flowchart TB subgraph Segment N A1[Segment N 的 Hidden State] --> B1[传入 Segment N+1] end subgraph Segment N+1 B1 --> C1[拼接前段信息] C1 --> D1[扩展上下文] end style A1 fill:#4caf50,color:#fff style D1 fill:#2196f3,color:#fff
┌─────────────────────────────────────────────────────────────┐
│ Transformer-XL 核心创新 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 段级递归机制(Segment-Level Recurrence): │
│ │
│ • 每段处理完后,将其 Hidden State 缓存起来 │
│ • 处理下一段时,将前一段的 Hidden State 作为额外上下文 │
│ • 类似 RNN 的隐藏状态传递,但用的是 Transformer 的表示 │
│ │
│ 效果: │
│ • 上下文长度从 O(L) 扩展到 O(L×N) │
│ • 解决了跨段依赖问题 │
│ • 消除上下文碎片化 │
│ │
└─────────────────────────────────────────────────────────────┘

2.2 数学原理#

class TransformerXLLayer(nn.Module):
"""Transformer-XL 的关键:段级递归"""
def __init__(self, d_model, n_heads, d_head, d_ff, seg_len):
super().__init__()
self.seg_len = seg_len
# 标准 Transformer 组件
self.attention = MultiHeadAttention(d_model, n_heads, d_head)
self.feed_forward = FeedForward(d_model, d_ff)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x, prev_state, U=None):
"""
x: 当前段的输入 [batch, seg_len, d_model]
prev_state: 前一段的 Hidden State [batch, seg_len, d_model]
"""
# 关键:拼接前段状态
# 实际实现中,只使用前段最后一个位置的 state 作为"memory"
# 因为最后一个位置包含最完整的序列信息
# 扩展的上下文
x_concat = torch.cat([prev_state[-1:, :, :], x], dim=0)
# 由于位置编码问题,这里需要用相对位置编码
# 详见下一节
# 标准的 Attention + FFN
attn_out = self.attention(x_concat, U=U)
out = self.feed_forward(self.norm1(attn_out))
out = self.norm2(out)
return out, out # 新的 state

2.3 与 RNN 的类比#

┌─────────────────────────────────────────────────────────────┐
│ Transformer-XL vs RNN │
├─────────────────────────────────────────────────────────────┤
│ │
│ RNN: │
│ h_t = f(W · x_t + U · h_{t-1}) │
│ 隐藏状态逐时间步传递,梯度消失/爆炸 │
│ │
│ Transformer-XL: │
│ • 用完整的序列表示代替单个隐藏向量 │
│ • 保留更丰富的上下文信息 │
│ • Attention 机制避免了梯度问题 │
│ │
│ 优势: │
│ • 并行训练(比 RNN 快) │
│ • 长距离依赖(比 RNN 强) │
│ • 段级记忆(比 vanilla Transformer 久) │
│ │
└─────────────────────────────────────────────────────────────┘

三、相对位置编码#

3.1 为什么需要相对位置编码#

问题分析:
在 vanilla Transformer 中,位置编码是绝对的:
PE(pos) = sin(pos/10000^{2i/d}) 或 PE(pos) = cos(...)
当段级递归时,位置会"重复":
Segment 1: 位置 0, 1, 2, ..., 511
Segment 2: 位置 0, 1, 2, ..., 511 ← 同样的位置编码!
这会导致 Attention 分数计算错误:
- 位置 0 的 query 和位置 0 的 key 会匹配
- 但实际上它们来自不同段,应该有不同的"距离"

3.2 相对位置编码设计#

def relative_positional_attention(Q, K, V, seg_len, d_model):
"""
Transformer-XL 的相对位置编码
核心思想:用相对位置替代绝对位置
• query 位置 i 和 key 位置 j 的关系用 (j-i) 表示
• 而不是用 j 本身
"""
batch, n_heads, seq_len, d_head = Q.shape
# 相对位置矩阵
# relative_pos[i,j] = j - i,表示 key j 相对于 query i 的位置
relative_pos = torch.arange(seq_len).unsqueeze(0) - torch.arange(seq_len).unsqueeze(1)
# 映射到 d_model 维度
# 生成 2*seg_len-1 个位置编码(正负相对位置)
relative_pos_embedding = nn.Parameter(
torch.randn(2 * seg_len - 1, d_model) * 0.02
)
# 将相对位置映射到嵌入
relative_pos_emb = relative_pos_embedding[relative_pos + seg_len - 1]
# 相对位置 attention
# 不再用绝对位置计算 attention score
# 而是考虑 query 和 key 之间的相对距离
# 原始:
# score = Q · (K + pos_enc[j])
# 相对位置:
# score = Q · K + Q · relative_pos_emb[j - i]
return attention_with_relative_pos(Q, K, V, relative_pos_emb)

3.3 四种 attention 分数#

┌─────────────────────────────────────────────────────────────┐
│ Transformer-XL Attention 分解 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 标准绝对位置编码的 attention: │
│ Score(i,j) = Q_i · (K_j + PE_j) │
│ │
│ 分解为四项(相对位置版本): │
│ │
│ ① Q_i · K_j 内容交互 │
│ ② Q_i · PE_{j-i} 位置-内容 交互 │
│ ③ R_{i-j} · K_j 内容-位置 交互(可学习) │
│ ④ R_{i-j} · PE_{j-i} 位置-位置 交互 │
│ │
│ 简化:只用 ① 和 ②,去掉 iii 和 iv │
│ 最终:只保留内容交互和相对位置 │
│ │
│ 优势: │
│ • 不需要为每个绝对位置学习位置编码 │
│ • 可以泛化到训练时未见过的序列长度 │
│ • 更符合"相对位置"的直觉 │
│ │
└─────────────────────────────────────────────────────────────┘

四、完整架构#

4.1 Transformer-XL 模型结构#

class TransformerXL(nn.Module):
"""完整的 Transformer-XL 模型"""
def __init__(self, vocab_size, d_model, n_layers, n_heads, d_head, d_ff):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.layers = nn.ModuleList([
TransformerXLBlock(d_model, n_heads, d_head, d_ff)
for _ in range(n_layers)
])
self.ln_f = nn.LayerNorm(d_model)
self.head = nn.Linear(d_model, vocab_size, bias=False)
# 权重绑定
self.head.weight = self.embedding.weight
def forward(self, x, mems=None):
"""
x: [batch, seq_len]
mems: 每层的 memory [n_layers, batch, mem_len, d_model]
"""
seq_len = x.size(1)
# 初始化 memory
if mems is None:
mems = [None] * self.n_layers
new_mems = []
hidden = self.embedding(x)
for i, layer in enumerate(self.layers):
# 关键:传入上一段的 memory
hidden, new_mem = layer(hidden, mems[i])
new_mems.append(new_mem)
return self.head(self.ln_f(hidden)), new_mems

4.2 训练流程#

flowchart TB subgraph 训练 A1[完整序列] --> B1[切成重叠的段] B1 --> C1[段1: tokens 0-511] B1 --> C2[段2: tokens 288-799] B1 --> C3[段3: tokens 576-1087] C1 --> D1[计算 loss 1] C2 --> D2[计算 loss 2] C3 --> D3[计算 loss 3] D1 --> E1[缓存 hidden state] D2 --> E2[缓存 hidden state] D3 --> E3[缓存 hidden state] E1 --> C2 E2 --> C3 end

五、实验结果#

5.1 性能对比#

┌─────────────────────────────────────────────────────────────┐
│ Transformer-XL 实验结果 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 依赖长度对比(发现更好的长期依赖): │
│ │
│ • RNN:80 个单位长度依赖 │
│ • vanilla Transformer:450 个单位长度依赖 │
│ • Transformer-XL:比 RNN 长 80% 的依赖 │
│ • 比 vanilla Transformer 长 450% 的依赖 │
│ │
│ 困惑度对比(在 WikiText-103 上): │
│ │
│ • Transformer-XL: 18.3 │
│ • 之前最佳: 23.0 │
│ • 提升: 20% │
│ │
│ 推理速度(与 vanilla Transformer 对比): │
│ │
│ • vanilla Transformer: 基准速度 │
│ • Transformer-XL: 快 1800 倍!(因 为 reuse) │
│ │
└─────────────────────────────────────────────────────────────┘

5.2 各数据集表现#

bar-chart title "各数据集困惑度对比(越低越好)" x-axis ["enwik8", "text8", "WikiText-103", "One Billion Word", "Penn Treebank"] y-axis "困惑度" 0 --> 70 bar ["0.99", "1.08", "18.3", "21.8", "54.5"] bar ["1.18", "1.19", "23.0", "26.0", "62.0"] legend ["Transformer-XL", "之前最佳"]

六、核心创新总结#

6.1 Transformer-XL 的两大贡献#

┌─────────────────────────────────────────────────────────────┐
│ Transformer-XL 核心总结 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 1. 段级递归机制(Segment-Level Recurrence) │
│ │
│ • 缓存每段的 Hidden State │
│ • 作为下一段的额外上下文 │
│ • 将有效上下文长度从 O(L) 扩展到 O(L×N) │
│ • 消除上下文碎片化 │
│ │
│ 2. 相对位置编码(Relative Positional Encoding) │
│ │
│ • 用相对位置替代绝对位置 │
│ • 解决段内位置编码重复问题 │
│ • 可泛化到更长序列 │
│ │
│ 影响: │
│ • 后续 XLNet、BERT-wwm 等都采用类似技术 │
│ • 为长文本处理奠定基础 │
│ │
└─────────────────────────────────────────────────────────────┘

常见问题 FAQ#

Q1:Transformer-XL 和 vanilla Transformer 的主要区别是什么?

A:核心区别在于是否有跨段的信息传递。vanilla Transformer 每次只处理固定长度的段,段与段之间独立;Transformer-XL 通过缓存上一段的 Hidden State,让当前段可以看到前一段的信息。

Q2:相对位置编码为什么比绝对位置编码更好?

A:绝对位置编码在段级递归时会导致位置”重复”(如第二段的位置0和第一段的位置0编码相同,但含义不同)。相对位置编码用 query 和 key 之间的相对距离替代绝对位置,更符合 Attention 的本质。

Q3:Transformer-XL 的 memory 会无限增长吗?

A:实际上使用时会有一个最大 memory 长度限制。超出限制时会丢弃最早的 segment,平衡内存和长期依赖。

Q4:Transformer-XL 适用于哪些任务?

A:主要用于需要长文本建模的任务,如语言模型、文本生成、文档级理解等。

Q5:Transformer-XL 对后续模型有什么影响?

A:Transformer-XL 的段级递归和相对位置编码被广泛应用于后续模型,如 XLNet、BERT-wwm、Longformer 等。


小结#

Transformer-XL 通过段级递归机制和相对位置编码,解决了固定长度上下文的限制,为长文本处理提供了重要思路。

核心贡献:

┌─────────────────────────────────────────────────────────────┐
│ Transformer-XL 核心总结 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 段级递归机制: │
│ • 缓存 Hidden State 作为 Memory │
│ • 跨段传递信息,解决长距离依赖 │
│ • 消除上下文碎片化 │
│ │
│ 相对位置编码: │
│ • 用相对距离替代绝对位置 │
│ • 适配段级递归的 Position 重复问题 │
│ • 可泛化到更长序列 │
│ │
└─────────────────────────────────────────────────────────────┘

参考资料#

支持与分享

如果这篇文章对你有帮助,欢迎支持作者或分享给更多人

Transformer-XL 论文解读:让 Transformer 学会"记忆"更长上下文
https://blog.souloss.com/posts/machine-learning/llm-paper-history/transformer-xl-long-text-model/
作者
Souloss
发布于
2025-03-12
许可协议
CC BY-NC-SA 4.0

部分信息可能已经过时