Mixed Precision Training/es
| Article | |
|---|---|
| Topic area | Deep Learning |
| Prerequisites | Neural Networks, Backpropagation, Stochastic Gradient Descent |
Visión general
El entrenamiento con precisión mixta es una técnica que acelera el entrenamiento de redes neuronales profundas al realizar la mayoría de las operaciones aritméticas en un formato de coma flotante de menor precisión (típicamente 16 bits) mientras mantiene un pequeño número de operaciones numéricamente sensibles en mayor precisión (típicamente 32 bits). Introducido en su forma moderna por Micikevicius et al. en 2017, se ha convertido en el régimen de entrenamiento por defecto para el aprendizaje profundo a gran escala, impulsando la mayor parte del trabajo contemporáneo sobre redes convolucionales, transformers y grandes modelos de lenguaje. En comparación con el entrenamiento puramente en precisión simple (FP32), la precisión mixta típicamente reduce a la mitad el consumo de memoria y entrega un rendimiento de dos a ocho veces mayor en hardware con unidades matriciales de precisión reducida dedicadas, mientras alcanza esencialmente la misma exactitud final.
El enfoque explota la observación de que el entrenamiento de redes neuronales es altamente tolerante al ruido numérico en la mayoría de los tensores — activaciones, gradientes y multiplicaciones matriciales intermedias — pero requiere alta precisión en unos pocos lugares críticos, particularmente la copia maestra de los pesos y ciertas reducciones en la pérdida y el optimizador.
Formatos de coma flotante
Tres formatos de coma flotante dominan el aprendizaje profundo moderno. La línea base histórica es FP32 (precisión simple IEEE 754), con un bit de signo, ocho bits de exponente y 23 bits de mantisa, lo que da un rango dinámico de aproximadamente $ 10^{-38} $ a $ 10^{38} $ y unos siete dígitos decimales de precisión.
FP16 (precisión media IEEE 754) usa un bit de signo, cinco bits de exponente y 10 bits de mantisa. Su rango dinámico es mucho más estrecho — aproximadamente $ 6 \times 10^{-5} $ a $ 6.5 \times 10^4 $ — lo que constituye el principal desafío numérico del entrenamiento con precisión mixta. Los gradientes pequeños pueden caer a cero por subdesbordamiento, y las activaciones grandes pueden desbordar a infinito.
BF16 (bfloat16), introducido por Google para la TPU y ahora soportado en la mayoría de los aceleradores modernos, conserva los ocho bits de exponente de FP32 pero trunca la mantisa a siete bits. Tiene el mismo rango dinámico que FP32 y mucha menos precisión que FP16, lo que lo hace dramáticamente más fácil de usar como reemplazo directo de FP32 porque el subdesbordamiento y el desbordamiento son raros. La contrapartida es un redondeo más grueso en cada operación individual.
Una familia más reciente de formatos FP8 (E4M3 y E5M2, estandarizados en 2022) extiende la misma idea a ocho bits, principalmente para multiplicaciones matriciales hacia adelante y hacia atrás en el entrenamiento de transformers muy grandes. FP8 típicamente requiere factores de escalado por tensor y se usa junto con un formato maestro de mayor precisión.
La receta de precisión mixta
La receta canónica de Micikevicius et al. tiene tres componentes.
Pesos maestros en FP32. El optimizador mantiene una copia maestra de los parámetros del modelo en FP32. Antes de cada paso hacia adelante, esta copia maestra se reduce al formato de baja precisión (FP16 o BF16) para producir los pesos de trabajo usados en la red. Después del paso del optimizador, se actualiza el maestro FP32, no la copia de baja precisión. Esto evita que las pequeñas actualizaciones de parámetros producidas por Stochastic Gradient Descent se pierdan por redondeo cuando se suman a un valor de peso mucho mayor.
Concretamente, si $ w $ es un peso y $ \Delta w $ es su actualización, el espaciado representable en FP16 cerca de $ w \approx 1 $ es aproximadamente $ 2^{-10} \approx 10^{-3} $. Las actualizaciones menores que esta magnitud — extremadamente comunes al final del entrenamiento — se perderían por completo si la suma se realizara en FP16.
Adelante y atrás en baja precisión. Las activaciones, los tensores de pesos y los gradientes se almacenan en FP16 o BF16. Las multiplicaciones matriciales y convoluciones se ejecutan en núcleos tensor o matriciales dedicados que consumen entradas de baja precisión y acumulan internamente en FP32, escribiendo después una salida de baja precisión. De aquí provienen las ganancias de memoria y rendimiento.
Escalado de pérdida. Debido a que FP16 tiene un rango dinámico limitado, los valores de gradiente menores a aproximadamente $ 2^{-24} $ caen a cero por subdesbordamiento. La solución consiste en multiplicar la pérdida por un factor de escala grande $ S $ antes de la retropropagación:
$ {\displaystyle L_{\mathrm{scaled}} = S \cdot L} $
Por la regla de la cadena, todo gradiente queda entonces escalado por el mismo factor $ S $, elevando los valores pequeños fuera de la región de subdesbordamiento. Después del paso hacia atrás, los gradientes son desescalados (divididos por $ S $) en FP32 antes del paso del optimizador. Con BF16, el escalado de pérdida generalmente no es necesario porque el formato hereda el rango de exponente de FP32.
Escalado dinámico de pérdida
Elegir un único valor estático para $ S $ requiere conocer la distribución de los gradientes de antemano. Los frameworks modernos en su lugar usan escalado dinámico de pérdida, que ajusta $ S $ durante el entrenamiento:
- Comenzar con un valor inicial grande (p. ej., $ S = 2^{16} $).
- Después de cada paso hacia atrás, comprobar si algún gradiente contiene infinito o NaN.
- Si se detecta desbordamiento, omitir el paso del optimizador para esa iteración y dividir $ S $ a la mitad.
- Si no se ha detectado desbordamiento durante un número fijo de iteraciones (p. ej., 2000), duplicar $ S $.
Este procedimiento mantiene la escala tan grande como sea numéricamente posible sin perder iteraciones por desbordamiento, y se adapta a medida que las magnitudes de los gradientes cambian a lo largo del entrenamiento.
Operaciones que permanecen en FP32
Un puñado de operaciones se mantienen rutinariamente en FP32 incluso dentro de un grafo por lo demás de precisión mixta. Son aquellas cuyo comportamiento numérico es sensible al rango o a la suma repetida:
- El softmax y el log-softmax usados en cabezas de atención y de clasificación, donde importan pequeñas diferencias entre logits grandes.
- El cómputo de pérdida de entropía cruzada, que combina un softmax con un logaritmo de un número pequeño.
- Las estadísticas de Batch Normalization — media, varianza y las estimaciones móviles — que se acumulan sobre muchas muestras.
- Las reducciones sobre ejes largos, como las normas de gradiente usadas para el recorte.
- El estado del optimizador (p. ej., las estimaciones del primer y segundo momento de Adam), que se acumula a lo largo de muchos pasos.
Los frameworks exponen esta distinción mediante regiones autocast o listas de operaciones permitidas: las multiplicaciones matriciales y convoluciones se reducen automáticamente, mientras que las operaciones listadas permanecen en FP32.
Implementaciones
PyTorch proporciona precisión mixta a través de torch.cuda.amp (la API original de FP16) y torch.amp (la API unificada de FP16 / BF16), combinada con GradScaler para el escalado de pérdida. TensorFlow expone la misma idea mediante políticas tf.keras.mixed_precision. JAX usa control explícito del dtype más bibliotecas como Optax para el escalado de pérdida.
La biblioteca Apex de NVIDIA fue el primer kit de herramientas de precisión mixta ampliamente utilizado y precedió a las API nativas de los frameworks; sigue siendo históricamente importante como fuente del escalado dinámico de pérdida. El formato TF32 relacionado (usado implícitamente por las GPU de la generación Ampere para los matmuls en FP32) a veces se agrupa con la precisión mixta, pero técnicamente es una optimización separada que pasa entradas FP32 a través de un multiplicador de precisión reducida.
Comparación con baja precisión pura
El entrenamiento puro en FP16 sin pesos maestros ni escalado de pérdida típicamente diverge o se estanca debido al subdesbordamiento de las actualizaciones y de los gradientes. El entrenamiento puro en BF16 sin pesos maestros suele funcionar para modelos de tamaño moderado pero tiende a perder exactitud final en ejecuciones de entrenamiento largas, especialmente con pérdidas convexas de cola, porque la mantisa de siete bits es demasiado gruesa para acumular con exactitud los pequeños momentos de Adam. La precisión mixta restaura esta exactitud manteniendo el estado del optimizador y los pesos maestros en FP32 mientras sigue extrayendo la mayor parte del beneficio de rendimiento del camino de cómputo de baja precisión.
Limitaciones y modos de fallo
La precisión mixta no está libre de inconvenientes. Los fallos más comunes incluyen:
- NaN persistentes al inicio del entrenamiento, normalmente causados por activaciones iniciales que exceden el valor máximo de FP16 de aproximadamente $ 6.5 \times 10^4 $. La solución es BF16, una inicialización cuidadosa o el recorte de gradiente por capa.
- Pérdida silenciosa de exactitud cuando una operación que debería haberse mantenido en FP32 — por ejemplo un softmax sobre secuencias muy largas — se ejecuta accidentalmente en FP16. La solución estándar es auditar la política de autocast.
- Colapso de la escala de pérdida, donde la escala dinámica cae a uno y permanece allí. Esto indica un problema numérico real, no de ajuste, y suele apuntar a datos defectuosos o a un componente inestable del modelo.
- Reproducibilidad reducida entre hardwares: distintas generaciones de núcleos tensor pueden producir resultados ligeramente diferentes a nivel de bit para el mismo matmul FP16, lo que complica las pruebas de reproducibilidad exacta.
Para modelos muy grandes, FP8 introduce consideraciones adicionales — las escalas por tensor deben ser rastreadas y actualizadas — pero la estructura de alto nivel es la misma que la de la receta original de FP16.
Véase también
Referencias
- ↑ Template:Cite arxiv
- ↑ Template:Cite arxiv
- ↑ Template:Cite arxiv
- ↑ NVIDIA, "Train With Mixed Precision," NVIDIA Deep Learning Performance Documentation, 2023.