FlashAttention/es

    From Marovi AI
    This page is a translated version of the page FlashAttention and the translation is 100% complete.
    Other languages:
    Article
    Topic area Deep Learning
    Prerequisites Self-Attention, Softmax, Transformer


    Visión general

    FlashAttention es un algoritmo exacto y consciente de la E/S para calcular la operación de autoatención utilizada en los modelos Transformer. Presentado por Tri Dao y colaboradores en 2022, produce salidas numéricamente equivalentes a una implementación estándar de atención, pero se ejecuta varias veces más rápido y utiliza una memoria lineal en la longitud de la secuencia, en lugar de cuadrática. Su idea central es que, en aceleradores modernos como las GPU, la atención está limitada por el ancho de banda de memoria entre la memoria de alto ancho de banda (HBM) y la SRAM en chip, y no por el rendimiento de coma flotante. Al organizar el cálculo en bloques (tiling) de modo que los resultados intermedios nunca abandonen la SRAM, FlashAttention reduce el tráfico hacia la HBM y el tiempo de reloj de todo sistema basado en Transformer que lo emplea, desde BERT hasta los grandes modelos de estilo GPT y otros modelos de lenguaje.[1]

    El algoritmo se ha convertido en un componente estándar de facto de las pilas de entrenamiento e inferencia de Transformer en producción, y sus versiones sucesivas (FlashAttention-2 y FlashAttention-3) lo han ampliado con un mejor paralelismo y optimizaciones específicas para las generaciones más recientes de GPU.

    Por qué la atención estándar es lenta

    En un Transformer, la atención por producto escalar escalado calcula

    $ {\displaystyle \mathrm{Attention}(Q, K, V) = \mathrm{softmax}\!\left(\tfrac{Q K^{\top}}{\sqrt{d}}\right) V} $

    donde $ Q, K, V \in \mathbb{R}^{N \times d} $ son las matrices de consulta, clave y valor, $ N $ es la longitud de la secuencia y $ d $ es la dimensión por cabeza. Una implementación ingenua construye la matriz de puntuaciones $ S = Q K^{\top}/\sqrt{d} $, le aplica un softmax por filas para obtener $ P $ y luego la multiplica por $ V $. Tanto $ S $ como $ P $ son de tamaño $ N \times N $ y deben escribirse y leerse desde la HBM.

    Para contextos largos, este tráfico $ O(N^2) $ domina el tiempo de ejecución. En una GPU NVIDIA A100, el ancho de banda de la HBM es aproximadamente un orden de magnitud menor que el de la SRAM, por lo que la atención pasa la mayor parte de sus ciclos esperando a la memoria. Métodos previos de "atención rápida", como la atención dispersa y la atención lineal, reducían el número de FLOP pero a menudo no reducían el tiempo de reloj, porque no atacaban este cuello de botella de memoria.

    El truco del softmax en línea

    FlashAttention se basa en una forma incremental y numéricamente estable del softmax, a veces llamada softmax en línea.[2] Dada una secuencia de valores $ x_1, x_2, \dots, x_N $, se puede mantener un máximo acumulado $ m $ y un normalizador acumulado $ \ell $ tales que, tras procesar todos los valores,

    $ {\displaystyle \mathrm{softmax}(x)_i = \frac{e^{x_i - m}}{\ell}, \qquad m = \max_i x_i, \qquad \ell = \sum_i e^{x_i - m}.} $

    La misma identidad se extiende a la suma ponderada $ \sum_i \mathrm{softmax}(x)_i v_i $ que necesita la atención: cuando llega un nuevo bloque de puntuaciones con máximo local $ m' $, la salida parcial anterior se reescala por $ e^{m_{\text{old}} - m_{\text{new}}} $, se añade la nueva contribución y se actualiza el normalizador en consecuencia. La softmax completa nunca tiene que materializarse en memoria.

    El algoritmo FlashAttention

    FlashAttention aplica el softmax en línea de forma segmentada, en bloques de matrices por bloques. El paso hacia adelante funciona del siguiente modo.

    1. Particionar $ Q $ en bloques de $ B_r $ filas y $ K, V $ en bloques de $ B_c $ filas, eligiendo los tamaños de bloque de modo que el conjunto de trabajo quepa en la SRAM.
    2. Para cada bloque de consultas $ Q_i $, inicializar un bloque de salida $ O_i = 0 $, un máximo acumulado $ m_i = -\infty $ y un normalizador $ \ell_i = 0 $.
    3. Iterar sobre los bloques clave/valor $ (K_j, V_j) $. Cargar $ K_j $ y $ V_j $ en la SRAM. Calcular el bloque de puntuaciones $ S_{ij} = Q_i K_j^{\top}/\sqrt{d} $, el máximo local $ m_{ij} $ y las exponenciales locales $ \tilde P_{ij} = e^{S_{ij} - m_{ij}} $.
    4. Combinar con las estadísticas acumuladas. Actualizar $ m_i^{\text{new}} = \max(m_i, m_{ij}) $, reescalar $ O_i $ y $ \ell_i $ por $ e^{m_i - m_i^{\text{new}}} $ y acumular $ O_i \mathrel{+}= \tilde P_{ij} V_j \cdot e^{m_{ij} - m_i^{\text{new}}} $ y $ \ell_i \mathrel{+}= \mathbf{1}^{\top} \tilde P_{ij} \cdot e^{m_{ij} - m_i^{\text{new}}} $.
    5. Una vez procesados todos los bloques clave/valor, dividir $ O_i $ por $ \ell_i $ y escribirlo de vuelta a la HBM, junto con el log-sum-exp $ L_i = m_i + \log \ell_i $ (utilizado por el paso hacia atrás).

    La matriz completa de tamaño $ N \times N $ nunca se instancia en la HBM, por lo que la memoria desciende de $ O(N^2) $ a $ O(N) $. Los accesos a HBM caen de $ O(N^2 d) $ a $ O(N^2 d^2 / M) $, donde $ M $ es el tamaño de la SRAM, y de ahí proviene la mayor parte de la aceleración.

    Paso hacia atrás y recomputación

    El paso hacia atrás estándar necesita la matriz de atención $ P $, lo que anularía el propósito de FlashAttention si se almacenara. En su lugar, FlashAttention recomputa $ S $ y $ P $ sobre la marcha durante la retropropagación, utilizando únicamente el log-sum-exp $ L $ guardado y el gradiente hacia atrás $ \mathrm{d}O $. La recomputación es barata porque los bloques de puntuaciones residen en la SRAM durante la pasada de recálculo, y cambiar más FLOPs por menos lecturas de HBM es un intercambio favorable en hardware limitado por memoria. El paso de entrenamiento global resulta más rápido que la línea base que materializa la matriz, a pesar de realizar más operaciones aritméticas.

    Variantes y sucesores

    FlashAttention ha pasado por tres versiones publicadas.

    • FlashAttention-1 (2022) introdujo el mosaico (tiling) consciente de la E/S descrito anteriormente y se implementó como un único kernel fusionado de CUDA. Soportaba enmascaramiento causal y dropout, manteniendo la equivalencia bit a bit en FP16 hasta el orden estándar de reducción.
    • FlashAttention-2 (2023) reorganizó la distribución de trabajo de modo que el bucle externo recorre los bloques de consultas y el interno los bloques clave/valor, aumentando el paralelismo entre bloques de hilos y warps. También redujo los FLOPs no relacionados con la multiplicación de matrices y mejoró la utilización en secuencias largas, alcanzando aproximadamente entre el 50 y el 73 % del pico teórico de FLOPs en la A100.[3]
    • FlashAttention-3 (2024) está dirigida a la arquitectura Hopper y emplea Tensor Cores asíncronos, el Tensor Memory Accelerator (TMA) y la especialización por warp para solapar el trabajo de softmax con el de matmul. También admite FP8 con un error de cuantización reducido gracias a un escalado por bloques.[4]

    Entre las ideas relacionadas se encuentran Memory-Efficient Attention de xFormers, que llegó de forma independiente a un enfoque por bloques similar, y PagedAttention, utilizada en vLLM, que segmenta la caché KV durante la inferencia para soportar grandes tamaños de lote y muchas secuencias concurrentes.

    Comparación con otros métodos de atención rápida

    A diferencia de métodos aproximados como Performer o Linformer, FlashAttention calcula la atención exacta. No altera las salidas del modelo ni la dinámica de entrenamiento, por lo que puede sustituirse en cualquier Transformer existente sin necesidad de reentrenarlo. Su aceleración se combina con técnicas que cambian el coste asintótico de la atención, ya que FlashAttention también puede emplearse junto con máscaras dispersas o por bandas. En la práctica, en cargas en las que la atención exacta es aceptable, FlashAttention ha desplazado a las alternativas aproximadas porque ofrece menor tiempo de ejecución, menor consumo de memoria y resultados numéricos idénticos.

    Limitaciones

    Las ganancias de FlashAttention dependen de la diferencia entre el ancho de banda de la SRAM y el de la HBM, así como de la programación cuidadosamente ajustada del kernel. El algoritmo es más beneficioso para secuencias de longitud moderada a larga, dimensiones de cabeza de hasta unos 256 y patrones de enmascaramiento estándar. Las variantes de atención personalizadas — máscaras exóticas, sesgos aprendidos o funciones de puntuación no estándar — pueden requerir un nuevo kernel o caer a una ruta más lenta. Los resultados numéricos solo son equivalentes bit a bit hasta el orden de reducción elegido por el kernel; con formatos de baja precisión como FP8, es necesario un escalado cuidadoso para preservar la exactitud.

    Referencias

    1. Template:Cite arxiv
    2. Milakov, M. and Gimelshein, N., Online Normalizer Calculation for Softmax, 2018.
    3. Template:Cite arxiv
    4. Template:Cite arxiv