Cross-Attention/en

    From Marovi AI
    Other languages:
    Article
    Topic area Deep Learning
    Prerequisites Attention Mechanism, Self-Attention, Transformer


    Overview

    Cross-attention is a variant of the attention mechanism in which the queries come from one sequence and the keys and values come from a different sequence. It is the standard mechanism for letting a model condition the generation or representation of one stream on the content of another, and it is the architectural backbone of encoder-decoder Transformers, retrieval-augmented models, and most modern multimodal systems that bind text to images, audio, or video.

    In contrast to self-attention, where queries, keys, and values are projections of the same input, cross-attention establishes an asymmetric flow of information from a source sequence to a target sequence. The target sequence asks questions; the source sequence supplies the evidence. This decoupling makes cross-attention a natural fit whenever two streams have different lengths, different modalities, or different roles — for example, decoder tokens attending to encoder outputs in machine translation, or text tokens attending to image patches in a vision-language model.

    Cross-attention was introduced as part of the original Transformer architecture[1] and has since become a basic building block reused across diffusion models[2], perceiver-style architectures[3], and contemporary multimodal large language models.

    Intuition

    A useful mental model is to think of attention as soft, differentiable lookup in an associative memory. Self-attention is a memory whose contents are the very tokens doing the lookup; cross-attention is a memory whose contents come from somewhere else. The decoder of a translation model, while generating the next German word, queries the encoder's representations of the source English sentence to decide which words it should be looking at right now. The query knows what it wants ("a noun phrase that represents the subject"); the keys advertise what they offer ("I am a noun phrase about the cat"); the values supply the actual content that gets mixed into the decoder's hidden state.

    Because keys and values come from a fixed external source for the duration of a decoding step, cross-attention is also the most natural place to inject conditioning signals into a generative model. Text-to-image diffusion models, for example, treat the denoising U-Net's spatial features as queries and the encoded text prompt as keys and values, so that every spatial location can selectively pull semantic content from the prompt at every denoising step.

    Formulation

    Let the target (query) sequence have length $ n $ with hidden size $ d $, giving $ X_{\text{tgt}} \in \mathbb{R}^{n \times d} $, and let the source (key/value) sequence have length $ m $, giving $ X_{\text{src}} \in \mathbb{R}^{m \times d} $. Three learned linear projections produce the queries, keys, and values:

    $ {\displaystyle Q = X_{\text{tgt}} W_Q, \quad K = X_{\text{src}} W_K, \quad V = X_{\text{src}} W_V} $

    with $ W_Q, W_K \in \mathbb{R}^{d \times d_k} $ and $ W_V \in \mathbb{R}^{d \times d_v} $. Scaled dot-product cross-attention then computes

    $ {\displaystyle \operatorname{CrossAttn}(X_{\text{tgt}}, X_{\text{src}}) = \operatorname{softmax}\!\left(\frac{Q K^\top}{\sqrt{d_k}}\right) V \in \mathbb{R}^{n \times d_v}.} $

    The crucial structural fact is the asymmetry of the two operands: the output has the same length as the target, while its content is a convex combination of value vectors drawn from the source. The attention matrix is rectangular with shape $ n \times m $, not square as in self-attention.

    In practice cross-attention is almost always used in its multi-head form. With $ h $ heads of dimension $ d_k = d / h $, the queries, keys, and values are split, attention is applied per head, and the head outputs are concatenated and linearly projected:

    $ {\displaystyle \operatorname{MultiHead}(X_{\text{tgt}}, X_{\text{src}}) = \operatorname{Concat}(\operatorname{head}_1, \ldots, \operatorname{head}_h)\, W_O.} $

    Different heads can specialize: some align positionally, others align semantically, and still others act as near-uniform smoothing.

    Use in encoder-decoder Transformers

    In the original Transformer, every decoder block contains three sublayers: masked self-attention over the partially generated target, cross-attention into the encoder stack's final output, and a position-wise feedforward network. The cross-attention sublayer is the only place where information flows from source to target; remove it, and the decoder becomes a plain language model with no view of the source sentence.

    A few practical points follow from this. Because the encoder representations are computed once and reused for every decoding step, the keys and values can be cached across steps, which makes cross-attention much cheaper than the self-attention that grows with the partial target. Most production decoders maintain a separate KV cache for self-attention (which grows token by token) and a static, precomputed KV tensor for cross-attention.

    Cross-attention is also where source-side padding masks are applied: padded positions in the source sequence are masked out so that the softmax assigns them zero probability. Causal masks, by contrast, are unnecessary in cross-attention — the decoder is allowed to attend to any source position at any decoding step.

    Variants

    Several variants extend or modify the basic cross-attention layer to address specific constraints.

    Gated cross-attention inserts a learned, often initially-zero gate on the cross-attention output so that a freshly added cross-attention layer does not destabilize a pretrained model. This is the mechanism Flamingo uses to graft visual context onto a frozen language model[4], and it is a common pattern for parameter-efficient multimodal adaptation more generally.

    Perceiver-style cross-attention uses a small set of learned latent vectors as queries against a very long input sequence, compressing the input into a fixed-size representation independently of its length. This breaks the quadratic dependence of standard self-attention on input length and is what allows the Perceiver family to handle raw pixels, audio samples, and point clouds without modality-specific tokenizers.

    Cross-attention in diffusion models conditions a denoising network on a text or class embedding by treating the network's spatial feature map as queries and the conditioning embedding as keys and values. The same mechanism, applied at every layer and every denoising step, is what gives latent diffusion models their fine-grained controllability over generated images.

    Memory and retrieval cross-attention generalizes the source sequence to a retrieved chunk database. Architectures such as RETRO and kNN-augmented Transformers retrieve nearest-neighbor passages and cross-attend into them, which decouples a model's parametric capacity from the knowledge it can access at inference time.

    Cross-attention versus self-attention

    The difference between cross-attention and self-attention is structural rather than algorithmic: the same scaled dot-product is computed, but the keys and values come from a different source. Several practical consequences follow.

    The attention matrix is rectangular and typically not square, so the cost is $ O(nm) $ rather than $ O(n^2) $; for short targets attending to long sources, this is much cheaper than self-attending over the concatenation of the two. Padding masks are applied on the source side only, and causal masking, when used, applies to the target's own self-attention rather than to cross-attention. Because the source representation is fixed during decoding, its keys and values can be precomputed once and reused, which is a substantial inference-time win.

    A subtler point is that cross-attention does not need positional encodings on the source if the source representation already contains positional information from an earlier encoder. In multimodal settings where source and target modalities have very different positional structures (e.g., 2D image patches as source, 1D text as target), positional information typically lives inside the encoder rather than being added at the cross-attention boundary.

    Limitations

    Cross-attention inherits the quadratic memory cost of standard attention in the rectangle $ n \times m $. When the source is very long — long documents, high-resolution images, or hour-long audio — the attention matrix becomes the dominant cost, and various sparse, low-rank, or memory-efficient approximations[5] are needed.

    Cross-attention is also notoriously fragile to distribution shift between the source and target streams. A decoder trained to attend to clean encoder outputs can degrade sharply when the encoder is replaced or fine-tuned, because the geometry of the keys may change in ways the queries did not anticipate. Joint training, gating, or careful adapter design typically mitigates this.

    Finally, cross-attention is not, on its own, a solution to grounding or hallucination. The mechanism only specifies how information flows; it does not enforce that the target faithfully reflects the source. Models trained with cross-attention can and do ignore their conditioning, particularly in autoregressive setups where the target's self-attention can dominate the cross-attention signal.

    References