SameTime WMT 专题:Phase 1 从 RNN 记忆到 LSTM 门控
“不要跳过推车直接开跑车。先造一辆吱嘎作响的木板车,体会它为什么散架,再理解锻造淬火的钢架好在哪。”
Phase 1 拆分为两步
Phase 1 原计划走 “RNN Seq2Seq”,但代码里全是 nn.LSTM——中间缺了一环。
现在拆成两步:
| Phase 1.0 | Phase 1.1 | |
|---|---|---|
| 目录 | phase1_0_rnn/ |
phase1_1_lstm/ |
| 模型 | 2层 BiRNN + 2层 RNN | 2层 BiLSTM + 2层 LSTM |
| 状态 | 只有 h_t | h_t + c_t(细胞状态) |
| 门控 | 无(纯 tanh) | 三扇门(f/i/o) |
| 梯度路径 | 连乘 tanh(W) → 消失 | 加性直通 → 保持 |
| 参数 | ~8M | ~16M |
| 目的 | 理解记忆原理 | 体会门控解决了什么 |
Phase 1.0:Vanilla RNN — 推车
RNN 如何"记忆"
h_t = tanh(W_hh · h_{t-1} + W_ih · x_t + b)
每个时间步,RNN 把 上一个隐藏状态 h_{t-1}(历史)和 当前输入 x_t(现在)揉在一起,过一层 tanh,得到新的隐藏状态 h_t。
直观类比:h_t 像一张便签。你每读一个词,擦掉便签上的一部分旧内容,写上当前词的摘要。
- 读第 1 个词 → h_1 记着 “I”
- 读第 5 个词 → h_5 记着一些语法信息 + 部分语义
- 读第 50 个词 → h_50 中,第 1 个词的信息早已被 tanh(W) 连乘 50 次碾成粉末
tanh 的罪:|tanh'(…)| ≤ 1,但实际值通常 « 1。乘以一个权重矩阵 W(通常元素也 < 1),50 步连乘后 ≈ 0。
BPTT:为什么梯度消失了
反向传播需要计算 ∂L/∂W。对于时间步 t=1 的单词,梯度需要穿过 T-1 个时间步才能到达损失函数:
∂L/∂h_1 = ∂L/∂h_T · (Π_{k=1}^{T-1} ∂h_{k+1}/∂h_k)
其中 ∂h_{k+1}/∂h_k = W_hh^T · diag(tanh'(…))
连乘项中 W_hh 的特征值如果 < 1 → 梯度消失;如果 > 1 → 梯度爆炸。tanh 的导数范围是 (0, 1],进一步压低梯度。
结论:50 个词的句子,前 10 个词 ≈ 没有梯度信号。训练只学到了靠 target 端最近的词。
代码级变化(vs. Phase 0)
# Phase 0: 只处理 target
class DummyModel(nn.Module):
def forward(self, src, tgt_in):
return self.out(self.embed(tgt_in))
# Phase 1.0: Encoder 处理 src,Decoder 处理 tgt
class Seq2Seq(nn.Module):
def forward(self, src, tgt, src_len):
enc_out, hidden = self.encoder(src, src_len) # 编码源句
logits, _ = self.decoder(tgt, hidden) # hidden 是唯一的信息传递者
return logits
注意 nn.RNN 只返回 hidden(没有 cell),Decorder 签名更简洁。代价就是"便签"容量极其有限。
Phase 1.1:LSTM — 锻造淬火
LSTM 的三扇门 + 细胞状态
f_t = σ(W_f · [h_{t-1}, x_t]) # 遗忘门:丢掉多少旧记忆
i_t = σ(W_i · [h_{t-1}, x_t]) # 输入门:写入多少新信息
o_t = σ(W_o · [h_{t-1}, x_t]) # 输出门:暴露多少给外面
c̃_t = tanh(W_c · [h_{t-1}, x_t]) # 候选记忆
c_t = f_t ⊙ c_{t-1} + i_t ⊙ c̃_t # ← 这里是关键
h_t = o_t ⊙ tanh(c_t)
关键洞察:c_t 的更新是 加性更新——f_t * 旧记忆 + i_t * 新信息。没有连乘 tanh!
- 忘了 0.3 旧记忆 + 写入 0.7 新信息 = 新细胞状态
- 梯度通过 c_{t-1}→c_t 可以"直线传递"(LSTM 论文称 “Constant Error Carousel”)
- 遗忘门 f_t 可以在训练中学到 “保留重要信息” vs “丢弃无关信息”
类比:c_t 是一张有"编辑权限"的便签。RNN 只能整张擦掉重写,LSTM 可以选择性地划掉某些字、补上新字。因此 50 步后,第一步的"签名"仍然可能清晰可辨。
RNN vs LSTM 实战对比
| Phase 1.0 (RNN) | Phase 1.1 (LSTM) | |
|---|---|---|
| 隐状态 | h_t (1 份) | (hidden, cell) (2 份) |
| 参数 | H×H (单矩阵) | H×H×4 (三扇门 + 候选) |
| Decoder 接口 | decoder(tgt, hidden) |
decoder(tgt, (enc_out, (hidden, cell))) |
| 长句梯度 | 趋近于 0 | 可保持 |
| BLEU 预期 | 很低 | 高于 RNN,但 < 5(仍信息瓶颈) |
代码级变化
# Phase 1.0: RNN
self.rnn = nn.RNN(input_size, hidden_size, num_layers, ...)
output, hidden = self.rnn(embedded, hidden)
# hidden: (num_layers, B, H)
# Phase 1.1: LSTM
self.rnn = nn.LSTM(input_size, hidden_size, num_layers, ...)
output, (hidden, cell) = self.rnn(embedded, (hidden, cell))
# hidden: (num_layers, B, H)
# cell: (num_layers, B, H) ← 细胞状态,梯度高速通道
PyTorch 层面: nn.RNN → nn.LSTM 仅改一行。但 LSTM 需要额外维护 cell state,Decoder 的 forward 签名从 (tgt, hidden) 变成 (tgt, (hidden, cell))。
提问环节
Q1: RNN 记忆的本质
RNN 把历史信息"折叠"进一个向量 h_t。这更像"压缩/蒸馏"还是更像"遗忘/丢弃"?如果用信息论的语言描述:h_t 的信息容量由 hidden_size 决定——256 维的向量能无损"记住"多长的句子?
HM:参考 Phase 0 Q2 的 hash 碰撞理论——如果 h_t 是源语言的 hash 摘要,Decoder 需要从这个固定大小 hash 中"解压"出目标语言。信息容量随句子变长指数下降,这就是信息瓶颈。
Q2: tanh 的驯服
RNN 用 tanh 激活 → 输出在 (-1, 1) 之间 → 导数最大为 1。LSTM 用 sigmoid (遗忘/输入/输出门) + tanh (候选记忆 + 输出门)。
为什么门控函数用 sigmoid 而不是 tanh?如果所有门都换成 tanh 会发生什么?
Q3: 细胞状态的加法
LSTM 的核心创新是 c_t = f_t*c_{t-1} + i_t*c̃_t——加法而非连乘。除了梯度直通,这种"选择性遗忘 + 选择性写入"与你之前提到的"最低频率过滤"(min_freq=2)有什么本质相似之处?
Q4: RNN→LSTM 的参数膨胀
LSTM 比 RNN 参数多了约 4 倍。在 IWSLT14 这种小数据集(160K 句)上,这个参数冗余是浪费还是必要?如果用 GRU(2 扇门,参数约为 LSTM 的 3/4),BLEU 会不会差不多?
Q5: 信息瓶颈仍然存在
即使换上 LSTM,“上下文向量 c”(Encoder 最后 hidden state)仍然是固定 512 维。Phase 2 引入 Attention 后,Decoder 可以"跳过 c,直接看 Encoder 所有时间步的输出"。
从 LSTM 的 c_t(细胞内记忆)到 Attention 的 c(动态加权上下文),这两个 “c” 从命名到作用有什么本质不同?LSTM 的 c_t 能不能替代 Attention?如果不能,为什么?
Q6: 训练速度与算力
| Phase | 参数量 | 每 epoch 时间(预估) | 梯度瓶颈 |
|---|---|---|---|
| 1.0 (RNN) | ~8M | 快 | 消失严重 |
| 1.1 (LSTM) | ~16M | 中 | 缓解 |
| 2.0 (LSTM+Attn) | ~20M | 慢 | Attention 的计算代价 |
在显存限制下(GTX 3090 24GB),如果 batch_size 从 64 降到 16 才能跑 LSTM,低 batch 引入的噪音和 LSTM 门控的稳定性之间如何权衡?
May the Code be with us.
License: GPLv3
本文《SameTime》系列采用 GNU 通用公共许可证第三版 (GNU General Public License v3.0) 协议进行开源发布与分发。允许任何形式的复制、修改和分发,但必须继承相同的开源协议,承认在算力宇宙中所有的迭代与变异。