Attention Rollout

    From Marovi AI
    This page contains changes which are not marked for translation.
    Other languages:
    Article
    Topic area Deep Learning
    Prerequisites Transformer, Attention Mechanism, Self-Attention


    Overview

    Attention rollout is a technique for quantifying how much information flows from each input token to each output position across all layers of a Transformer, by recursively multiplying the attention weight matrices of every layer. It was introduced by Abnar and Zuidema in 2020[1] as a remedy for the well-known problem that the raw attention weights in a single layer are a poor proxy for the contribution of each input token to a given prediction, because tokens are repeatedly mixed with their neighbours by Self-Attention in every layer.

    Rollout treats the stack of layers as a graph and computes the path-weighted reachability from any output node back to the inputs. The resulting matrix is widely used as a saliency map for transformer-based models, in particular Vision Transformers (ViTs) and BERT-style encoders, where it is one of the standard tools for interpretability and for producing class-discriminative localisation without retraining.

    Motivation

    Single-layer attention weights have an intuitive appeal as explanations: position $ i $ "looks at" position $ j $ with weight $ a_{ij} $. However, in a deep transformer the representation at layer $ \ell $ is already a mixture of all input tokens, so a high $ a_{ij} $ at a late layer does not mean that the prediction at position $ i $ depends mostly on the original input token $ j $. Several studies have argued that "attention is not explanation" precisely because of this contamination across layers.[2][3]

    Attention rollout addresses this by asking a different question: if we trace information backwards through the network, how much of the representation at a given output position originated from each input token? The answer requires composing the per-layer attention matrices, not inspecting them in isolation.

    Formulation

    Setup

    Consider a Transformer with $ L $ layers, sequence length $ n $, and $ H $ attention heads per layer. Let $ A^{(\ell)} \in \mathbb{R}^{n \times n} $ denote the attention weight matrix at layer $ \ell $, where rows correspond to query positions and columns to key positions and each row sums to one. When the layer has multiple heads, the per-head matrices $ A^{(\ell, h)} $ are first averaged:

    $ {\displaystyle A^{(\ell)} = \frac{1}{H} \sum_{h=1}^{H} A^{(\ell, h)}.} $

    Other reductions, such as taking the maximum across heads or a learned weighted sum, are sometimes used.

    Residual correction

    A pure self-attention layer is not the only path through a transformer block: a residual connection adds the layer's input to its output, so a non-trivial fraction of the signal at each position is simply the position's own previous representation. To account for this, rollout replaces the raw attention matrix with

    $ {\displaystyle \tilde{A}^{(\ell)} = \frac{1}{2}\bigl(A^{(\ell)} + I\bigr),} $

    where $ I $ is the identity matrix. The factor of one half preserves the row-stochastic property: each row of $ \tilde{A}^{(\ell)} $ still sums to one, so it can be interpreted as a probability distribution over input positions. The identity term encodes the residual stream's contribution.

    Recursive product

    The rollout matrix at layer $ \ell $ is the cumulative product

    $ {\displaystyle R^{(\ell)} = \tilde{A}^{(\ell)} \, \tilde{A}^{(\ell-1)} \cdots \tilde{A}^{(1)} = \prod_{k=\ell}^{1} \tilde{A}^{(k)}.} $

    Entry $ R^{(L)}_{ij} $ is interpreted as the proportion of the representation at output position $ i $ that can be attributed, through composed attention paths, to input position $ j $. Because each $ \tilde{A}^{(\ell)} $ is row-stochastic, so is $ R^{(L)} $, making the result directly visualisable as a heatmap over input tokens.

    For a classification ViT or a BERT-style model that uses a special CLS token, the row of $ R^{(L)} $ corresponding to the CLS index gives a saliency map over input patches or tokens for the predicted class.

    Algorithm

    The procedure is straightforward to implement on top of any transformer that exposes its attention weights:

    1. Run a forward pass and store $ A^{(\ell, h)} $ for every layer and head.
    2. For each layer, average over heads to obtain $ A^{(\ell)} $.
    3. Add the identity and renormalise to get $ \tilde{A}^{(\ell)} $.
    4. Initialise $ R \leftarrow I $ and multiply $ R \leftarrow \tilde{A}^{(\ell)} R $ from the first layer to the last.
    5. Slice out the row of interest (typically CLS) and reshape to the input grid.

    The total cost is $ O(L n^2) $ matrix-matrix products, dominated by the per-layer multiplication. For the modest sequence lengths typical in image classification, the entire rollout is cheap relative to a forward pass.

    Attention Flow

    The original paper of Abnar and Zuidema also proposes attention flow, a related but more expensive variant. Instead of multiplying the layerwise matrices, attention flow treats the attention graph as a capacitated directed graph and computes a maximum flow from every input node to every output node. Flow is more faithful to bottlenecks in the network, since multiplication can over-count paths that share an edge, but its cost is super-linear in $ n $ and it is rarely used at scale. Rollout is the dominant variant in practice.

    Applications

    Attention rollout has become a default visualisation tool for several families of models:

    • Vision Transformers: rollout heatmaps, restricted to the CLS row of $ R^{(L)} $ and reshaped to the patch grid, produce class-discriminative localisations that are competitive with gradient-based methods such as Grad-CAM[4] on ImageNet and on weakly supervised segmentation benchmarks.[5]
    • Language models: rollout is used to inspect which input tokens a BERT or RoBERTa encoder draws on for a given prediction, complementing per-head probing studies.
    • Multimodal transformers: in CLIP-style and image-text models, rollout matrices restricted to cross-attention layers reveal which image regions ground a given text token, supporting open-vocabulary segmentation and grounding.
    • Model auditing: by comparing rollout maps before and after a fine-tuning step or a domain shift, practitioners can detect whether a model has shifted its reliance from one input region to another.

    Variants and Extensions

    Several extensions of basic rollout sharpen its faithfulness or class specificity:

    • Gradient-weighted rollout: combines the attention matrix with gradients of the predicted class score with respect to the attention values, replacing $ A^{(\ell)} $ in the rollout product with $ (\nabla_A y)\odot A^{(\ell)} $ evaluated at the relevant class. Chefer et al. show that this produces sharper, more class-discriminative maps than vanilla rollout in ViTs.[6]
    • Head-specific rollout: instead of averaging heads, the rollout product is computed per head and the results are aggregated, exposing the role of individual heads.
    • Encoder-decoder rollout: for translation and seq2seq models, rollout is computed separately along the encoder self-attention chain, the cross-attention layers, and the decoder self-attention chain, and the matrices are composed end-to-end.
    • Sparse and pruned rollout: zeros below a threshold are dropped before multiplication to focus the visualisation on dominant paths.
    • Top-k attention rollout: retains only the top-k entries per row of each $ \tilde{A}^{(\ell)} $ before composition, motivated by the observation that attention distributions are often heavy-tailed.

    Comparison With Other Saliency Methods

    Rollout is one of several families of explanation tools applied to transformers:

    • Compared to raw attention, rollout corrects for the multi-layer mixing that makes single-layer attention misleading.
    • Compared to Grad-CAM and other gradient-based methods, rollout is purely forward-pass and does not require differentiating through the prediction; this makes it cheap and architecture-agnostic but also less class-specific unless gradients are folded in.
    • Compared to integrated gradients and SHAP-style methods, rollout is much faster (a small number of matrix products versus many forward passes) but does not satisfy the same axiomatic guarantees.
    • Compared to attention flow, rollout is a tractable approximation that gives qualitatively similar saliency maps in most cases.

    Limitations

    Several caveats apply when using attention rollout as an explanation:

    • No use of the value pathway: rollout only inspects the attention weights and ignores the value projections, which carry the actual content. Two heads with identical attention patterns but very different value matrices would produce indistinguishable rollouts.
    • Identity correction is heuristic: the equal weighting of $ A $ and $ I $ is a convention rather than a derived quantity. The true residual mass depends on the relative norms of the residual and attention contributions, which vary across layers and tokens.
    • Class agnosticism: vanilla rollout gives the same map regardless of the predicted class; gradient-weighted variants are needed for class-discriminative explanations.
    • Faithfulness debates: formal evaluations of attention-based explanations have produced mixed results, and rollout inherits much of this critique. Users should treat rollout maps as informative summaries rather than as ground-truth attributions.
    • Architecture assumptions: rollout in its standard form assumes a stack of identical attention blocks with residual connections. Adapting it to mixture-of-experts, sparse-attention, or routing-based transformers requires care.

    References

    1. Template:Cite arxiv
    2. Jain, S. and Wallace, B. C., "Attention is not Explanation", NAACL 2019.
    3. Wiegreffe, S. and Pinter, Y., "Attention is not not Explanation", EMNLP 2019.
    4. Selvaraju, R. R. et al., "Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization", ICCV 2017.
    5. Template:Cite arxiv
    6. Template:Cite arxiv