Long Short-Term Memory/zh
| Article | |
|---|---|
| Topic area | Deep Learning |
| Prerequisites | Recurrent Neural Network, Backpropagation, Gradient Descent |
概述
长短期记忆(LSTM)是一种循环神经网络架构,旨在学习序列数据中的长程依赖关系,同时避免困扰标准循环网络的梯度消失和梯度爆炸问题。由 Sepp Hochreiter 和 Juergen Schmidhuber 于 1997 年提出,[1]LSTM 单元在循环单元的基础上增加了一个内部细胞状态,由乘性门控加以保护,这些门控学习有选择地读取、写入和擦除信息。在大约二十年的时间里,LSTM 一直是序列建模的主导神经网络架构,为语音识别、机器翻译、手写识别和语言建模等生产系统提供动力,直到在并行训练占主导的领域中被 Transformer 大规模取代。它在低延迟在线推理、时间序列预测、控制以及任何严格的序列因果性和有界记忆是优势而非限制的场景中,仍被广泛使用。
LSTM 的核心思想是常数误差传送带:细胞状态上的一个线性自循环,使梯度能够在任意多个时间步上向后传播而不会发生指数衰减。门控学习何时让信息进入该循环、何时遗忘信息以及何时将其暴露给网络的其余部分。
动机:梯度消失问题
一个简单的循环网络维护一个隐藏状态 $ h_t = \sigma(W_h h_{t-1} + W_x x_t + b) $,并通过时间反向传播(BPTT)进行训练。时间 $ T $ 的损失相对于时间 $ t $ 的隐藏状态的梯度,涉及循环映射的 $ T - t $ 个雅可比矩阵的连乘。当这些雅可比矩阵的谱半径小于 1 时,梯度随时间跨度呈指数级衰减,网络无法学习相隔超过几步的依赖关系;当其大于 1 时,梯度发生爆炸,训练发散。[2]Hochreiter 在 1991 年的论文中将此确定为训练深层循环模型的核心障碍。
LSTM 在架构层面解决这一问题:它们不依赖精细的初始化或归一化,而是引入一条贯穿时间的通路,梯度沿此通路默认以单位雅可比矩阵传播,并通过学习到的门控仅在有用时才扰动这一流动。
架构
标准的 LSTM 单元在时间步 $ t $ 接收前一个隐藏状态 $ h_{t-1} $、前一个细胞状态 $ c_{t-1} $ 和输入向量 $ x_t $,并产生新的隐藏状态和细胞状态。三个 Sigmoid 门控控制信息的流动:
- 遗忘门 $ f_t $ 决定擦除前一个细胞状态的哪些分量。
- 输入门 $ i_t $ 决定将候选更新 $ \tilde{c}_t $ 的哪些分量写入细胞状态。
- 输出门 $ o_t $ 决定将细胞状态的哪些分量作为隐藏状态暴露出来。
每个门控都是 $ [h_{t-1}, x_t] $ 的一个学习到的线性投影后接 Sigmoid 函数,因此其输出位于 $ (0,1) $ 区间,在逐元素乘法下充当软二值掩码。细胞状态以加性方式更新,这一结构性特性保留了梯度流。
形式化定义
设 $ \sigma $ 表示逻辑 Sigmoid 函数,$ \odot $ 表示逐元素乘法。对每个门控,使用权重 $ W_\bullet $、循环权重 $ U_\bullet $ 和偏置 $ b_\bullet $,LSTM 的更新规则为
$ {\displaystyle f_t = \sigma(W_f x_t + U_f h_{t-1} + b_f)} $
$ {\displaystyle i_t = \sigma(W_i x_t + U_i h_{t-1} + b_i)} $
$ {\displaystyle o_t = \sigma(W_o x_t + U_o h_{t-1} + b_o)} $
$ {\displaystyle \tilde{c}_t = \tanh(W_c x_t + U_c h_{t-1} + b_c)} $
$ {\displaystyle c_t = f_t \odot c_{t-1} + i_t \odot \tilde{c}_t} $
$ {\displaystyle h_t = o_t \odot \tanh(c_t)} $
前四个方程可以作为一次宽度为 $ 4d $ 的矩阵乘法来计算,其中 $ d $ 是隐藏层大小,这正是高效实现的写法。细胞状态更新 $ c_t = f_t \odot c_{t-1} + i_t \odot \tilde{c}_t $ 即为常数误差传送带:当 $ f_t \approx 1 $ 时,关于 $ c_{t-1} $ 的梯度沿加性路径向后传播,不会受到非线性挤压。
一种被广泛采用的启发式做法是将遗忘门的偏置 $ b_f $ 初始化为较大的正值(通常为 1 或 2),使得训练开始时 $ f_t \approx 1 $,细胞状态默认倾向于记忆而非遗忘。[3]
训练与推理
LSTM 通过截断的时间反向传播进行训练:在长度为 $ T $ 的窗口上对损失求和,梯度沿展开的网络向后传播,通常配合梯度裁剪以控制该架构未能消除的偶发梯度爆炸事件。常用的优化器如 Adam 或带动量的 SGD 在实践中表现良好。正则化通常将权重衰减与应用于非循环连接的 Dropout 相结合,因为对循环路径直接应用 dropout 会破坏梯度高速通道;诸如变分 dropout 之类的变体在时间维度上使用相同的掩码以保留这一通道。[4]
推理本质上是顺序的:$ c_t $ 依赖于 $ c_{t-1} $,因此每个词元的耗时无法像 Transformer 的前向传播那样在时间轴上并行化。这种顺序依赖是 LSTM 在大规模预训练中失去优势的主要原因,但同一特性也使其每词元开销在序列长度上保持恒定、内存占用有界,这对流式推理和端侧部署颇具吸引力。
变体
已提出许多 LSTM 变体;消融研究表明,大多数仅带来微小提升,而遗忘门和输出激活函数才是起关键作用的组件。[5]
- 窥孔连接允许门控直接读取细胞状态,将 $ U_\bullet h_{t-1} $ 替换为包含 $ c_{t-1} $ 或 $ c_t $ 的项。
- 耦合输入和遗忘门将 $ i_t = 1 - f_t $ 绑定,使门控参数数量减半。
- 门控循环单元(GRU)将细胞状态与隐藏状态合并,使用两个门控而非三个,通常以更少的参数即可与 LSTM 性能相当。
- 双向 LSTM 沿相反方向运行两个 LSTM 并拼接它们的隐藏状态,为标注等非因果任务同时暴露过去和未来的上下文。
- ConvLSTM 将矩阵乘法替换为卷积,保留空间结构,适用于视频和时空预测。
- 堆叠 LSTM 在纵向上堆叠若干 LSTM 层,第 $ \ell $ 层的隐藏状态作为第 $ \ell+1 $ 层的输入;这是 Transformer 出现之前生产级神经机器翻译系统的标准配置。
与相关模型的比较
与普通的 RNN 相比,LSTM 以每步约 4 倍的参数量和计算量为代价,换取大幅改善的梯度流以及在经验上学习间隔数百步依赖的能力。与 GRU 相比,LSTM 拥有单独的细胞状态和一个额外的门;在大多数基准上两者表现接近,GRU 略快,而 LSTM 在长序列上略具表达力。与 Transformer 相比,LSTM 的每词元内存和计算开销恒定,具有严格的因果归纳偏置,但缺乏对任意位置的直接注意力,且无法以同等程度并行训练,这限制了它从大规模预训练中获益的程度。近期的线性注意力和 Mamba 等状态空间模型在更易并行化的形式下,重新探讨了与 LSTM 细胞状态密切相关的思想。[6]
应用
在 2010 年代,LSTM 是序列建模的主力架构。值得关注的部署包括 Google 的神经机器翻译系统、[7]用于大词汇量语音识别的声学模型、[8]手写识别、光学字符识别以及早期的语言模型。它们在时间序列预测、金融建模、异常检测以及在部分观测下运行的强化学习策略中仍然常见,在这些场景下,有界的循环状态本身可作为对历史的有用学习摘要。
局限性
LSTM 并不能消除梯度爆炸,只能缓解梯度消失;梯度裁剪仍是标准做法。其顺序的前向传播阻碍了使 Transformer 在加速器上高效的那种教师强制并行训练,这就是为何将参数规模扩展到十亿乃至万亿词元几乎专属于基于注意力的模型。细胞状态是一个固定宽度的向量,因此 LSTM 原则上无法存储无界量的历史上下文;Transformer 通过关注远端词元解决的长程检索任务,在 LSTM 中只能被压缩到循环状态中。在经验上,在需要锐利、可按内容寻址的查找任务上,LSTM 也倾向于表现不如基于注意力的模型,而在以中短程时间结构为主的任务上仍然具有竞争力甚至更优。
参考文献
- ↑ Hochreiter, S. and Schmidhuber, J., "Long Short-Term Memory", Neural Computation 9(8):1735-1780, 1997.
- ↑ Bengio, Y., Simard, P., and Frasconi, P., "Learning long-term dependencies with gradient descent is difficult", IEEE Transactions on Neural Networks 5(2):157-166, 1994.
- ↑ Jozefowicz, R., Zaremba, W., and Sutskever, I., "An empirical exploration of recurrent network architectures", ICML, 2015.
- ↑ Gal, Y. and Ghahramani, Z., "A theoretically grounded application of dropout in recurrent neural networks", NeurIPS, 2016.
- ↑ Greff, K., Srivastava, R. K., Koutnik, J., Steunebrink, B. R., and Schmidhuber, J., "LSTM: A search space odyssey", IEEE TNNLS 28(10):2222-2232, 2017.
- ↑ Gu, A. and Dao, T., "Mamba: Linear-time sequence modeling with selective state spaces", arXiv:2312.00752, 2023.
- ↑ Wu, Y. et al., "Google's neural machine translation system: Bridging the gap between human and machine translation", arXiv:1609.08144, 2016.
- ↑ Sak, H., Senior, A., and Beaufays, F., "Long short-term memory recurrent neural network architectures for large scale acoustic modeling", Interspeech, 2014.