Mixed Precision Training

    From Marovi AI
    This page contains changes which are not marked for translation.
    Other languages:
    Article
    Topic area Deep Learning
    Prerequisites Neural Networks, Backpropagation, Stochastic Gradient Descent


    Overview

    Mixed precision training is a technique that accelerates the training of deep neural networks by performing most arithmetic operations in a lower-precision floating-point format (typically 16 bits) while keeping a small number of numerically sensitive operations in higher precision (typically 32 bits). Introduced in its modern form by Micikevicius et al. in 2017, it has become the default training regime for large-scale deep learning, powering most contemporary work on convolutional networks, transformers, and large language models. Compared to pure single-precision (FP32) training, mixed precision typically halves memory consumption and delivers two to eight times higher throughput on hardware with dedicated reduced-precision matrix units, while reaching essentially the same final accuracy.

    The approach exploits the observation that neural network training is highly tolerant of numerical noise in most tensors — activations, gradients, and intermediate matrix multiplications — but requires high precision in a few critical places, particularly the master copy of the weights and certain reductions in the loss and optimizer.

    Floating-Point Formats

    Three floating-point formats dominate modern deep learning. The historical baseline is FP32 (IEEE 754 single precision), with one sign bit, eight exponent bits, and 23 mantissa bits, giving a dynamic range of roughly $ 10^{-38} $ to $ 10^{38} $ and about seven decimal digits of precision.

    FP16 (IEEE 754 half precision) uses one sign bit, five exponent bits, and 10 mantissa bits. Its dynamic range is much narrower — roughly $ 6 \times 10^{-5} $ to $ 6.5 \times 10^4 $ — which is the central numerical challenge for mixed precision training. Small gradients can underflow to zero, and large activations can overflow to infinity.

    BF16 (bfloat16), introduced by Google for the TPU and now supported on most modern accelerators, retains FP32's eight exponent bits but truncates the mantissa to seven bits. It has the same dynamic range as FP32 and far less precision than FP16, which makes it dramatically easier to use as a drop-in replacement for FP32 because underflow and overflow are rare. The trade-off is coarser rounding in each individual operation.

    A newer family of FP8 formats (E4M3 and E5M2, standardized in 2022) extends the same idea to eight bits, primarily for forward and backward matrix multiplies in very large transformer training. FP8 typically requires per-tensor scaling factors and is used alongside a higher-precision master format.

    The Mixed Precision Recipe

    The canonical recipe from Micikevicius et al. has three components.

    Master weights in FP32. The optimizer maintains a master copy of the model parameters in FP32. Before each forward pass, this master copy is cast down to the low-precision format (FP16 or BF16) to produce the working weights used in the network. After the optimizer step, the FP32 master is updated, not the low-precision copy. This prevents the small parameter updates produced by Stochastic Gradient Descent from being lost to rounding when added to a much larger weight value.

    Concretely, if $ w $ is a weight and $ \Delta w $ is its update, the FP16 representable spacing near $ w \approx 1 $ is roughly $ 2^{-10} \approx 10^{-3} $. Updates smaller than this magnitude — extremely common late in training — would be entirely lost if the addition were performed in FP16.

    Forward and backward in low precision. Activations, weight tensors, and gradients are stored in FP16 or BF16. Matrix multiplications and convolutions execute on dedicated tensor or matrix cores that consume low-precision inputs and accumulate in FP32 internally, then write back a low-precision output. This is where the memory and throughput gains come from.

    Loss scaling. Because FP16 has limited dynamic range, gradient values smaller than about $ 2^{-24} $ underflow to zero. The fix is to multiply the loss by a large scale factor $ S $ before backpropagation:

    $ {\displaystyle L_{\mathrm{scaled}} = S \cdot L} $

    By the chain rule, every gradient is then scaled by the same factor $ S $, lifting small values out of the underflow region. After the backward pass, gradients are unscaled (divided by $ S $) in FP32 before the optimizer step. With BF16, loss scaling is generally unnecessary because the format inherits FP32's exponent range.

    Dynamic Loss Scaling

    Choosing a single static value for $ S $ requires knowing the gradient distribution in advance. Modern frameworks instead use dynamic loss scaling, which adjusts $ S $ during training:

    • Start with a large initial value (e.g., $ S = 2^{16} $).
    • After each backward pass, check whether any gradient contains an infinity or NaN.
    • If overflow is detected, skip the optimizer step for that iteration and halve $ S $.
    • If no overflow has been detected for a fixed number of iterations (e.g., 2000), double $ S $.

    This procedure keeps the scale as large as numerically possible without losing iterations to overflow, and it adapts as the gradient magnitudes change over the course of training.

    Operations That Stay in FP32

    A handful of operations are routinely kept in FP32 even within an otherwise mixed-precision graph. These are the ones whose numerical behavior is sensitive to range or to repeated summation:

    • The softmax and log-softmax used in attention and classification heads, where small differences between large logits matter.
    • The cross-entropy loss computation, which combines a softmax with a logarithm of a small number.
    • Batch Normalization statistics — mean, variance, and the running estimates — which accumulate over many samples.
    • Reductions over long axes, such as gradient norms used for clipping.
    • The optimizer state (e.g., Adam's first and second moment estimates), which accumulates over many steps.

    Frameworks expose this distinction through autocast regions or op allow-lists: matrix multiplications and convolutions are downcast automatically, while listed operations remain in FP32.

    Implementations

    PyTorch provides mixed precision through torch.cuda.amp (the original FP16 API) and torch.amp (the unified FP16 / BF16 API), combined with GradScaler for loss scaling. TensorFlow exposes the same idea through tf.keras.mixed_precision policies. JAX uses explicit dtype control plus libraries such as Optax for loss scaling.

    NVIDIA's Apex library was the first widely used mixed-precision toolkit and predated the framework-native APIs; it remains historically important as the source of dynamic loss scaling. The closely related TF32 format (used implicitly by Ampere-class GPUs for FP32 matmuls) is sometimes grouped with mixed precision but is technically a separate optimization that runs FP32 inputs through a reduced-precision multiplier.

    Comparison with Pure Low Precision

    Pure FP16 training without master weights or loss scaling typically diverges or stalls because of update underflow and gradient underflow. Pure BF16 training without master weights often works for moderate-sized models but tends to lose final accuracy on long training runs, especially for convex tail losses, because the seven-bit mantissa is too coarse to accumulate small Adam moments accurately. Mixed precision restores this accuracy by keeping the optimizer state and the master weights in FP32 while still extracting most of the throughput benefit from the low-precision compute path.

    Limitations and Failure Modes

    Mixed precision is not free of pitfalls. The most common failures include:

    • Persistent NaNs early in training, usually caused by initial activations that exceed the FP16 max value of about $ 6.5 \times 10^4 $. The remedy is BF16, careful initialization, or layer-wise gradient clipping.
    • Silent accuracy loss when an operation that should have been kept in FP32 — for example a softmax over very long sequences — is accidentally executed in FP16. Auditing the autocast policy is the standard fix.
    • Loss-scale collapse, where the dynamic scale falls to one and stays there. This indicates a real numerical problem, not a tuning issue, and usually points to bad data or an unstable model component.
    • Reduced reproducibility across hardware: different generations of tensor cores can produce slightly different bit-level results for the same FP16 matmul, which complicates exact-reproducibility testing.

    For very large models, FP8 introduces additional considerations — per-tensor scales must be tracked and updated — but the high-level structure is the same as the original FP16 recipe.

    See Also

    References

    [1] [2] [3] [4]

    1. Template:Cite arxiv
    2. Template:Cite arxiv
    3. Template:Cite arxiv
    4. NVIDIA, "Train With Mixed Precision," NVIDIA Deep Learning Performance Documentation, 2023.