Long Short-Term Memory/zh
| Article | |
|---|---|
| Topic area | Deep Learning |
| Prerequisites | Recurrent Neural Network, Backpropagation, Gradient Descent |
概述
長短期記憶(LSTM)是一種循環神經網絡架構,旨在學習序列數據中的長程依賴關係,同時避免困擾標準循環網絡的梯度消失和梯度爆炸問題。由 Sepp Hochreiter 和 Juergen Schmidhuber 於 1997 年提出,[1]LSTM 單元在循環單元的基礎上增加了一個內部細胞狀態,由乘性門控加以保護,這些門控學習有選擇地讀取、寫入和擦除信息。在大約二十年的時間裡,LSTM 一直是序列建模的主導神經網絡架構,為語音識別、機器翻譯、手寫識別和語言建模等生產系統提供動力,直到在並行訓練占主導的領域中被 Transformer 大規模取代。它在低延遲在線推理、時間序列預測、控制以及任何嚴格的序列因果性和有界記憶是優勢而非限制的場景中,仍被廣泛使用。
LSTM 的核心思想是常數誤差傳送帶:細胞狀態上的一個線性自循環,使梯度能夠在任意多個時間步上向後傳播而不會發生指數衰減。門控學習何時讓信息進入該循環、何時遺忘信息以及何時將其暴露給網絡的其餘部分。
動機:梯度消失問題
一個簡單的循環網絡維護一個隱藏狀態 $ h_t = \sigma(W_h h_{t-1} + W_x x_t + b) $,並通過時間反向傳播(BPTT)進行訓練。時間 $ T $ 的損失相對於時間 $ t $ 的隱藏狀態的梯度,涉及循環映射的 $ T - t $ 個雅可比矩陣的連乘。當這些雅可比矩陣的譜半徑小於 1 時,梯度隨時間跨度呈指數級衰減,網絡無法學習相隔超過幾步的依賴關係;當其大於 1 時,梯度發生爆炸,訓練發散。[2]Hochreiter 在 1991 年的論文中將此確定為訓練深層循環模型的核心障礙。
LSTM 在架構層面解決這一問題:它們不依賴精細的初始化或歸一化,而是引入一條貫穿時間的通路,梯度沿此通路默認以單位雅可比矩陣傳播,並通過學習到的門控僅在有用時才擾動這一流動。
架構
標準的 LSTM 單元在時間步 $ t $ 接收前一個隱藏狀態 $ h_{t-1} $、前一個細胞狀態 $ c_{t-1} $ 和輸入向量 $ x_t $,並產生新的隱藏狀態和細胞狀態。三個 Sigmoid 門控控制信息的流動:
- 遺忘門 $ f_t $ 決定擦除前一個細胞狀態的哪些分量。
- 輸入門 $ i_t $ 決定將候選更新 $ \tilde{c}_t $ 的哪些分量寫入細胞狀態。
- 輸出門 $ o_t $ 決定將細胞狀態的哪些分量作為隱藏狀態暴露出來。
每個門控都是 $ [h_{t-1}, x_t] $ 的一個學習到的線性投影后接 Sigmoid 函數,因此其輸出位於 $ (0,1) $ 區間,在逐元素乘法下充當軟二值掩碼。細胞狀態以加性方式更新,這一結構性特性保留了梯度流。
形式化定義
設 $ \sigma $ 表示邏輯 Sigmoid 函數,$ \odot $ 表示逐元素乘法。對每個門控,使用權重 $ W_\bullet $、循環權重 $ U_\bullet $ 和偏置 $ b_\bullet $,LSTM 的更新規則為
$ {\displaystyle f_t = \sigma(W_f x_t + U_f h_{t-1} + b_f)} $
$ {\displaystyle i_t = \sigma(W_i x_t + U_i h_{t-1} + b_i)} $
$ {\displaystyle o_t = \sigma(W_o x_t + U_o h_{t-1} + b_o)} $
$ {\displaystyle \tilde{c}_t = \tanh(W_c x_t + U_c h_{t-1} + b_c)} $
$ {\displaystyle c_t = f_t \odot c_{t-1} + i_t \odot \tilde{c}_t} $
$ {\displaystyle h_t = o_t \odot \tanh(c_t)} $
前四個方程可以作為一次寬度為 $ 4d $ 的矩陣乘法來計算,其中 $ d $ 是隱藏層大小,這正是高效實現的寫法。細胞狀態更新 $ c_t = f_t \odot c_{t-1} + i_t \odot \tilde{c}_t $ 即為常數誤差傳送帶:當 $ f_t \approx 1 $ 時,關於 $ c_{t-1} $ 的梯度沿加性路徑向後傳播,不會受到非線性擠壓。
一種被廣泛採用的啟發式做法是將遺忘門的偏置 $ b_f $ 初始化為較大的正值(通常為 1 或 2),使得訓練開始時 $ f_t \approx 1 $,細胞狀態默認傾向於記憶而非遺忘。[3]
訓練與推理
LSTM 通過截斷的時間反向傳播進行訓練:在長度為 $ T $ 的窗口上對損失求和,梯度沿展開的網絡向後傳播,通常配合梯度裁剪以控制該架構未能消除的偶發梯度爆炸事件。常用的優化器如 Adam 或帶動量的 SGD 在實踐中表現良好。正則化通常將權重衰減與應用於非循環連接的 Dropout 相結合,因為對循環路徑直接應用 dropout 會破壞梯度高速通道;諸如變分 dropout 之類的變體在時間維度上使用相同的掩碼以保留這一通道。[4]
推理本質上是順序的:$ c_t $ 依賴於 $ c_{t-1} $,因此每個詞元的耗時無法像 Transformer 的前向傳播那樣在時間軸上並行化。這種順序依賴是 LSTM 在大規模預訓練中失去優勢的主要原因,但同一特性也使其每詞元開銷在序列長度上保持恆定、內存占用有界,這對流式推理和端側部署頗具吸引力。
變體
已提出許多 LSTM 變體;消融研究表明,大多數僅帶來微小提升,而遺忘門和輸出激活函數才是起關鍵作用的組件。[5]
- 窺孔連接允許門控直接讀取細胞狀態,將 $ U_\bullet h_{t-1} $ 替換為包含 $ c_{t-1} $ 或 $ c_t $ 的項。
- 耦合輸入和遺忘門將 $ i_t = 1 - f_t $ 綁定,使門控參數數量減半。
- 門控循環單元(GRU)將細胞狀態與隱藏狀態合併,使用兩個門控而非三個,通常以更少的參數即可與 LSTM 性能相當。
- 雙向 LSTM 沿相反方向運行兩個 LSTM 並拼接它們的隱藏狀態,為標註等非因果任務同時暴露過去和未來的上下文。
- ConvLSTM 將矩陣乘法替換為卷積,保留空間結構,適用於視頻和時空預測。
- 堆疊 LSTM 在縱向上堆疊若干 LSTM 層,第 $ \ell $ 層的隱藏狀態作為第 $ \ell+1 $ 層的輸入;這是 Transformer 出現之前生產級神經機器翻譯系統的標準配置。
與相關模型的比較
與普通的 RNN 相比,LSTM 以每步約 4 倍的參數量和計算量為代價,換取大幅改善的梯度流以及在經驗上學習間隔數百步依賴的能力。與 GRU 相比,LSTM 擁有單獨的細胞狀態和一個額外的門;在大多數基準上兩者表現接近,GRU 略快,而 LSTM 在長序列上略具表達力。與 Transformer 相比,LSTM 的每詞元內存和計算開銷恆定,具有嚴格的因果歸納偏置,但缺乏對任意位置的直接注意力,且無法以同等程度並行訓練,這限制了它從大規模預訓練中獲益的程度。近期的線性注意力和 Mamba 等狀態空間模型在更易並行化的形式下,重新探討了與 LSTM 細胞狀態密切相關的思想。[6]
應用
在 2010 年代,LSTM 是序列建模的主力架構。值得關注的部署包括 Google 的神經機器翻譯系統、[7]用於大詞彙量語音識別的聲學模型、[8]手寫識別、光學字符識別以及早期的語言模型。它們在時間序列預測、金融建模、異常檢測以及在部分觀測下運行的強化學習策略中仍然常見,在這些場景下,有界的循環狀態本身可作為對歷史的有用學習摘要。
局限性
LSTM 並不能消除梯度爆炸,只能緩解梯度消失;梯度裁剪仍是標準做法。其順序的前向傳播阻礙了使 Transformer 在加速器上高效的那種教師強制並行訓練,這就是為何將參數規模擴展到十億乃至萬億詞元幾乎專屬於基於注意力的模型。細胞狀態是一個固定寬度的向量,因此 LSTM 原則上無法存儲無界量的歷史上下文;Transformer 通過關注遠端詞元解決的長程檢索任務,在 LSTM 中只能被壓縮到循環狀態中。在經驗上,在需要銳利、可按內容尋址的查找任務上,LSTM 也傾向於表現不如基於注意力的模型,而在以中短程時間結構為主的任務上仍然具有競爭力甚至更優。
參考文獻
- ↑ Hochreiter, S. and Schmidhuber, J., "Long Short-Term Memory", Neural Computation 9(8):1735-1780, 1997.
- ↑ Bengio, Y., Simard, P., and Frasconi, P., "Learning long-term dependencies with gradient descent is difficult", IEEE Transactions on Neural Networks 5(2):157-166, 1994.
- ↑ Jozefowicz, R., Zaremba, W., and Sutskever, I., "An empirical exploration of recurrent network architectures", ICML, 2015.
- ↑ Gal, Y. and Ghahramani, Z., "A theoretically grounded application of dropout in recurrent neural networks", NeurIPS, 2016.
- ↑ Greff, K., Srivastava, R. K., Koutnik, J., Steunebrink, B. R., and Schmidhuber, J., "LSTM: A search space odyssey", IEEE TNNLS 28(10):2222-2232, 2017.
- ↑ Gu, A. and Dao, T., "Mamba: Linear-time sequence modeling with selective state spaces", arXiv:2312.00752, 2023.
- ↑ Wu, Y. et al., "Google's neural machine translation system: Bridging the gap between human and machine translation", arXiv:1609.08144, 2016.
- ↑ Sak, H., Senior, A., and Beaufays, F., "Long short-term memory recurrent neural network architectures for large scale acoustic modeling", Interspeech, 2014.