Cross-Attention/es
| Article | |
|---|---|
| Topic area | Deep Learning |
| Prerequisites | Attention Mechanism, Self-Attention, Transformer |
Visión general
La atención cruzada es una variante del mecanismo de atención en la que las consultas provienen de una secuencia y las claves y los valores provienen de una secuencia distinta. Es el mecanismo estándar para que un modelo condicione la generación o la representación de un flujo en función del contenido de otro, y constituye la columna vertebral arquitectónica de los Transformers de tipo codificador-decodificador, los modelos con recuperación aumentada y la mayoría de los sistemas multimodales modernos que vinculan texto con imágenes, audio o vídeo.
A diferencia de la autoatención, en la que las consultas, las claves y los valores son proyecciones de la misma entrada, la atención cruzada establece un flujo asimétrico de información desde una secuencia de origen hacia una secuencia de destino. La secuencia de destino formula preguntas; la secuencia de origen aporta la evidencia. Este desacoplamiento hace que la atención cruzada se ajuste de forma natural cuando dos flujos tienen longitudes distintas, modalidades distintas o roles distintos: por ejemplo, los tokens del decodificador atendiendo a las salidas del codificador en traducción automática, o los tokens de texto atendiendo a parches de imagen en un modelo de lenguaje de visión.
La atención cruzada se introdujo como parte de la arquitectura Transformer original[1] y, desde entonces, se ha convertido en un bloque básico reutilizado en modelos de difusión[2], arquitecturas de tipo Perceiver[3] y los modelos multimodales de gran tamaño actuales.
Intuición
Un modelo mental útil consiste en pensar en la atención como una búsqueda blanda y diferenciable en una memoria asociativa. La autoatención es una memoria cuyo contenido son los propios tokens que realizan la búsqueda; la atención cruzada es una memoria cuyo contenido procede de otro lugar. El decodificador de un modelo de traducción, mientras genera la siguiente palabra en alemán, consulta las representaciones del codificador de la oración fuente en inglés para decidir qué palabras debería estar mirando en ese momento. La consulta sabe lo que quiere ("un sintagma nominal que represente al sujeto"); las claves anuncian lo que ofrecen ("soy un sintagma nominal sobre el gato"); y los valores aportan el contenido real que se mezcla en el estado oculto del decodificador.
Dado que las claves y los valores provienen de una fuente externa fija durante toda una etapa de decodificación, la atención cruzada también es el lugar más natural para inyectar señales de condicionamiento en un modelo generativo. Los modelos de difusión de texto a imagen, por ejemplo, tratan las características espaciales de la U-Net de eliminación de ruido como consultas y la indicación de texto codificada como claves y valores, de modo que cada ubicación espacial puede extraer de forma selectiva contenido semántico de la indicación en cada paso de eliminación de ruido.
Formulación
Sea la secuencia de destino (consulta) de longitud $ n $ con tamaño oculto $ d $, lo que da $ X_{\text{tgt}} \in \mathbb{R}^{n \times d} $, y sea la secuencia de origen (clave/valor) de longitud $ m $, lo que da $ X_{\text{src}} \in \mathbb{R}^{m \times d} $. Tres proyecciones lineales aprendidas producen las consultas, las claves y los valores:
$ {\displaystyle Q = X_{\text{tgt}} W_Q, \quad K = X_{\text{src}} W_K, \quad V = X_{\text{src}} W_V} $
con $ W_Q, W_K \in \mathbb{R}^{d \times d_k} $ y $ W_V \in \mathbb{R}^{d \times d_v} $. La atención cruzada con producto escalado por producto escalar calcula entonces
$ {\displaystyle \operatorname{CrossAttn}(X_{\text{tgt}}, X_{\text{src}}) = \operatorname{softmax}\!\left(\frac{Q K^\top}{\sqrt{d_k}}\right) V \in \mathbb{R}^{n \times d_v}.} $
El hecho estructural crucial es la asimetría de los dos operandos: la salida tiene la misma longitud que el destino, mientras que su contenido es una combinación convexa de los vectores de valor extraídos del origen. La matriz de atención es rectangular con forma $ n \times m $, no cuadrada como en la autoatención.
En la práctica, la atención cruzada se usa casi siempre en su forma multi-cabeza. Con $ h $ cabezas de dimensión $ d_k = d / h $, las consultas, las claves y los valores se dividen, se aplica la atención por cabeza y las salidas de cada cabeza se concatenan y se proyectan linealmente:
$ {\displaystyle \operatorname{MultiHead}(X_{\text{tgt}}, X_{\text{src}}) = \operatorname{Concat}(\operatorname{head}_1, \ldots, \operatorname{head}_h)\, W_O.} $
Distintas cabezas pueden especializarse: algunas alinean posicionalmente, otras alinean semánticamente y otras actúan como un suavizado casi uniforme.
Uso en Transformers codificador-decodificador
En el Transformer original, cada bloque del decodificador contiene tres subcapas: autoatención enmascarada sobre el destino parcialmente generado, atención cruzada hacia la salida final de la pila del codificador y una red feedforward aplicada por posición. La subcapa de atención cruzada es el único lugar por el que la información fluye desde el origen hacia el destino; si se elimina, el decodificador se convierte en un modelo de lenguaje simple sin visión de la oración fuente.
De aquí se derivan varios puntos prácticos. Como las representaciones del codificador se calculan una vez y se reutilizan en cada paso de decodificación, las claves y los valores pueden almacenarse en caché entre pasos, lo que hace que la atención cruzada sea mucho más barata que la autoatención, que crece con el destino parcial. La mayoría de los decodificadores de producción mantienen una caché KV separada para la autoatención (que crece token a token) y un tensor KV estático y precomputado para la atención cruzada.
La atención cruzada también es donde se aplican las máscaras de relleno del lado del origen: las posiciones rellenadas de la secuencia de origen se enmascaran de forma que la softmax les asigne probabilidad cero. Las máscaras causales, en cambio, son innecesarias en la atención cruzada: el decodificador puede atender a cualquier posición del origen en cualquier paso de decodificación.
Variantes
Diversas variantes amplían o modifican la capa básica de atención cruzada para abordar restricciones específicas.
La atención cruzada con compuertas inserta una compuerta aprendida, frecuentemente inicializada en cero, sobre la salida de la atención cruzada, de modo que una capa de atención cruzada recién añadida no desestabilice un modelo preentrenado. Este es el mecanismo que utiliza Flamingo para injertar contexto visual en un modelo de lenguaje congelado[4], y es un patrón habitual de adaptación multimodal eficiente en parámetros en general.
La atención cruzada de tipo Perceiver utiliza un pequeño conjunto de vectores latentes aprendidos como consultas frente a una secuencia de entrada muy larga, comprimiendo la entrada en una representación de tamaño fijo independientemente de su longitud. Esto rompe la dependencia cuadrática de la autoatención estándar respecto a la longitud de entrada y es lo que permite que la familia Perceiver maneje píxeles brutos, muestras de audio y nubes de puntos sin tokenizadores específicos para cada modalidad.
La atención cruzada en modelos de difusión condiciona una red de eliminación de ruido sobre un texto o un embedding de clase tratando el mapa de características espacial de la red como consultas y el embedding de condicionamiento como claves y valores. Este mismo mecanismo, aplicado en cada capa y en cada paso de eliminación de ruido, es lo que confiere a los modelos de difusión latente su control de grano fino sobre las imágenes generadas.
La atención cruzada de memoria y recuperación generaliza la secuencia de origen a una base de datos de fragmentos recuperados. Arquitecturas como RETRO y los Transformers aumentados con kNN recuperan pasajes vecinos más cercanos y aplican atención cruzada sobre ellos, lo que desacopla la capacidad paramétrica de un modelo del conocimiento al que puede acceder en el momento de la inferencia.
Atención cruzada frente a autoatención
La diferencia entre atención cruzada y autoatención es estructural más que algorítmica: se calcula el mismo producto escalar escalado, pero las claves y los valores provienen de una fuente diferente. De ello se siguen varias consecuencias prácticas.
La matriz de atención es rectangular y, en general, no cuadrada, por lo que el coste es $ O(nm) $ en lugar de $ O(n^2) $; para destinos cortos que atienden a fuentes largas, esto resulta mucho más barato que aplicar autoatención sobre la concatenación de ambos. Las máscaras de relleno se aplican únicamente del lado del origen, y el enmascaramiento causal, cuando se utiliza, se aplica a la propia autoatención del destino y no a la atención cruzada. Como la representación del origen es fija durante la decodificación, sus claves y valores pueden precomputarse una sola vez y reutilizarse, lo que supone una ganancia sustancial en tiempo de inferencia.
Un punto más sutil es que la atención cruzada no necesita codificaciones posicionales en el origen si la representación del origen ya contiene información posicional procedente de un codificador anterior. En entornos multimodales en los que las modalidades de origen y destino tienen estructuras posicionales muy distintas (por ejemplo, parches de imagen 2D como origen y texto 1D como destino), la información posicional suele residir dentro del codificador y no añadirse en la frontera de la atención cruzada.
Limitaciones
La atención cruzada hereda el coste cuadrático de memoria de la atención estándar en el rectángulo $ n \times m $. Cuando el origen es muy largo —documentos extensos, imágenes de alta resolución o audio de horas de duración—, la matriz de atención se convierte en el coste dominante, y se requieren diversas aproximaciones dispersas, de bajo rango o eficientes en memoria[5].
La atención cruzada es también notoriamente frágil ante el cambio de distribución entre los flujos de origen y destino. Un decodificador entrenado para atender a salidas limpias del codificador puede degradarse de forma marcada cuando el codificador se reemplaza o se ajusta finamente, ya que la geometría de las claves puede cambiar en formas que las consultas no anticipaban. El entrenamiento conjunto, las compuertas o un diseño cuidadoso de adaptadores suelen mitigar este problema.
Por último, la atención cruzada no es, por sí sola, una solución al anclaje ni a la alucinación. El mecanismo solo especifica cómo fluye la información; no impone que el destino refleje fielmente al origen. Los modelos entrenados con atención cruzada pueden ignorar —y de hecho lo hacen— su condicionamiento, en particular en configuraciones autorregresivas en las que la autoatención del destino puede dominar la señal de la atención cruzada.