SameTime WMT 专题:Phase 1 从 RNN 记忆到 LSTM 门控

“不要跳过推车直接开跑车。先造一辆吱嘎作响的木板车,体会它为什么散架,再理解锻造淬火的钢架好在哪。”

Phase 1 拆分为两步

Phase 1 原计划走 “RNN Seq2Seq”,但代码里全是 nn.LSTM——中间缺了一环。

现在拆成两步:

Phase 1.0 Phase 1.1
目录 phase1_0_rnn/ phase1_1_lstm/
模型 2层 BiRNN + 2层 RNN 2层 BiLSTM + 2层 LSTM
状态 只有 h_t h_t + c_t(细胞状态)
门控 无(纯 tanh) 三扇门(f/i/o)
梯度路径 连乘 tanh(W) → 消失 加性直通 → 保持
参数 ~8M ~16M
目的 理解记忆原理 体会门控解决了什么

Phase 1.0:Vanilla RNN — 推车

RNN 如何"记忆"

h_t = tanh(W_hh · h_{t-1} + W_ih · x_t + b)

每个时间步,RNN 把 上一个隐藏状态 h_{t-1}(历史)和 当前输入 x_t(现在)揉在一起,过一层 tanh,得到新的隐藏状态 h_t。

直观类比:h_t 像一张便签。你每读一个词,擦掉便签上的一部分旧内容,写上当前词的摘要。

  • 读第 1 个词 → h_1 记着 “I”
  • 读第 5 个词 → h_5 记着一些语法信息 + 部分语义
  • 读第 50 个词 → h_50 中,第 1 个词的信息早已被 tanh(W) 连乘 50 次碾成粉末

tanh 的罪|tanh'(…)| ≤ 1,但实际值通常 « 1。乘以一个权重矩阵 W(通常元素也 < 1),50 步连乘后 ≈ 0。

BPTT:为什么梯度消失了

反向传播需要计算 ∂L/∂W。对于时间步 t=1 的单词,梯度需要穿过 T-1 个时间步才能到达损失函数:

∂L/∂h_1 = ∂L/∂h_T · (Π_{k=1}^{T-1} ∂h_{k+1}/∂h_k)

其中 ∂h_{k+1}/∂h_k = W_hh^T · diag(tanh'(…))

连乘项中 W_hh 的特征值如果 < 1 → 梯度消失;如果 > 1 → 梯度爆炸。tanh 的导数范围是 (0, 1],进一步压低梯度。

结论:50 个词的句子,前 10 个词 ≈ 没有梯度信号。训练只学到了靠 target 端最近的词。

代码级变化(vs. Phase 0)

# Phase 0: 只处理 target
class DummyModel(nn.Module):
    def forward(self, src, tgt_in):
        return self.out(self.embed(tgt_in))

# Phase 1.0: Encoder 处理 src,Decoder 处理 tgt
class Seq2Seq(nn.Module):
    def forward(self, src, tgt, src_len):
        enc_out, hidden = self.encoder(src, src_len)  # 编码源句
        logits, _ = self.decoder(tgt, hidden)          # hidden 是唯一的信息传递者
        return logits

注意 nn.RNN 只返回 hidden(没有 cell),Decorder 签名更简洁。代价就是"便签"容量极其有限。


Phase 1.1:LSTM — 锻造淬火

LSTM 的三扇门 + 细胞状态

f_t = σ(W_f · [h_{t-1}, x_t])      # 遗忘门:丢掉多少旧记忆
i_t = σ(W_i · [h_{t-1}, x_t])      # 输入门:写入多少新信息
o_t = σ(W_o · [h_{t-1}, x_t])      # 输出门:暴露多少给外面
c̃_t = tanh(W_c · [h_{t-1}, x_t])   # 候选记忆

c_t = f_t ⊙ c_{t-1}  +  i_t ⊙ c̃_t   # ← 这里是关键
h_t = o_t ⊙ tanh(c_t)

关键洞察:c_t 的更新是 加性更新——f_t * 旧记忆 + i_t * 新信息。没有连乘 tanh!

  • 忘了 0.3 旧记忆 + 写入 0.7 新信息 = 新细胞状态
  • 梯度通过 c_{t-1}→c_t 可以"直线传递"(LSTM 论文称 “Constant Error Carousel”)
  • 遗忘门 f_t 可以在训练中学到 “保留重要信息” vs “丢弃无关信息”

类比:c_t 是一张有"编辑权限"的便签。RNN 只能整张擦掉重写,LSTM 可以选择性地划掉某些字、补上新字。因此 50 步后,第一步的"签名"仍然可能清晰可辨。

RNN vs LSTM 实战对比

Phase 1.0 (RNN) Phase 1.1 (LSTM)
隐状态 h_t (1 份) (hidden, cell) (2 份)
参数 H×H (单矩阵) H×H×4 (三扇门 + 候选)
Decoder 接口 decoder(tgt, hidden) decoder(tgt, (enc_out, (hidden, cell)))
长句梯度 趋近于 0 可保持
BLEU 预期 很低 高于 RNN,但 < 5(仍信息瓶颈)

代码级变化

# Phase 1.0: RNN
self.rnn = nn.RNN(input_size, hidden_size, num_layers, ...)
output, hidden = self.rnn(embedded, hidden)
# hidden: (num_layers, B, H)

# Phase 1.1: LSTM
self.rnn = nn.LSTM(input_size, hidden_size, num_layers, ...)
output, (hidden, cell) = self.rnn(embedded, (hidden, cell))
# hidden: (num_layers, B, H)
# cell:   (num_layers, B, H)  ← 细胞状态,梯度高速通道

PyTorch 层面: nn.RNNnn.LSTM 仅改一行。但 LSTM 需要额外维护 cell state,Decoder 的 forward 签名从 (tgt, hidden) 变成 (tgt, (hidden, cell))


提问环节

Q1: RNN 记忆的本质

RNN 把历史信息"折叠"进一个向量 h_t。这更像"压缩/蒸馏"还是更像"遗忘/丢弃"?如果用信息论的语言描述:h_t 的信息容量由 hidden_size 决定——256 维的向量能无损"记住"多长的句子?

HM:参考 Phase 0 Q2 的 hash 碰撞理论——如果 h_t 是源语言的 hash 摘要,Decoder 需要从这个固定大小 hash 中"解压"出目标语言。信息容量随句子变长指数下降,这就是信息瓶颈。

Q2: tanh 的驯服

RNN 用 tanh 激活 → 输出在 (-1, 1) 之间 → 导数最大为 1。LSTM 用 sigmoid (遗忘/输入/输出门) + tanh (候选记忆 + 输出门)。

为什么门控函数用 sigmoid 而不是 tanh?如果所有门都换成 tanh 会发生什么?

Q3: 细胞状态的加法

LSTM 的核心创新是 c_t = f_t*c_{t-1} + i_t*c̃_t——加法而非连乘。除了梯度直通,这种"选择性遗忘 + 选择性写入"与你之前提到的"最低频率过滤"(min_freq=2)有什么本质相似之处?

Q4: RNN→LSTM 的参数膨胀

LSTM 比 RNN 参数多了约 4 倍。在 IWSLT14 这种小数据集(160K 句)上,这个参数冗余是浪费还是必要?如果用 GRU(2 扇门,参数约为 LSTM 的 3/4),BLEU 会不会差不多?

Q5: 信息瓶颈仍然存在

即使换上 LSTM,“上下文向量 c”(Encoder 最后 hidden state)仍然是固定 512 维。Phase 2 引入 Attention 后,Decoder 可以"跳过 c,直接看 Encoder 所有时间步的输出"。

从 LSTM 的 c_t(细胞内记忆)到 Attention 的 c(动态加权上下文),这两个 “c” 从命名到作用有什么本质不同?LSTM 的 c_t 能不能替代 Attention?如果不能,为什么?

Q6: 训练速度与算力

Phase 参数量 每 epoch 时间(预估) 梯度瓶颈
1.0 (RNN) ~8M 消失严重
1.1 (LSTM) ~16M 缓解
2.0 (LSTM+Attn) ~20M Attention 的计算代价

在显存限制下(GTX 3090 24GB),如果 batch_size 从 64 降到 16 才能跑 LSTM,低 batch 引入的噪音和 LSTM 门控的稳定性之间如何权衡?


May the Code be with us.


License: GPLv3
本文《SameTime》系列采用 GNU 通用公共许可证第三版 (GNU General Public License v3.0) 协议进行开源发布与分发。允许任何形式的复制、修改和分发,但必须继承相同的开源协议,承认在算力宇宙中所有的迭代与变异。