Linear Attention/en
| Article | |
|---|---|
| Topic area | Deep Learning |
| Prerequisites | Transformer, Softmax Function |
Overview
Linear attention is a family of attention mechanisms whose time and memory cost scales linearly with the sequence length, in contrast to the quadratic cost of the standard softmax attention used in the original Transformer architecture. The core idea is to replace the exponential similarity function in softmax attention with a kernel that can be written as an inner product of feature maps, then exploit the associativity of matrix multiplication to reorder the computation. The result is an attention layer that, for autoregressive decoding, can be expressed as a linear recurrent neural network with a fixed-size hidden state, making it attractive for long-context language modeling, streaming inference, and on-device deployment.[1]
Linear attention sacrifices some of the expressivity of softmax attention in exchange for asymptotic efficiency. Whether the trade is favorable depends on the workload: for very long sequences and constant-memory inference it is often a clear win, while for short to moderate contexts the constant factors and quality gap can erase the theoretical advantage. Modern variants such as gated linear attention and state-space-style models have narrowed the quality gap considerably and underpin several efficient sequence models of the early 2020s.
Background and motivation
Standard scaled dot-product attention computes, for queries $ Q \in \mathbb{R}^{N \times d} $, keys $ K \in \mathbb{R}^{N \times d} $, and values $ V \in \mathbb{R}^{N \times d} $, the output
$ {\displaystyle \mathrm{Attention}(Q, K, V) = \mathrm{softmax}\!\left(\tfrac{Q K^\top}{\sqrt{d}}\right) V.} $
The matrix $ Q K^\top $ has shape $ N \times N $, so both the time and memory cost are $ \mathcal{O}(N^2 d) $. For long documents, high-resolution images, or audio waveforms, this quadratic scaling becomes the dominant cost of training and inference, motivating a large literature on efficient attention. Linear attention is one of the simplest and most influential approaches in that literature.
A second motivation comes from autoregressive decoding. A standard Transformer decoder must, at each step, attend to all previous tokens, requiring storage of a key-value cache that grows with the context length. A linear-attention decoder maintains a fixed-size summary of the past, so per-token decoding cost is constant in the sequence length and memory does not grow.
Kernel reformulation
The starting point for linear attention is to write the (unnormalized) softmax similarity as a kernel $ k(q, k) = \exp(q^\top k / \sqrt{d}) $ and to replace it with a positive-definite kernel that admits an explicit feature map $ \phi : \mathbb{R}^d \to \mathbb{R}^{d'} $ such that
$ {\displaystyle k(q, k) = \phi(q)^\top \phi(k).} $
Substituting and writing the attention output for a single query $ q_i $ gives
$ {\displaystyle y_i = \frac{\sum_{j=1}^{N} \phi(q_i)^\top \phi(k_j)\, v_j}{\sum_{j=1}^{N} \phi(q_i)^\top \phi(k_j)} = \frac{\phi(q_i)^\top \sum_{j=1}^{N} \phi(k_j)\, v_j^\top}{\phi(q_i)^\top \sum_{j=1}^{N} \phi(k_j)}.} $
The crucial step is the second equality, which uses the associativity of matrix multiplication to pull $ \phi(q_i) $ outside the sums. The two sums
$ {\displaystyle S = \sum_{j=1}^{N} \phi(k_j)\, v_j^\top \in \mathbb{R}^{d' \times d}, \qquad z = \sum_{j=1}^{N} \phi(k_j) \in \mathbb{R}^{d'}} $
do not depend on the query, so they can be precomputed in $ \mathcal{O}(N d' d) $ time and reused across all queries. The full attention output is then computed in $ \mathcal{O}(N d' d) $ rather than $ \mathcal{O}(N^2 d) $, giving the eponymous linear scaling.
The same trick applies to the softmax denominator, which becomes a normalization by $ \phi(q_i)^\top z $. Some implementations omit the denominator entirely and rely on layer normalization to stabilize the output magnitudes.
Feature maps
The choice of $ \phi $ determines both the expressivity and the cost of the layer. Several families have been proposed:
- Elementwise positive maps. Katharopoulos et al. introduced $ \phi(x) = \mathrm{elu}(x) + 1 $, a cheap nonnegative map with the same dimension as the input. This is the canonical "linear Transformer" baseline.
- Random feature maps. The Performer model approximates the softmax kernel with positive random features $ \phi(x) \propto \exp(W x - \|x\|^2 / 2) $, where $ W $ contains random orthogonal projections. This recovers softmax attention in expectation while keeping the linear-cost factorization.[2]
- Polynomial feature maps. Choosing $ \phi(x) = (1, x, x \otimes x, \dots) $ yields a polynomial kernel; truncating at low degree gives a tractable linear-attention variant with controlled expressivity.
- Identity map. Setting $ \phi(x) = x $ reduces the layer to plain bilinear attention. This is fast but tends to underperform unless combined with normalization or gating.
In practice the feature map is rarely the limiting factor on quality; the more important design choices are normalization, gating, and how the recurrent state is updated.
Recurrent form and autoregressive inference
For causal (autoregressive) modeling, the cumulative sums $ S_t $ and $ z_t $ can be updated incrementally:
$ {\displaystyle S_t = S_{t-1} + \phi(k_t)\, v_t^\top, \qquad z_t = z_{t-1} + \phi(k_t).} $
Each new token contributes a rank-one update to the matrix-valued state $ S_t $. The output at time $ t $ is
$ {\displaystyle y_t = \frac{\phi(q_t)^\top S_t}{\phi(q_t)^\top z_t}.} $
This is exactly the form of a linear RNN with a fixed-size hidden state of dimension $ d' \times d $. Per-token decoding cost is $ \mathcal{O}(d' d) $, independent of the sequence length, and there is no growing key-value cache. Memory consumption during inference is constant, which is the property that has driven much of the renewed interest in linear attention for long-context language models.
Training: parallel form
Training would be slow if it relied on the recurrent form, because backpropagation through a long sequential recurrence is hard to parallelize across the time axis. Fortunately, the same algebra that gives the recurrence also gives a parallel form: the cumulative sums can be computed with a prefix-scan, or, more commonly, the entire $ N \times N $ attention pattern can be materialized in chunks of moderate size so that GPUs are kept busy. Modern implementations such as the chunkwise parallel form of FLA (the Flash Linear Attention library) and the kernels used in RetNet exploit this to train at speeds comparable to softmax Transformers on long sequences while keeping the linear asymptotic cost.
The causal mask introduces a subtlety: the cumulative sums must respect the order of tokens, so a naive vectorization that sums over all $ j $ before applying the mask is incorrect. Correct implementations either use a prefix scan or split the sequence into chunks and combine intra-chunk softmax-style attention with inter-chunk linear updates.
Variants and extensions
A large number of follow-up models build on the basic linear-attention recipe:
- Performer. Random-feature approximation of the softmax kernel, providing unbiased estimates and theoretical guarantees on the variance of the approximation.
- Linear Transformer with elu+1. The original Katharopoulos et al. formulation, widely used as a baseline.
- RetNet. Replaces the unbounded cumulative sum with a decay factor $ \gamma \in (0, 1) $, giving $ S_t = \gamma S_{t-1} + \phi(k_t) v_t^\top $. The decay endows the model with a multi-scale retention property and is provably equivalent to a chunkwise computation.[3]
- Gated linear attention (GLA). Replaces the scalar decay with a data-dependent gate, recovering more of the selective behavior of softmax attention while keeping the linear cost.[4]
- Selective state-space models. Architectures such as Mamba, while not strictly linear-attention models, share the linear-recurrent structure and can be expressed in a closely related framework. The two families have substantially converged in the recent literature.
Comparison with softmax attention
Relative to standard softmax attention, linear attention offers:
- Asymptotic cost. $ \mathcal{O}(N d' d) $ versus $ \mathcal{O}(N^2 d) $ in time and memory. This becomes decisive somewhere around $ N \approx 2{,}000 $ to $ 8{,}000 $ tokens, depending on hardware and constant factors.
- Constant inference memory. No growing KV cache; the state has fixed shape $ d' \times d $.
- Streaming-friendly. New tokens can be incorporated with a rank-one update.
The price paid is:
- Reduced expressivity. Softmax attention can selectively focus on a small number of tokens with a sharp distribution; the bounded-rank state of linear attention cannot reproduce arbitrarily peaked patterns. This shows up empirically as weaker associative recall and copy behavior.
- Sensitivity to feature-map choice. Quality varies substantially with $ \phi $, normalization, and gating, so naive replacements often underperform.
In practice, hybrid architectures that interleave softmax and linear-attention layers are a common compromise.
Limitations and open problems
Linear attention's bounded-rank recurrent state is its defining feature and its main limitation. Tasks that require precise retrieval of an arbitrary past token, such as in-context learning with long demonstration sets, expose this gap most clearly. Several lines of work try to recover the missing capacity: gating, multi-scale decays, larger head counts, and the use of nonlinear update rules such as the delta rule.
A second open question concerns hardware. The recurrent form is compact but inherently sequential, while the parallel form requires careful kernel engineering to match the throughput of fused softmax-attention kernels. Libraries such as Flash Linear Attention have closed much of the gap, but the implementation surface is still less mature than for standard attention.
Finally, linear attention is sometimes criticized as a step backward toward RNNs, and it is true that the recurrent form has all the classic difficulties of training deep recurrences. Modern variants mitigate this with careful initialization, layer normalization, and decay parameterizations, but the issue is real and worth understanding before adopting linear attention in a new system.
References
- ↑ Katharopoulos, A., Vyas, A., Pappas, N., and Fleuret, F. Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention. ICML, 2020.
- ↑ Choromanski, K. et al. Rethinking Attention with Performers. ICLR, 2021. Template:Cite arxiv
- ↑ Sun, Y. et al. Retentive Network: A Successor to Transformer for Large Language Models. 2023. Template:Cite arxiv
- ↑ Yang, S. et al. Gated Linear Attention Transformers with Hardware-Efficient Training. ICML, 2024. Template:Cite arxiv