665 字
2 分钟
Transformer-XL 论文解读:让 Transformer 学会"记忆"更长上下文
Transformer 虽然强大,但有一个致命缺陷:无法处理超过固定长度的上下文。Transformer-XL 通过段级递归机制和相对位置编码解决了这个问题。
本文将深入解析 Transformer-XL 的核心创新。
本文要点
- 固定长度上下文的局限性
- 段级递归机制原理
- 相对位置编码设计
- 与 vanilla Transformer 的对比
一、问题背景:Transformer 的”断章取义”
1.1 固定长度上下文的困扰
┌─────────────────────────────────────────────────────────────┐│ Transformer 固定长度上下文问题 │├─────────────────────────────────────────────────────────────┤│ ││ Vanilla Transformer: ││ • 训练时固定上下文长度(如 512 tokens) ││ • 推理时逐段处理,每段独立 ││ • 段与段之间没有信息传递 ││ ││ 问题: ││ • 无法捕捉跨段的长距离依赖 ││ • 段边界处产生上下文碎片化(Fragmentation) ││ • 生成时可能"忘记"之前的内容 ││ │└─────────────────────────────────────────────────────────────┘# Vanilla Transformer 的处理方式def vanilla_transformer_generate(text, max_length=100): """逐段独立处理,没有跨段信息传递""" segments = split_into_chunks(text, chunk_size=512)
results = [] for seg in segments: # 每段独立编码,不考虑其他段 encoding = transformer_encoder(seg) # 在编码上逐 token 生成 generated = generate_from_encoding(encoding, max_new_tokens) results.append(generated)
return results # 段与段之间完全独立!1.2 上下文碎片化问题
生成示例:上下文碎片化的影响
文本:"我叫小明,来自北京,毕业于清华大学计算机系, 之后在Google工作多年,去年创立了自己的AI公司..."
问题段1(处理"我叫小明"):无法知道后面说的是什么问题段2(处理"毕业于清华"):不记得前面提到过Google
生成结果可能:"我叫小明,毕业于清华大学。" ← 丢失了中间信息二、核心创新:段级递归机制
2.1 递归连接的核心思想
flowchart TB
subgraph Segment N
A1[Segment N 的 Hidden State] --> B1[传入 Segment N+1]
end
subgraph Segment N+1
B1 --> C1[拼接前段信息]
C1 --> D1[扩展上下文]
end
style A1 fill:#4caf50,color:#fff
style D1 fill:#2196f3,color:#fff
┌─────────────────────────────────────────────────────────────┐│ Transformer-XL 核心创新 │├─────────────────────────────────────────────────────────────┤│ ││ 段级递归机制(Segment-Level Recurrence): ││ ││ • 每段处理完后,将其 Hidden State 缓存起来 ││ • 处理下一段时,将前一段的 Hidden State 作为额外上下文 ││ • 类似 RNN 的隐藏状态传递,但用的是 Transformer 的表示 ││ ││ 效果: ││ • 上下文长度从 O(L) 扩展到 O(L×N) ││ • 解决了跨段依赖问题 ││ • 消除上下文碎片化 ││ │└─────────────────────────────────────────────────────────────┘2.2 数学原理
class TransformerXLLayer(nn.Module): """Transformer-XL 的关键:段级递归"""
def __init__(self, d_model, n_heads, d_head, d_ff, seg_len): super().__init__() self.seg_len = seg_len
# 标准 Transformer 组件 self.attention = MultiHeadAttention(d_model, n_heads, d_head) self.feed_forward = FeedForward(d_model, d_ff) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model)
def forward(self, x, prev_state, U=None): """ x: 当前段的输入 [batch, seg_len, d_model] prev_state: 前一段的 Hidden State [batch, seg_len, d_model] """ # 关键:拼接前段状态 # 实际实现中,只使用前段最后一个位置的 state 作为"memory" # 因为最后一个位置包含最完整的序列信息
# 扩展的上下文 x_concat = torch.cat([prev_state[-1:, :, :], x], dim=0)
# 由于位置编码问题,这里需要用相对位置编码 # 详见下一节
# 标准的 Attention + FFN attn_out = self.attention(x_concat, U=U) out = self.feed_forward(self.norm1(attn_out)) out = self.norm2(out)
return out, out # 新的 state2.3 与 RNN 的类比
┌─────────────────────────────────────────────────────────────┐│ Transformer-XL vs RNN │├─────────────────────────────────────────────────────────────┤│ ││ RNN: ││ h_t = f(W · x_t + U · h_{t-1}) ││ 隐藏状态逐时间步传递,梯度消失/爆炸 ││ ││ Transformer-XL: ││ • 用完整的序列表示代替单个隐藏向量 ││ • 保留更丰富的上下文信息 ││ • Attention 机制避免了梯度问题 ││ ││ 优势: ││ • 并行训练(比 RNN 快) ││ • 长距离依赖(比 RNN 强) ││ • 段级记忆(比 vanilla Transformer 久) ││ │└─────────────────────────────────────────────────────────────┘三、相对位置编码
3.1 为什么需要相对位置编码
问题分析:
在 vanilla Transformer 中,位置编码是绝对的:PE(pos) = sin(pos/10000^{2i/d}) 或 PE(pos) = cos(...)
当段级递归时,位置会"重复":
Segment 1: 位置 0, 1, 2, ..., 511Segment 2: 位置 0, 1, 2, ..., 511 ← 同样的位置编码!
这会导致 Attention 分数计算错误:- 位置 0 的 query 和位置 0 的 key 会匹配- 但实际上它们来自不同段,应该有不同的"距离"3.2 相对位置编码设计
def relative_positional_attention(Q, K, V, seg_len, d_model): """ Transformer-XL 的相对位置编码
核心思想:用相对位置替代绝对位置 • query 位置 i 和 key 位置 j 的关系用 (j-i) 表示 • 而不是用 j 本身 """ batch, n_heads, seq_len, d_head = Q.shape
# 相对位置矩阵 # relative_pos[i,j] = j - i,表示 key j 相对于 query i 的位置 relative_pos = torch.arange(seq_len).unsqueeze(0) - torch.arange(seq_len).unsqueeze(1)
# 映射到 d_model 维度 # 生成 2*seg_len-1 个位置编码(正负相对位置) relative_pos_embedding = nn.Parameter( torch.randn(2 * seg_len - 1, d_model) * 0.02 )
# 将相对位置映射到嵌入 relative_pos_emb = relative_pos_embedding[relative_pos + seg_len - 1]
# 相对位置 attention # 不再用绝对位置计算 attention score # 而是考虑 query 和 key 之间的相对距离
# 原始: # score = Q · (K + pos_enc[j])
# 相对位置: # score = Q · K + Q · relative_pos_emb[j - i]
return attention_with_relative_pos(Q, K, V, relative_pos_emb)3.3 四种 attention 分数
┌─────────────────────────────────────────────────────────────┐│ Transformer-XL Attention 分解 │├─────────────────────────────────────────────────────────────┤│ ││ 标准绝对位置编码的 attention: ││ Score(i,j) = Q_i · (K_j + PE_j) ││ ││ 分解为四项(相对位置版本): ││ ││ ① Q_i · K_j 内容交互 ││ ② Q_i · PE_{j-i} 位置-内容 交互 ││ ③ R_{i-j} · K_j 内容-位置 交互(可学习) ││ ④ R_{i-j} · PE_{j-i} 位置-位置 交互 ││ ││ 简化:只用 ① 和 ②,去掉 iii 和 iv ││ 最终:只保留内容交互和相对位置 ││ ││ 优势: ││ • 不需要为每个绝对位置学习位置编码 ││ • 可以泛化到训练时未见过的序列长度 ││ • 更符合"相对位置"的直觉 ││ │└─────────────────────────────────────────────────────────────┘四、完整架构
4.1 Transformer-XL 模型结构
class TransformerXL(nn.Module): """完整的 Transformer-XL 模型"""
def __init__(self, vocab_size, d_model, n_layers, n_heads, d_head, d_ff): super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model) self.layers = nn.ModuleList([ TransformerXLBlock(d_model, n_heads, d_head, d_ff) for _ in range(n_layers) ]) self.ln_f = nn.LayerNorm(d_model) self.head = nn.Linear(d_model, vocab_size, bias=False)
# 权重绑定 self.head.weight = self.embedding.weight
def forward(self, x, mems=None): """ x: [batch, seq_len] mems: 每层的 memory [n_layers, batch, mem_len, d_model] """ seq_len = x.size(1)
# 初始化 memory if mems is None: mems = [None] * self.n_layers
new_mems = [] hidden = self.embedding(x)
for i, layer in enumerate(self.layers): # 关键:传入上一段的 memory hidden, new_mem = layer(hidden, mems[i]) new_mems.append(new_mem)
return self.head(self.ln_f(hidden)), new_mems4.2 训练流程
flowchart TB
subgraph 训练
A1[完整序列] --> B1[切成重叠的段]
B1 --> C1[段1: tokens 0-511]
B1 --> C2[段2: tokens 288-799]
B1 --> C3[段3: tokens 576-1087]
C1 --> D1[计算 loss 1]
C2 --> D2[计算 loss 2]
C3 --> D3[计算 loss 3]
D1 --> E1[缓存 hidden state]
D2 --> E2[缓存 hidden state]
D3 --> E3[缓存 hidden state]
E1 --> C2
E2 --> C3
end
五、实验结果
5.1 性能对比
┌─────────────────────────────────────────────────────────────┐│ Transformer-XL 实验结果 │├─────────────────────────────────────────────────────────────┤│ ││ 依赖长度对比(发现更好的长期依赖): ││ ││ • RNN:80 个单位长度依赖 ││ • vanilla Transformer:450 个单位长度依赖 ││ • Transformer-XL:比 RNN 长 80% 的依赖 ││ • 比 vanilla Transformer 长 450% 的依赖 ││ ││ 困惑度对比(在 WikiText-103 上): ││ ││ • Transformer-XL: 18.3 ││ • 之前最佳: 23.0 ││ • 提升: 20% ││ ││ 推理速度(与 vanilla Transformer 对比): ││ ││ • vanilla Transformer: 基准速度 ││ • Transformer-XL: 快 1800 倍!(因 为 reuse) ││ │└─────────────────────────────────────────────────────────────┘5.2 各数据集表现
bar-chart
title "各数据集困惑度对比(越低越好)"
x-axis ["enwik8", "text8", "WikiText-103", "One Billion Word", "Penn Treebank"]
y-axis "困惑度" 0 --> 70
bar ["0.99", "1.08", "18.3", "21.8", "54.5"]
bar ["1.18", "1.19", "23.0", "26.0", "62.0"]
legend ["Transformer-XL", "之前最佳"]
六、核心创新总结
6.1 Transformer-XL 的两大贡献
┌─────────────────────────────────────────────────────────────┐│ Transformer-XL 核心总结 │├─────────────────────────────────────────────────────────────┤│ ││ 1. 段级递归机制(Segment-Level Recurrence) ││ ││ • 缓存每段的 Hidden State ││ • 作为下一段的额外上下文 ││ • 将有效上下文长度从 O(L) 扩展到 O(L×N) ││ • 消除上下文碎片化 ││ ││ 2. 相对位置编码(Relative Positional Encoding) ││ ││ • 用相对位置替代绝对位置 ││ • 解决段内位置编码重复问题 ││ • 可泛化到更长序列 ││ ││ 影响: ││ • 后续 XLNet、BERT-wwm 等都采用类似技术 ││ • 为长文本处理奠定基础 ││ │└─────────────────────────────────────────────────────────────┘常见问题 FAQ
Q1:Transformer-XL 和 vanilla Transformer 的主要区别是什么?
A:核心区别在于是否有跨段的信息传递。vanilla Transformer 每次只处理固定长度的段,段与段之间独立;Transformer-XL 通过缓存上一段的 Hidden State,让当前段可以看到前一段的信息。
Q2:相对位置编码为什么比绝对位置编码更好?
A:绝对位置编码在段级递归时会导致位置”重复”(如第二段的位置0和第一段的位置0编码相同,但含义不同)。相对位置编码用 query 和 key 之间的相对距离替代绝对位置,更符合 Attention 的本质。
Q3:Transformer-XL 的 memory 会无限增长吗?
A:实际上使用时会有一个最大 memory 长度限制。超出限制时会丢弃最早的 segment,平衡内存和长期依赖。
Q4:Transformer-XL 适用于哪些任务?
A:主要用于需要长文本建模的任务,如语言模型、文本生成、文档级理解等。
Q5:Transformer-XL 对后续模型有什么影响?
A:Transformer-XL 的段级递归和相对位置编码被广泛应用于后续模型,如 XLNet、BERT-wwm、Longformer 等。
小结
Transformer-XL 通过段级递归机制和相对位置编码,解决了固定长度上下文的限制,为长文本处理提供了重要思路。
核心贡献:
┌─────────────────────────────────────────────────────────────┐│ Transformer-XL 核心总结 │├─────────────────────────────────────────────────────────────┤│ ││ 段级递归机制: ││ • 缓存 Hidden State 作为 Memory ││ • 跨段传递信息,解决长距离依赖 ││ • 消除上下文碎片化 ││ ││ 相对位置编码: ││ • 用相对距离替代绝对位置 ││ • 适配段级递归的 Position 重复问题 ││ • 可泛化到更长序列 ││ │└─────────────────────────────────────────────────────────────┘参考资料
支持与分享
如果这篇文章对你有帮助,欢迎支持作者或分享给更多人
Transformer-XL 论文解读:让 Transformer 学会"记忆"更长上下文
https://blog.souloss.com/posts/machine-learning/llm-paper-history/transformer-xl-long-text-model/ 部分信息可能已经过时
相关文章 智能推荐
1
Transformer-XL 论文解读:超越固定长度的注意力机制
AI 深度解读 Transformer-XL 论文——如何通过段级递归机制和相对位置编码,让 Transformer 突破固定长度限制,捕获更长的依赖关系。
2
WebGPT 论文解读:让语言模型学会上网搜索
AI 深度解读 WebGPT 论文——OpenAI 如何训练 GPT-3 使用浏览器进行信息检索和问答。
3
Toolformer 论文解读:让语言模型学会使用工具
AI 深度解读 Toolformer 论文——Meta 如何让语言模型通过自监督学习掌握使用外部工具的能力。
4
Verify Step by Step 论文解读:过程监督让数学推理更强
AI 深度解读 OpenAI Verify Step by Step 论文——比较结果监督和过程监督在数学推理中的效果,证明过程监督能显著提升模型可靠性。
5
Codex 论文解读:AI 编程的开端
AI 深度解读 Codex 论文——OpenAI 如何训练能够编写代码的 GPT 模型,以及 HumanEval 基准的建立。






