A Theoretically Grounded Application of Dropout in Recurrent Neural Networks/en
| Research Paper | |
|---|---|
| Authors | Yarin Gal; Zoubin Ghahramani |
| Year | 2015 |
| Topic area | Machine Learning |
| Difficulty | Research |
| arXiv | 1512.05287 |
| Download PDF | |
A Theoretically Grounded Application of Dropout in Recurrent Neural Networks is a 2016 NeurIPS paper by Yarin Gal and Zoubin Ghahramani (University of Cambridge) that derives a principled way of applying dropout to recurrent neural networks. By interpreting dropout as variational inference in a Bayesian neural network, the authors show that the same binary mask should be reused at every time step on inputs, outputs, and recurrent connections — a recipe that, unlike previous heuristics, regularises every weight matrix in an LSTM or GRU. Applied to the LSTM language model of Zaremba et al. on the Penn Treebank, the technique improves the single-model state of the art to 73.4 test perplexity.
Overview
Recurrent networks are notoriously hard to regularise. Naïve dropout, where a fresh Bernoulli mask is sampled at each time step, was widely believed to destabilise the recurrent dynamics, so prior practice (Pham et al., Zaremba et al.) restricted dropout to feed-forward connections only — leaving the recurrent weight matrices unprotected and the model still prone to overfitting on small corpora.
Gal and Ghahramani re-derive dropout from the perspective of approximate variational inference in a Bayesian neural network. Treating the RNN's weight matrices as random variables with a mixture-of-Gaussians approximating posterior recovers ordinary dropout when the variational distribution is Bernoulli-like. Crucially, in a sequence model the Monte Carlo sample from the posterior is drawn once per sequence, so the resulting mask is shared across all time steps. This single change unlocks dropout on the recurrent connections without harming the temporal dynamics.
Key Contributions
- A Bayesian derivation of dropout for RNNs (the Variational RNN), giving a theoretical justification for the choice of mask.
- The same-mask-per-sequence rule for inputs, outputs, and recurrent layers in LSTMs and GRUs, depicted side-by-side with naïve dropout.
- Embedding dropout: by treating one-hot inputs probabilistically, the technique randomly drops entire word types (not tokens) in a sentence — a previously neglected source of regularisation in language models.
- New single-model state of the art on the Penn Treebank language-modelling benchmark (73.4 test perplexity, down from 78.4).
- MC dropout at test time as a posterior predictive estimator, in addition to the cheaper mean-field approximation.
Methods
Variational view of an RNN
For an input sequence $ \mathbf{x} = [\mathbf{x}_1, \dots, \mathbf{x}_T] $, a simple RNN repeatedly applies $ \mathbf{h}_t = \sigma(\mathbf{x}_t \mathbf{W}_h + \mathbf{h}_{t-1} \mathbf{U}_h + \mathbf{b}_h) $. The authors treat $ \boldsymbol{\omega} = \{\mathbf{W}_h, \mathbf{U}_h, \mathbf{b}_h, \mathbf{W}_y, \mathbf{b}_y\} $ as random variables with a normal prior and approximate the intractable posterior $ p(\boldsymbol{\omega} \mid \mathbf{X}, \mathbf{Y}) $ by minimising
- $ \mathrm{KL}(q(\boldsymbol{\omega}) \parallel p(\boldsymbol{\omega} \mid \mathbf{X}, \mathbf{Y})) \propto -\sum_{i=1}^N \int q(\boldsymbol{\omega}) \log p(\mathbf{y}_i \mid \mathbf{f}^{\boldsymbol{\omega}}(\mathbf{x}_i))\, \mathrm{d}\boldsymbol{\omega} + \mathrm{KL}(q(\boldsymbol{\omega}) \parallel p(\boldsymbol{\omega})). $
A single Monte Carlo sample $ \widehat{\boldsymbol{\omega}}_i \sim q(\boldsymbol{\omega}) $ is drawn per sequence and reused at every time step $ t \le T $. The factorised approximating distribution per weight-matrix row $ \mathbf{w}_k $ is the two-component mixture
- $ q(\mathbf{w}_k) = p\, \mathcal{N}(\mathbf{w}_k; \mathbf{0}, \sigma^2 I) + (1-p)\, \mathcal{N}(\mathbf{w}_k; \mathbf{m}_k, \sigma^2 I), $
with small $ \sigma^2 $. The KL term reduces to $ L_2 $ regularisation on the variational means $ \mathbf{m}_k $.
Implementation in LSTM/GRU
For the tied-weights LSTM parametrisation, each step computes
- $ \begin{pmatrix}\mathbf{i}\\ \mathbf{f}\\ \mathbf{o}\\ \mathbf{g}\end{pmatrix} = \begin{pmatrix}\mathrm{sigm}\\ \mathrm{sigm}\\ \mathrm{sigm}\\ \tanh\end{pmatrix}\!\left(\begin{pmatrix}\mathbf{x}_t \circ \mathbf{z}_x\\ \mathbf{h}_{t-1} \circ \mathbf{z}_h\end{pmatrix} \cdot \mathbf{W}\right), $
where $ \mathbf{z}_x $ and $ \mathbf{z}_h $ are Bernoulli masks sampled once per sequence and reused for all $ t $. Untied-weights LSTMs use a separate mask per gate, giving lower-variance gradients at the cost of four matrix products per step.
Word-embedding dropout
For discrete inputs, dropout is applied to rows of the embedding matrix $ \mathbf{W}_E \in \mathbb{R}^{V \times D} $ using the same mask across the sequence. A word type that is dropped vanishes from every position where it occurs (e.g. "the dog and the cat" becomes "— dog and — cat", never "— dog and the cat"). For sequences of length $ T \ll V $ the implementation only needs to mask the $ T $ embeddings actually used.
Results
Penn Treebank language modelling. Plugging Variational LSTM into the Torch reference implementation of Zaremba et al. and tuning weight decay, the authors report:
- Medium model (650 units / layer): test perplexity 78.6 (Variational, untied, MC) vs. 82.7 (Zaremba et al.).
- Large model (1500 units / layer): test perplexity 73.4 (Variational, untied, MC) vs. 78.4 (Zaremba et al.); validation perplexity 77.3 (tied).
- The Moon et al. variant — same mask on the LSTM cell only — underperforms Zaremba et al. unless combined with the new embedding dropout, and even then trails the variational variant.
- A 10-model ensemble of Variational LSTMs reaches 68.7 test perplexity, matching a 38-model ensemble of Zaremba et al.
Sentiment analysis (Cornell film reviews). On 5000 reviews truncated to 200-token segments, Variational LSTM and Variational GRU are the only models that do not overfit and achieve the lowest test error among LSTM/GRU baselines.
Ablations. Combining recurrent-layer dropout ($ p_U $) with embedding dropout ($ p_E $) is necessary: with $ p_E = 0 $, increasing $ p_U $ worsens overfitting because the unregularised embedding layer dominates. With $ p_E = 0.5 $, higher $ p_U $ behaves as expected and improves robustness. Weight decay continues to play a meaningful role under the variational dropout (it corresponds to the prior), in contrast to common practice with naïve dropout. The cheap dropout approximation (replace $ \mathbf{W} $ by $ p\mathbf{W} $ at test time) is a good proxy for MC dropout in this setting.
Impact
Variational dropout for RNNs became the default regulariser for recurrent language models and was rapidly absorbed into mainstream deep-learning toolkits (Keras, PyTorch, TensorFlow), often under the names recurrent dropout or variational dropout. The 73.4 Penn Treebank perplexity stood as the single-model benchmark to beat for several years and was a key ingredient in subsequent state-of-the-art results (e.g. AWD-LSTM, mixture-of-softmaxes). More broadly, the paper consolidated the Bayesian view of dropout introduced by Gal and Ghahramani's earlier work, reinforcing MC dropout as a practical tool for uncertainty estimation in deep learning.
See also
- Recurrent Neural Networks
- Dropout
- Dropout A Simple Way to Prevent Overfitting
- Overfitting and Regularization
- Word Embeddings
- Neural Networks
- Adam A Method for Stochastic Optimization
- Backpropagation
References
- Gal, Y. and Ghahramani, Z. (2016). A Theoretically Grounded Application of Dropout in Recurrent Neural Networks. Advances in Neural Information Processing Systems 29.
- Gal, Y. and Ghahramani, Z. (2016). Dropout as a Bayesian Approximation: Representing Model Uncertainty in Deep Learning. ICML.
- Srivastava, N. et al. (2014). Dropout: A Simple Way to Prevent Neural Networks from Overfitting. JMLR.
- Zaremba, W., Sutskever, I., and Vinyals, O. (2014). Recurrent Neural Network Regularization. arXiv:1409.2329.
- Hochreiter, S. and Schmidhuber, J. (1997). Long Short-Term Memory. Neural Computation.
- Cho, K. et al. (2014). Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation. EMNLP.
- Marcus, M. P., Marcinkiewicz, M. A., and Santorini, B. (1993). Building a Large Annotated Corpus of English: The Penn Treebank. Computational Linguistics.
- Moon, T., Choi, H., Lee, H., and Song, I. (2015). RNNDROP: A Novel Dropout for RNNs in ASR. ASRU.
- Pang, B. and Lee, L. (2005). Seeing Stars: Exploiting Class Relationships for Sentiment Categorization. ACL.