AdamW/es
| Article | |
|---|---|
| Topic area | optimization |
| Prerequisites | Adam, Stochastic gradient descent, Weight decay |
Visión general
AdamW es un algoritmo de optimización estocástica para el entrenamiento de redes neuronales que desacopla el decaimiento de pesos de la actualización basada en gradiente de Adam. Introducido por Loshchilov y Hutter en 2017, corrige un defecto de implementación de larga data en optimizadores adaptativos: en Adam estándar, añadir una penalización L2 a la pérdida no produce un verdadero decaimiento de pesos porque el término de penalización se reescala por la tasa de aprendizaje adaptativa por parámetro. AdamW restaura la formulación original de Hanson and Pratt aplicando el decaimiento de pesos como una contracción separada de tasa fija sobre los parámetros después de la actualización adaptativa. El cambio son unas pocas líneas de código, pero mejora consistentemente la generalización y se ha convertido en el optimizador por defecto para modelos basados en transformers, incluidos BERT, modelos de lenguaje estilo GPT y Vision Transformers.[1]
Motivación: la regularización L2 no es decaimiento de pesos en Adam
Para SGD simple, añadir una penalización L2 $ \tfrac{\lambda}{2}\|\theta\|^2 $ a la pérdida es matemáticamente equivalente a multiplicar los parámetros por $ (1 - \eta\lambda) $ en cada paso, donde $ \eta $ es la tasa de aprendizaje. Las dos formulaciones — regularización L2 y decaimiento de pesos — coinciden.
Esta equivalencia se rompe para los métodos adaptativos. Adam escala cada componente del gradiente por una estimación de su segundo momento $ \hat{v}_t $, de modo que la contribución L2 $ \lambda\theta $ añadida al gradiente se divide por $ \sqrt{\hat{v}_t}+\epsilon $ antes de aplicarse. Los parámetros con gradientes históricos grandes (direcciones bien condicionadas) reciben menos regularización que los parámetros con gradientes pequeños, lo cual es lo contrario de lo que el decaimiento de pesos debería hacer. Loshchilov y Hutter demostraron que este acoplamiento provoca que Adam generalice peor que SGD con momento en pruebas de clasificación de imágenes, y que desacoplar el decaimiento de pesos cierra la mayor parte de la brecha.
Algoritmo
Sea $ \theta_t $ el vector de parámetros en el paso $ t $, $ g_t = \nabla_\theta f_t(\theta_{t-1}) $ el gradiente estocástico de la pérdida sobre el minibatch $ t $, y $ \eta_t $ la tasa de aprendizaje (posiblemente con un calendario). AdamW mantiene promedios móviles exponenciales del gradiente y del gradiente al cuadrado con tasas de decaimiento $ \beta_1, \beta_2 \in [0,1) $:
$ {\displaystyle m_t = \beta_1 m_{t-1} + (1-\beta_1) g_t} $
$ {\displaystyle v_t = \beta_2 v_{t-1} + (1-\beta_2) g_t^2} $
Las estimaciones corregidas por sesgo son
$ {\displaystyle \hat{m}_t = \frac{m_t}{1-\beta_1^t}, \qquad \hat{v}_t = \frac{v_t}{1-\beta_2^t}.} $
La actualización de los parámetros es entonces
$ {\displaystyle \theta_t = \theta_{t-1} - \eta_t \left( \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} + \lambda\,\theta_{t-1} \right).} $
El término crucial es $ \lambda\,\theta_{t-1} $, aplicado fuera del denominador adaptativo. En contraste, la actualización original de Adam-con-L2 incorporaría $ \lambda\theta_{t-1} $ en $ g_t $, dando $ \theta_t = \theta_{t-1} - \eta_t (\hat{m}_t + \lambda \theta_{t-1}\cdot\text{scaling})/(\sqrt{\hat{v}_t}+\epsilon) $, donde el término de decaimiento de pesos se escala por el mismo factor adaptativo por parámetro que el gradiente.
Los hiperparámetros por defecto en la mayoría de las implementaciones son $ \beta_1 = 0.9 $, $ \beta_2 = 0.999 $, $ \epsilon = 10^{-8} $, con $ \lambda $ habitualmente en $ [10^{-2}, 10^{-1}] $ para el preentrenamiento de transformers y $ [10^{-4}, 10^{-2}] $ para el ajuste fino.
Decaimiento de pesos desacoplado en la práctica
El desacoplamiento tiene dos consecuencias prácticas. Primero, el decaimiento de pesos óptimo $ \lambda $ es ahora en gran medida independiente de la tasa de aprendizaje $ \eta $, lo que simplifica el ajuste de hiperparámetros — en el Adam original, cambiar la tasa de aprendizaje también cambiaba efectivamente la fuerza de la regularización, forzando barridos conjuntos. Segundo, el $ \lambda $ óptimo para AdamW suele ser de uno a dos órdenes de magnitud mayor que el coeficiente L2 que funcionaba para Adam, porque el escalado adaptativo ya no lo atenúa.
Una sutileza común es si se debe escalar $ \lambda $ por $ \eta_t $ al usar un calendario de tasa de aprendizaje. El artículo original escribe la actualización como $ \theta_t = \theta_{t-1} - \eta_t \hat{m}_t/(\sqrt{\hat{v}_t}+\epsilon) - \eta_t \lambda \theta_{t-1} $, de modo que el decaimiento de pesos se escala por el calendario. Algunas implementaciones (en particular una versión temprana de PyTorch) aplicaban en cambio $ \lambda\theta_{t-1} $ directamente sin el factor $ \eta_t $; esto se considera ahora ampliamente un error, y las implementaciones actuales de PyTorch, JAX y TensorFlow siguen la convención del artículo.
Variantes y extensiones
Varios optimizadores extienden o modifican AdamW:
- Lion (EvoLved Sign Momentum, Chen et al. 2023) — sustituye la estimación del segundo momento por un operador de signo, conserva el decaimiento de pesos desacoplado; usa aproximadamente la mitad de memoria que AdamW.
- AdamW con recorte de gradiente — el recorte por norma global o por capa es estándar para el preentrenamiento de grandes modelos de lenguaje con el fin de controlar los picos de pérdida.
- LAMB (Layer-wise Adaptive Moments) — añade normalización por capa sobre AdamW para tamaños de lote muy grandes (32k+), utilizado en el preentrenamiento de BERT en tiempos récord.
- AdaFactor — factoriza la matriz del segundo momento para ahorrar memoria; admite decaimiento desacoplado.
- Adan y Sophia — métodos inspirados en información de segundo orden que conservan el diseño de decaimiento desacoplado.
La era del entrenamiento en FP16/bfloat16 introdujo hiperparámetros adicionales: en la práctica, $ \epsilon $ suele elevarse a $ 10^{-6} $ o $ 10^{-5} $ en precisión mixta para evitar el subdesbordamiento en $ \sqrt{\hat{v}_t}+\epsilon $.
Comparación con Adam y SGD
Empíricamente, AdamW cierra la brecha de generalización que originalmente motivó a los profesionales a preferir SGD con momento para tareas de visión. En ImageNet, ResNet-50 entrenado con AdamW bien ajustado alcanza una exactitud dentro de 0.1–0.3% de SGD+momento, donde el Adam-con-L2 ingenuo se quedaba 1–2 puntos porcentuales por detrás. Para los transformers, AdamW es esencialmente universal: el escalado adaptativo por parámetro es necesario para manejar el amplio rango dinámico de los gradientes entre las subcapas de atención y de alimentación hacia adelante, y el decaimiento desacoplado evita que los parámetros de embedding y de normalización por capas queden infrarregularizados.
Una heurística útil: si está entrenando un modelo desde cero y la arquitectura tiene LayerNorm o RMSNorm, use AdamW. Si está ajustando finamente un modelo preentrenado, use AdamW con un $ \lambda $ más pequeño y posiblemente un $ \beta_2 $ más pequeño (por ejemplo 0.95) para evitar arrastrar los pesos preentrenados con estimaciones de segundo momento desactualizadas.
Limitaciones
AdamW hereda el coste de memoria de Adam: almacena dos tensores adicionales ($ m_t $, $ v_t $) por parámetro, triplicando el estado del optimizador frente al SGD simple. Para modelos con miles de millones de parámetros, este es un coste dominante, lo que motiva variantes fragmentadas como ZeRO y AdamW de 8 bits, donde el estado del optimizador se cuantiza.
El decaimiento desacoplado no es una panacea. Asume un calendario fijo para $ \lambda $; los calendarios cíclicos o con reinicios cálidos (el SGDR de Loshchilov y Hutter) interactúan de manera no trivial con los búferes del segundo momento, y la mejor práctica sigue siendo calentar la tasa de aprendizaje durante los primeros cientos a unos pocos miles de pasos antes de aplicar el decaimiento de pesos pleno. AdamW también sigue siendo sensible a $ \beta_2 $ en regímenes de pocos datos, donde la estimación del segundo momento es ruidosa; valores como $ \beta_2 = 0.95 $ o $ 0.98 $ son comunes en aprendizaje por refuerzo y aprendizaje continuo.
Por último, la equivalencia entre regularización L2 y decaimiento de pesos no se mantiene para AdamW como tampoco para Adam — ahora son regularizadores diferentes, y reportar "weight decay" sin aclarar si se refiere al $ \lambda $ de AdamW o a un término de pérdida L2 es una fuente común de errores de reproducibilidad.