Linear Attention/es

    From Marovi AI
    This page is a translated version of the page Linear Attention and the translation is 100% complete.
    Other languages:
    Article
    Topic area Deep Learning
    Prerequisites Transformer, Softmax Function


    Visión general

    La atención lineal es una familia de mecanismos de atención cuyo coste en tiempo y memoria escala linealmente con la longitud de la secuencia, en contraste con el coste cuadrático de la atención softmax estándar utilizada en la arquitectura original del Transformer. La idea central es reemplazar la función de similitud exponencial de la atención softmax por un kernel que pueda escribirse como un producto interno de aplicaciones de características y, a continuación, aprovechar la asociatividad de la multiplicación de matrices para reordenar el cálculo. El resultado es una capa de atención que, para la decodificación autorregresiva, puede expresarse como una red neuronal recurrente lineal con un estado oculto de tamaño fijo, lo que la hace atractiva para el modelado de lenguaje de contexto largo, la inferencia en flujo y el despliegue en dispositivos.[1]

    La atención lineal sacrifica parte de la expresividad de la atención softmax a cambio de eficiencia asintótica. Si la transacción resulta favorable depende de la carga de trabajo: para secuencias muy largas e inferencia con memoria constante suele ser una clara ventaja, mientras que para contextos cortos o moderados los factores constantes y la brecha de calidad pueden anular la ventaja teórica. Variantes modernas como la atención lineal con compuertas y los modelos de tipo espacio de estados han reducido considerablemente la brecha de calidad y sustentan varios modelos de secuencia eficientes de comienzos de la década de 2020.

    Antecedentes y motivación

    La atención por producto punto escalado estándar calcula, para consultas $ Q \in \mathbb{R}^{N \times d} $, claves $ K \in \mathbb{R}^{N \times d} $ y valores $ V \in \mathbb{R}^{N \times d} $, la salida

    $ {\displaystyle \mathrm{Attention}(Q, K, V) = \mathrm{softmax}\!\left(\tfrac{Q K^\top}{\sqrt{d}}\right) V.} $

    La matriz $ Q K^\top $ tiene forma $ N \times N $, por lo que tanto el coste en tiempo como en memoria son $ \mathcal{O}(N^2 d) $. Para documentos largos, imágenes de alta resolución o formas de onda de audio, este escalado cuadrático se convierte en el coste dominante del entrenamiento y la inferencia, lo que motiva una amplia literatura sobre atención eficiente. La atención lineal es uno de los enfoques más sencillos e influyentes de esa literatura.

    Una segunda motivación proviene de la decodificación autorregresiva. Un decodificador Transformer estándar debe, en cada paso, atender a todos los tokens previos, lo que requiere almacenar una caché de claves y valores que crece con la longitud del contexto. Un decodificador con atención lineal mantiene un resumen de tamaño fijo del pasado, de modo que el coste de decodificación por token es constante en la longitud de la secuencia y la memoria no crece.

    Reformulación con núcleos

    El punto de partida de la atención lineal consiste en escribir la similitud (sin normalizar) de la softmax como un núcleo $ k(q, k) = \exp(q^\top k / \sqrt{d}) $ y reemplazarlo por un núcleo definido positivo que admita una aplicación explícita de características $ \phi : \mathbb{R}^d \to \mathbb{R}^{d'} $ tal que

    $ {\displaystyle k(q, k) = \phi(q)^\top \phi(k).} $

    Sustituyendo y escribiendo la salida de atención para una sola consulta $ q_i $ se obtiene

    $ {\displaystyle y_i = \frac{\sum_{j=1}^{N} \phi(q_i)^\top \phi(k_j)\, v_j}{\sum_{j=1}^{N} \phi(q_i)^\top \phi(k_j)} = \frac{\phi(q_i)^\top \sum_{j=1}^{N} \phi(k_j)\, v_j^\top}{\phi(q_i)^\top \sum_{j=1}^{N} \phi(k_j)}.} $

    El paso crucial es la segunda igualdad, que utiliza la asociatividad de la multiplicación de matrices para sacar $ \phi(q_i) $ fuera de las sumas. Las dos sumas

    $ {\displaystyle S = \sum_{j=1}^{N} \phi(k_j)\, v_j^\top \in \mathbb{R}^{d' \times d}, \qquad z = \sum_{j=1}^{N} \phi(k_j) \in \mathbb{R}^{d'}} $

    no dependen de la consulta, por lo que pueden precomputarse en tiempo $ \mathcal{O}(N d' d) $ y reutilizarse para todas las consultas. La salida completa de la atención se calcula entonces en $ \mathcal{O}(N d' d) $ en lugar de $ \mathcal{O}(N^2 d) $, lo que da el escalado lineal que da nombre al método.

    El mismo truco se aplica al denominador de la softmax, que se convierte en una normalización por $ \phi(q_i)^\top z $. Algunas implementaciones omiten por completo el denominador y se apoyan en la normalización por capas para estabilizar las magnitudes de salida.

    Mapas de características

    La elección de $ \phi $ determina tanto la expresividad como el coste de la capa. Se han propuesto varias familias:

    • Aplicaciones positivas elemento a elemento. Katharopoulos et al. introdujeron $ \phi(x) = \mathrm{elu}(x) + 1 $, una aplicación no negativa barata con la misma dimensión que la entrada. Esta es la línea base canónica del "Transformer lineal".
    • Aplicaciones de características aleatorias. El modelo Performer aproxima el núcleo de la softmax mediante características aleatorias positivas $ \phi(x) \propto \exp(W x - \|x\|^2 / 2) $, donde $ W $ contiene proyecciones ortogonales aleatorias. Esto recupera la atención softmax en esperanza a la vez que conserva la factorización de coste lineal.[2]
    • Aplicaciones de características polinómicas. Elegir $ \phi(x) = (1, x, x \otimes x, \dots) $ da un núcleo polinómico; truncar a grado bajo proporciona una variante de atención lineal manejable con expresividad controlada.
    • Aplicación identidad. Tomar $ \phi(x) = x $ reduce la capa a una atención bilineal simple. Esta opción es rápida pero suele tener un rendimiento inferior salvo que se combine con normalización o compuertas.

    En la práctica, la aplicación de características rara vez es el factor limitante de la calidad; las decisiones de diseño más importantes son la normalización, las compuertas y cómo se actualiza el estado recurrente.

    Forma recurrente e inferencia autorregresiva

    Para el modelado causal (autorregresivo), las sumas acumuladas $ S_t $ y $ z_t $ pueden actualizarse de forma incremental:

    $ {\displaystyle S_t = S_{t-1} + \phi(k_t)\, v_t^\top, \qquad z_t = z_{t-1} + \phi(k_t).} $

    Cada nuevo token aporta una actualización de rango uno al estado matricial $ S_t $ (matriz). La salida en el instante $ t $ es

    $ {\displaystyle y_t = \frac{\phi(q_t)^\top S_t}{\phi(q_t)^\top z_t}.} $

    Esta es exactamente la forma de una RNN lineal con un estado oculto de tamaño fijo de dimensión $ d' \times d $. El coste de decodificación por token es $ \mathcal{O}(d' d) $, independiente de la longitud de la secuencia, y no hay caché de claves y valores creciente. El consumo de memoria durante la inferencia es constante, propiedad que ha impulsado gran parte del renovado interés por la atención lineal para modelos de lenguaje de contexto largo.

    Entrenamiento: forma paralela

    El entrenamiento sería lento si dependiera de la forma recurrente, porque la retropropagación a través de una recurrencia secuencial larga es difícil de paralelizar a lo largo del eje temporal. Por suerte, el mismo álgebra que da la recurrencia también proporciona una forma paralela: las sumas acumuladas pueden calcularse con un escaneo por prefijos o, más comúnmente, todo el patrón de atención $ N \times N $ puede materializarse en bloques de tamaño moderado para mantener ocupadas a las GPU. Implementaciones modernas como la forma paralela por bloques de FLA (la biblioteca Flash Linear Attention) y los núcleos utilizados en RetNet aprovechan esto para entrenar a velocidades comparables a las de los Transformers con softmax en secuencias largas, manteniendo el coste asintótico lineal.

    La máscara causal introduce una sutileza: las sumas acumuladas deben respetar el orden de los tokens, por lo que una vectorización ingenua que sume sobre todos los $ j $ antes de aplicar la máscara es incorrecta. Las implementaciones correctas o bien usan un escaneo por prefijos, o bien dividen la secuencia en bloques y combinan atención de tipo softmax intra-bloque con actualizaciones lineales inter-bloque.

    Variantes y extensiones

    Un gran número de modelos posteriores se basan en la receta básica de la atención lineal:

    • Performer. Aproximación del núcleo de la softmax mediante características aleatorias, con estimaciones insesgadas y garantías teóricas sobre la varianza de la aproximación.
    • Transformer lineal con elu+1. La formulación original de Katharopoulos et al., ampliamente utilizada como línea base.
    • RetNet. Sustituye la suma acumulada no acotada por un factor de decaimiento $ \gamma \in (0, 1) $, dando $ S_t = \gamma S_{t-1} + \phi(k_t) v_t^\top $. El decaimiento dota al modelo de una propiedad de retención multiescala y es demostrablemente equivalente a un cálculo por bloques.[3]
    • Atención lineal con compuertas (GLA). Sustituye el decaimiento escalar por una compuerta dependiente de los datos, recuperando parte del comportamiento selectivo de la atención softmax mientras conserva el coste lineal.[4]
    • Modelos selectivos de espacio de estados. Arquitecturas como Mamba, aunque no son estrictamente modelos de atención lineal, comparten la estructura recurrente lineal y pueden expresarse en un marco estrechamente relacionado. Las dos familias han convergido sustancialmente en la literatura reciente.

    Comparación con la atención softmax

    Frente a la atención softmax estándar, la atención lineal ofrece:

    • Coste asintótico. $ \mathcal{O}(N d' d) $ frente a $ \mathcal{O}(N^2 d) $ en tiempo y memoria. Esto se vuelve decisivo en torno a $ N \approx 2{,}000 $ a $ 8{,}000 $ tokens, dependiendo del hardware y los factores constantes.
    • Memoria de inferencia constante. Sin caché KV creciente; el estado tiene forma fija $ d' \times d $.
    • Apta para flujo continuo. Los nuevos tokens pueden incorporarse con una actualización de rango uno.

    El precio que se paga es:

    • Expresividad reducida. La atención softmax puede enfocarse selectivamente en un pequeño número de tokens con una distribución muy puntiaguda; el estado de rango acotado de la atención lineal no puede reproducir patrones arbitrariamente puntiagudos. Esto se manifiesta empíricamente como un peor recuerdo asociativo y peor capacidad de copia.
    • Sensibilidad a la elección de la aplicación de características. La calidad varía sustancialmente con $ \phi $, la normalización y las compuertas, por lo que las sustituciones ingenuas suelen tener un rendimiento inferior.

    En la práctica, las arquitecturas híbridas que intercalan capas softmax con capas de atención lineal son un compromiso habitual.

    Limitaciones y problemas abiertos

    El estado recurrente de rango acotado de la atención lineal es a la vez su rasgo definitorio y su principal limitación. Las tareas que requieren la recuperación precisa de un token pasado arbitrario, como el aprendizaje en contexto con conjuntos largos de demostraciones, son las que evidencian con mayor claridad esta brecha. Varias líneas de trabajo intentan recuperar la capacidad faltante: compuertas, decaimientos multiescala, mayor número de cabezas y el uso de reglas de actualización no lineales como la regla delta.

    Una segunda cuestión abierta concierne al hardware. La forma recurrente es compacta pero intrínsecamente secuencial, mientras que la forma paralela exige una cuidadosa ingeniería de núcleos para igualar el rendimiento de los núcleos fusionados de atención softmax. Bibliotecas como Flash Linear Attention han cerrado gran parte de la brecha, pero la superficie de implementación aún es menos madura que para la atención estándar.

    Por último, la atención lineal a veces es criticada como un retroceso hacia las RNN, y es cierto que la forma recurrente arrastra todas las dificultades clásicas de entrenar recurrencias profundas. Las variantes modernas mitigan esto con una inicialización cuidadosa, normalización por capas y parametrizaciones del decaimiento, pero el problema es real y conviene comprenderlo antes de adoptar la atención lineal en un sistema nuevo.

    Referencias

    1. Katharopoulos, A., Vyas, A., Pappas, N., and Fleuret, F. Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention. ICML, 2020.
    2. Choromanski, K. et al. Rethinking Attention with Performers. ICLR, 2021. Template:Cite arxiv
    3. Sun, Y. et al. Retentive Network: A Successor to Transformer for Large Language Models. 2023. Template:Cite arxiv
    4. Yang, S. et al. Gated Linear Attention Transformers with Hardware-Efficient Training. ICML, 2024. Template:Cite arxiv