FlashAttention/en

    From Marovi AI
    Other languages:
    Article
    Topic area Deep Learning
    Prerequisites Self-Attention, Softmax, Transformer


    Overview

    FlashAttention is an exact, IO-aware algorithm for computing the self-attention operation used in Transformer models. Introduced by Tri Dao and colleagues in 2022, it produces numerically equivalent outputs to a standard attention implementation while running several times faster and using memory linear in the sequence length, rather than quadratic. Its central insight is that on modern accelerators like GPUs, attention is bottlenecked by memory bandwidth between high-bandwidth memory (HBM) and on-chip SRAM, not by floating-point throughput. By tiling the computation so that intermediate results never leave SRAM, FlashAttention reduces HBM traffic and the wall-clock time of every Transformer-based system that uses it, from BERT to large GPT-style language models.[1]

    The algorithm has become a de facto standard component of production Transformer training and inference stacks, and successive versions (FlashAttention-2 and FlashAttention-3) have extended it with better parallelism and hardware-specific optimizations for newer GPU generations.

    Why standard attention is slow

    In a Transformer, scaled dot-product attention computes

    $ {\displaystyle \mathrm{Attention}(Q, K, V) = \mathrm{softmax}\!\left(\tfrac{Q K^{\top}}{\sqrt{d}}\right) V} $

    where $ Q, K, V \in \mathbb{R}^{N \times d} $ are query, key, and value matrices, $ N $ is the sequence length, and $ d $ is the per-head dimension. A naive implementation forms the score matrix $ S = Q K^{\top}/\sqrt{d} $, applies a row-wise softmax to obtain $ P $, then multiplies by $ V $. Both $ S $ and $ P $ are $ N \times N $ and must be written to and read from HBM.

    For long contexts, this $ O(N^2) $ traffic dominates running time. On an NVIDIA A100 GPU, HBM bandwidth is roughly an order of magnitude lower than SRAM bandwidth, so attention spends most of its cycles waiting for memory. Earlier "fast attention" methods such as sparse attention and linear attention reduced FLOPs but often did not reduce wall-clock time, because they did not address this memory bottleneck.

    The online softmax trick

    FlashAttention relies on a numerically stable, incremental form of softmax sometimes called online softmax.[2] Given a stream of values $ x_1, x_2, \dots, x_N $, one can maintain a running maximum $ m $ and a running normalizer $ \ell $ such that, after seeing all values,

    $ {\displaystyle \mathrm{softmax}(x)_i = \frac{e^{x_i - m}}{\ell}, \qquad m = \max_i x_i, \qquad \ell = \sum_i e^{x_i - m}.} $

    The same identity extends to the weighted sum $ \sum_i \mathrm{softmax}(x)_i v_i $ needed for attention: when a new block of scores arrives with local maximum $ m' $, the previous partial output is rescaled by $ e^{m_{\text{old}} - m_{\text{new}}} $, the new contribution is added, and the normalizer is updated accordingly. The full softmax never has to be materialized in memory.

    The FlashAttention algorithm

    FlashAttention applies online softmax in a tiled, block-matrix fashion. The forward pass works as follows.

    1. Partition $ Q $ into blocks of $ B_r $ rows and $ K, V $ into blocks of $ B_c $ rows, with block sizes chosen so that a working set fits in SRAM.
    2. For each query block $ Q_i $, initialize an output block $ O_i = 0 $, a running maximum $ m_i = -\infty $, and a normalizer $ \ell_i = 0 $.
    3. Iterate over key/value blocks $ (K_j, V_j) $. Load $ K_j $ and $ V_j $ into SRAM. Compute the block of scores $ S_{ij} = Q_i K_j^{\top}/\sqrt{d} $, the local maximum $ m_{ij} $, and the local exponentials $ \tilde P_{ij} = e^{S_{ij} - m_{ij}} $.
    4. Combine with the running statistics. Update $ m_i^{\text{new}} = \max(m_i, m_{ij}) $, rescale $ O_i $ and $ \ell_i $ by $ e^{m_i - m_i^{\text{new}}} $, and accumulate $ O_i \mathrel{+}= \tilde P_{ij} V_j \cdot e^{m_{ij} - m_i^{\text{new}}} $ and $ \ell_i \mathrel{+}= \mathbf{1}^{\top} \tilde P_{ij} \cdot e^{m_{ij} - m_i^{\text{new}}} $.
    5. After all key/value blocks have been processed, divide $ O_i $ by $ \ell_i $ and write it back to HBM, along with the log-sum-exp $ L_i = m_i + \log \ell_i $ (used by the backward pass).

    The full $ N \times N $ matrix is never instantiated in HBM, so memory drops from $ O(N^2) $ to $ O(N) $. HBM accesses fall from $ O(N^2 d) $ to $ O(N^2 d^2 / M) $, where $ M $ is the SRAM size, which yields the bulk of the speedup.

    Backward pass and recomputation

    The standard backward pass needs the attention matrix $ P $, which would defeat the purpose of FlashAttention if stored. Instead, FlashAttention recomputes $ S $ and $ P $ on the fly during backprop, using only the saved log-sum-exp $ L $ and the backward gradient $ \mathrm{d}O $. Recomputation is cheap because the score blocks live in SRAM during the recompute pass, and trading additional FLOPs for fewer HBM reads is a favorable exchange on memory-bound hardware. The overall training step is faster than the materializing baseline despite doing more arithmetic.

    Variants and successors

    FlashAttention has gone through three published versions.

    • FlashAttention-1 (2022) introduced the IO-aware tiling described above and was implemented as a single fused CUDA kernel. It supported causal masking and dropout while remaining bit-exact in FP16 up to standard reduction order.
    • FlashAttention-2 (2023) reorganized the work distribution so that the outer loop is over query blocks and the inner loop is over key/value blocks, increasing parallelism across thread blocks and warps. It also reduced non-matmul FLOPs and improved utilization on long sequences, reaching about 50–73% of theoretical peak FLOPs on A100.[3]
    • FlashAttention-3 (2024) targets the Hopper architecture, using asynchronous Tensor Cores, the Tensor Memory Accelerator (TMA), and warp-specialization to overlap softmax and matmul work. It also supports FP8 with reduced quantization error through block-wise scaling.[4]

    Related ideas include Memory-Efficient Attention from xFormers, which independently arrived at a similar tiled approach, and PagedAttention used in vLLM, which tiles the KV cache for inference to support large batch sizes and many concurrent sequences.

    Comparison with other fast-attention methods

    Unlike approximate methods such as Performer or Linformer, FlashAttention computes exact attention. It does not change the model's outputs or training dynamics, and so can be substituted into any existing Transformer without retraining. Its speedup compounds with techniques that change the asymptotic cost of attention, since FlashAttention can also be combined with sparse or banded masks. In practice, on workloads where exact attention is acceptable, FlashAttention has displaced approximate alternatives because it offers faster runtime, lower memory, and identical numerics.

    Limitations

    FlashAttention's gains depend on the gap between SRAM and HBM bandwidth and on the kernel's hand-tuned scheduling. The algorithm is most beneficial for moderate to long sequences, head dimensions up to about 256, and standard masking patterns. Custom attention variants — exotic masks, learned biases, or non-standard score functions — may require a new kernel or fall back to a slower path. Numerical results are bit-equivalent only up to the reduction order chosen by the kernel; under low-precision formats such as FP8, careful scaling is needed to preserve accuracy.

    References

    1. Template:Cite arxiv
    2. Milakov, M. and Gimelshein, N., Online Normalizer Calculation for Softmax, 2018.
    3. Template:Cite arxiv
    4. Template:Cite arxiv