AdamW

    From Marovi AI
    This page contains changes which are not marked for translation.
    Other languages:
    Article
    Topic area optimization
    Prerequisites Adam, Stochastic gradient descent, Weight decay


    Overview

    AdamW is a stochastic optimization algorithm for training neural networks that decouples weight decay from the gradient-based update of Adam. Introduced by Loshchilov and Hutter in 2017, it corrects a long-standing implementation flaw in adaptive optimizers: in standard Adam, adding an L2 penalty to the loss does not produce true weight decay because the penalty term is rescaled by the per-parameter adaptive learning rate. AdamW restores the original Hanson and Pratt formulation by applying weight decay as a separate, fixed-rate shrinkage of the parameters after the adaptive update. The change is a few lines of code, but it consistently improves generalization and has become the default optimizer for transformer-based models, including BERT, GPT-style language models, and Vision Transformers.[1]

    Motivation: L2 regularization is not weight decay in Adam

    For plain SGD, adding an L2 penalty $ \tfrac{\lambda}{2}\|\theta\|^2 $ to the loss is mathematically equivalent to multiplying the parameters by $ (1 - \eta\lambda) $ at each step, where $ \eta $ is the learning rate. The two formulations — L2 regularization and weight decay — coincide.

    This equivalence breaks for adaptive methods. Adam scales each gradient component by an estimate of its second moment $ \hat{v}_t $, so the L2 contribution $ \lambda\theta $ appended to the gradient is divided by $ \sqrt{\hat{v}_t}+\epsilon $ before being applied. Parameters with large historical gradients (well-conditioned directions) receive less regularization than parameters with small gradients, which is the opposite of what weight decay is supposed to do. Loshchilov and Hutter showed that this coupling causes Adam to generalize worse than SGD with momentum on image classification benchmarks, and that decoupling weight decay closes most of the gap.

    Algorithm

    Let $ \theta_t $ denote the parameter vector at step $ t $, $ g_t = \nabla_\theta f_t(\theta_{t-1}) $ the stochastic gradient of the loss on minibatch $ t $, and $ \eta_t $ the (possibly scheduled) learning rate. AdamW maintains exponential moving averages of the gradient and the squared gradient with decay rates $ \beta_1, \beta_2 \in [0,1) $:

    $ {\displaystyle m_t = \beta_1 m_{t-1} + (1-\beta_1) g_t} $

    $ {\displaystyle v_t = \beta_2 v_{t-1} + (1-\beta_2) g_t^2} $

    The bias-corrected estimates are

    $ {\displaystyle \hat{m}_t = \frac{m_t}{1-\beta_1^t}, \qquad \hat{v}_t = \frac{v_t}{1-\beta_2^t}.} $

    The parameter update is then

    $ {\displaystyle \theta_t = \theta_{t-1} - \eta_t \left( \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} + \lambda\,\theta_{t-1} \right).} $

    The crucial term is $ \lambda\,\theta_{t-1} $, applied outside the adaptive denominator. By contrast, the original Adam-with-L2 update would fold $ \lambda\theta_{t-1} $ into $ g_t $, giving $ \theta_t = \theta_{t-1} - \eta_t (\hat{m}_t + \lambda \theta_{t-1}\cdot\text{scaling})/(\sqrt{\hat{v}_t}+\epsilon) $, where the weight-decay term is scaled by the same per-parameter adaptive factor as the gradient.

    Default hyperparameters in most implementations are $ \beta_1 = 0.9 $, $ \beta_2 = 0.999 $, $ \epsilon = 10^{-8} $, with $ \lambda $ typically in $ [10^{-2}, 10^{-1}] $ for transformer pretraining and $ [10^{-4}, 10^{-2}] $ for fine-tuning.

    Decoupled weight decay in practice

    The decoupling has two practical consequences. First, the optimal weight decay $ \lambda $ is now largely independent of the learning rate $ \eta $, which simplifies hyperparameter tuning — in the original Adam, changing the learning rate effectively changed the regularization strength as well, forcing joint sweeps. Second, the optimal $ \lambda $ for AdamW is typically one to two orders of magnitude larger than the L2 coefficient that worked for Adam, because the adaptive scaling no longer attenuates it.

    A common subtlety is whether to scale $ \lambda $ by $ \eta_t $ when using a learning-rate schedule. The original paper writes the update as $ \theta_t = \theta_{t-1} - \eta_t \hat{m}_t/(\sqrt{\hat{v}_t}+\epsilon) - \eta_t \lambda \theta_{t-1} $, so weight decay is scaled by the schedule. Some implementations (notably an early PyTorch version) instead applied $ \lambda\theta_{t-1} $ directly without the $ \eta_t $ factor; this is now widely considered a bug, and current PyTorch, JAX, and TensorFlow implementations follow the paper's convention.

    Variants and extensions

    Several optimizers extend or modify AdamW:

    • Lion (EvoLved Sign Momentum, Chen et al. 2023) — replaces the second-moment estimate with a sign operator, retains decoupled weight decay; uses about half the memory of AdamW.
    • AdamW with gradient clipping — global-norm or per-layer clipping is standard for large language model pretraining to control loss spikes.
    • LAMB (Layer-wise Adaptive Moments) — adds layer-wise normalization on top of AdamW for very large batch sizes (32k+), used in record-time BERT pretraining.
    • AdaFactor — factorizes the second-moment matrix to save memory; supports decoupled decay.
    • Adan and Sophia — second-order-inspired methods that retain the decoupled-decay design.

    The FP16/bfloat16 training era introduced additional hyperparameters: in practice, $ \epsilon $ is often raised to $ 10^{-6} $ or $ 10^{-5} $ in mixed precision to avoid underflow in $ \sqrt{\hat{v}_t}+\epsilon $.

    Comparison to Adam and SGD

    Empirically, AdamW closes the generalization gap that originally motivated practitioners to prefer SGD with momentum for vision tasks. On ImageNet, ResNet-50 trained with well-tuned AdamW reaches accuracy within 0.1–0.3% of SGD+momentum, where naive Adam-with-L2 lagged by 1–2 percentage points. For transformers, AdamW is essentially universal: the adaptive per-parameter scaling is necessary to handle the wide dynamic range of gradients across attention and feed-forward sublayers, and the decoupled decay keeps embedding and layer-norm parameters from being under-regularized.

    A useful heuristic: if you are training a model from scratch and the architecture has LayerNorm or RMSNorm, use AdamW. If you are fine-tuning a pre-trained model, use AdamW with a smaller $ \lambda $ and possibly a smaller $ \beta_2 $ (e.g. 0.95) to avoid washing out the pre-trained weights with stale second-moment estimates.

    Limitations

    AdamW inherits Adam's memory cost: it stores two extra tensors ($ m_t $, $ v_t $) per parameter, tripling optimizer state versus plain SGD. For models in the billions of parameters, this is a dominant cost, motivating sharded variants like ZeRO and 8-bit AdamW where the optimizer state is quantized.

    The decoupled decay is not a panacea. It assumes a fixed schedule for $ \lambda $; cyclic or warm-restart schedules (Loshchilov and Hutter's SGDR) interact non-trivially with the second-moment buffers, and best practice is still to warm up the learning rate over the first few hundred to few thousand steps before applying full weight decay. AdamW also remains sensitive to $ \beta_2 $ in low-data regimes, where the second-moment estimate is noisy; values like $ \beta_2 = 0.95 $ or $ 0.98 $ are common in reinforcement learning and continual learning.

    Finally, the equivalence between L2 regularization and weight decay does not hold for AdamW any more than it does for Adam — they are now different regularizers, and reporting "weight decay" without clarifying whether it refers to the AdamW $ \lambda $ or an L2 loss term is a common source of reproducibility errors.

    References