LAMB Optimizer/en
| Article | |
|---|---|
| Topic area | Machine Learning |
| Prerequisites | Adam Optimizer, Stochastic Gradient Descent, Gradient Descent |
Overview
The LAMB optimizer (Layer-wise Adaptive Moments for Batch training) is a first-order stochastic optimization algorithm designed to enable training of deep neural networks with very large mini-batch sizes without loss of generalization. Introduced by You et al. in 2019, LAMB combines the per-parameter adaptive moment estimates of Adam with a per-layer trust ratio inspired by LARS (Layer-wise Adaptive Rate Scaling). The algorithm became prominent after it was used to reduce BERT pre-training time from roughly three days on a single TPU pod to 76 minutes on 1024 TPUv3 chips, while matching or exceeding the published F1 score on the SQuAD benchmark.[1]
LAMB occupies the intersection of two long-standing concerns in deep learning: how to use the parallelism of modern accelerators (which favours large batches) and how to preserve the implicit regularization that small-batch SGD is believed to provide. Its core insight is that, when the global learning rate is fixed, layers with very different weight norms experience updates of very different effective magnitudes, and that rescaling each layer's update by a trust ratio of weight norm to update norm restores stable training at large batch sizes.
Motivation: Large-Batch Training
Distributed data-parallel training scales the effective batch size proportionally to the number of workers. In principle, doubling the batch size and the learning rate together preserves the per-epoch trajectory of gradient descent, a heuristic known as the linear scaling rule. In practice, this rule breaks down beyond a problem-dependent critical batch size: training either diverges, plateaus at a worse loss, or generalizes poorly. Goyal et al. demonstrated linear scaling for ResNet-50 on ImageNet up to a batch size of 8192 using SGD with momentum, warmup, and careful normalization,[2] but the same recipe failed for Transformer models trained with Adam.
LARS, proposed earlier by You et al., addressed the SGD case by introducing a layer-wise trust ratio that adapts the learning rate to each layer's weight norm.[3] LARS pushed ResNet-50 training to a batch size of 32K. However, LARS is built on momentum SGD, and applying it directly to Transformer training with adaptive optimizers gave inferior results. LAMB extends the trust-ratio idea to adaptive moment-based methods, which are the de facto choice for transformer pre-training.
Algorithmic Formulation
Let $ \theta_t \in \mathbb{R}^d $ denote the parameter vector at step $ t $, partitioned across $ L $ layers as $ \theta_t = (\theta_t^{(1)}, \dots, \theta_t^{(L)}) $. Given a stochastic gradient $ g_t = \nabla_\theta \ell(\theta_t; \xi_t) $ on a mini-batch $ \xi_t $, LAMB performs the following update.
First, it maintains exponential moving averages of the gradient and its square, identical to Adam:
$ {\displaystyle m_t = \beta_1 m_{t-1} + (1 - \beta_1)\, g_t,} $
$ {\displaystyle v_t = \beta_2 v_{t-1} + (1 - \beta_2)\, g_t \odot g_t,} $
with bias-corrected estimates $ \hat m_t = m_t / (1 - \beta_1^t) $ and $ \hat v_t = v_t / (1 - \beta_2^t) $.
Second, it forms a per-coordinate Adam-style update augmented with decoupled weight decay $ \lambda $:
$ {\displaystyle r_t = \frac{\hat m_t}{\sqrt{\hat v_t} + \epsilon} + \lambda\, \theta_t.} $
Third, and crucially, it rescales the update of each layer $ i $ by a layer-wise trust ratio:
$ {\displaystyle \theta_{t+1}^{(i)} = \theta_t^{(i)} - \eta_t \cdot \frac{\phi\!\left(\| \theta_t^{(i)} \|\right)}{\| r_t^{(i)} \|}\, r_t^{(i)},} $
where $ \| \cdot \| $ denotes the L2 norm restricted to the parameters of layer $ i $, $ \eta_t $ is the global learning rate, and $ \phi: \mathbb{R}_{\ge 0} \to \mathbb{R}_{>0} $ is a scaling function. In the canonical implementation, $ \phi(x) = x $ (the identity), optionally clipped to a range such as $ [\phi_{\min}, \phi_{\max}] $ to prevent extreme updates for layers with very small or very large weight norms.
The trust ratio $ \| \theta^{(i)} \| / \| r^{(i)} \| $ ensures that, in expectation, the relative change $ \| \Delta \theta^{(i)} \| / \| \theta^{(i)} \| $ equals $ \eta_t $, regardless of the layer's absolute scale. Layers whose Adam update would be disproportionately large (relative to current weights) are damped, and conservatively small updates are amplified.
Pseudocode
input: learning rate eta, betas (b1, b2), epsilon, weight decay lambda
init: theta_0, m_0 = 0, v_0 = 0
for t = 1, 2, ... do
sample mini-batch, compute g_t
m_t = b1 * m_{t-1} + (1 - b1) * g_t
v_t = b2 * v_{t-1} + (1 - b2) * g_t * g_t
m_hat = m_t / (1 - b1**t)
v_hat = v_t / (1 - b2**t)
r_t = m_hat / (sqrt(v_hat) + epsilon) + lambda * theta_{t-1}
for each layer i do
w_norm = ||theta_{t-1}^(i)||
g_norm = ||r_t^(i)||
if w_norm > 0 and g_norm > 0:
trust = phi(w_norm) / g_norm
else:
trust = 1
theta_t^(i) = theta_{t-1}^(i) - eta * trust * r_t^(i)
end for
end for
Practical Considerations
LAMB is most effective when paired with a learning-rate schedule that includes a warmup phase. The original BERT recipe used a linear warmup over the first few thousand steps followed by polynomial decay; the peak learning rate is markedly higher than the typical Adam range, often $ 10^{-2} $ or above for transformer pre-training, because the trust ratio absorbs the absolute magnitude of updates.
Weight decay is decoupled from the gradient, in the AdamW sense; folding decay into $ g_t $ would couple it to the second-moment normalization and undo much of the benefit. Bias and normalization parameters (for example batch norm or layer-norm scales and shifts) are conventionally exempted from both weight decay and the trust-ratio rescaling, since their norms are small and the trust ratio can become numerically unstable.
Numerical stability also requires guarding against zero-norm layers at initialization or after pruning. The reference implementation falls back to a unit trust ratio whenever either the weight norm or the update norm is zero.
Comparisons
Compared with Adam, LAMB introduces only a modest constant-factor overhead per step, dominated by the per-layer norm computations. Its benefits over Adam are negligible at small batch sizes, where Adam already converges well; the gap appears once the batch size exceeds several thousand examples. Compared with LARS, LAMB inherits LARS's layer-wise adaptation but uses adaptive moment estimates rather than momentum, which makes it more suitable for transformer-style architectures where gradient magnitudes vary by orders of magnitude across layers.
The original paper provides a convergence analysis under standard smoothness and bounded-variance assumptions, showing an $ O(1/\sqrt{T}) $ rate to a stationary point for non-convex objectives. The analysis also gives a sense in which LAMB's update is the unique scaling that keeps the per-layer update norm proportional to the weight norm, motivating the algorithm beyond purely empirical grounds.
Limitations
LAMB does not produce the same trained model as small-batch Adam. The implicit regularization differences between small and large batches are not fully neutralized, and downstream task accuracy can be slightly lower at very large batch sizes despite matching pre-training loss. The algorithm also exposes additional hyperparameters: the trust-ratio clip range, the per-parameter-group exemptions, and the choice of $ \phi $. In practice, default values from the BERT recipe transfer well to other transformer pre-training but require retuning for vision or reinforcement-learning workloads.
Empirically, LAMB has been most successful on transformer pre-training and large-scale supervised learning. Its advantages over well-tuned Adam are smaller for fine-tuning, where batches are typically modest and the optimization landscape near a pre-trained initialization is well-conditioned. For very small models or tabular data, simpler optimizers usually suffice.