Distributed Training

    From Marovi AI
    This page contains changes which are not marked for translation.
    Other languages:
    Article
    Topic area Deep Learning
    Prerequisites Stochastic Gradient Descent, Backpropagation


    Overview

    Distributed training is the practice of splitting a single neural network training job across multiple processors, accelerators, or machines so that the workload runs in parallel. It addresses two limits of single-device training: the model may be too large to fit in one device's memory, and the wall-clock time to converge on a large dataset may be unacceptable. Modern foundation models with hundreds of billions of parameters cannot be trained at all without it, and even smaller models routinely use it to shorten training time from weeks to hours. The field combines ideas from numerical optimization, high-performance computing, and systems engineering: how to coordinate stochastic optimization across workers, how to move tensors efficiently between devices, and how to schedule computation so accelerators stay busy.

    The two foundational axes are data parallelism, in which each worker holds a replica of the model and processes a different shard of the batch, and model parallelism, in which the model itself is split across workers. Real-world systems combine both with pipeline parallelism and optimizer-state sharding to scale to thousands of accelerators while keeping arithmetic intensity high and communication tolerable.

    Motivation

    Three pressures push training off a single device. First, parameter count: a model with $ N $ parameters in 16-bit precision needs at least $ 2N $ bytes for weights, with similar amounts for gradients and optimizer state, so a 70-billion-parameter model in mixed precision easily exceeds 1 TB of memory once Adam moments are included. Second, data volume: training a large language model on trillions of tokens at single-device throughput would take years. Third, sample efficiency: larger effective batch sizes, made tractable by data parallelism, can stabilize gradients and improve Generalization up to a critical batch size beyond which returns diminish.

    Scaling, however, is not free. Communication between workers competes with computation, and synchronization barriers can leave devices idle. The central engineering question of distributed training is how to overlap, compress, or eliminate that communication while preserving the optimization dynamics of the original algorithm.

    Data parallelism

    In synchronous data-parallel training, each of $ K $ workers holds an identical copy of the parameters $ \theta $ and processes a distinct micro-batch of size $ B $. Each worker computes a local gradient $ g_k = \nabla_\theta L_k(\theta) $, the workers exchange gradients with an all-reduce so that every worker holds the average $ \bar{g} = \frac{1}{K}\sum_{k=1}^{K} g_k $, and each worker applies the same optimizer update locally:

    $ {\displaystyle \theta \leftarrow \theta - \eta \, \bar{g}.} $

    Because all replicas start identical and apply the same update to the same averaged gradient, they remain bit-identical (modulo non-determinism in reductions). The effective batch size is $ K B $, which usually requires a learning-rate warmup and linear scaling rule to maintain stability.

    The all-reduce primitive is implemented with a ring or tree algorithm; ring all-reduce moves $ 2(K-1)/K $ times the gradient size per worker and is bandwidth-optimal on uniform interconnects. NCCL, Gloo, and MPI provide tuned implementations.

    An older alternative, the parameter server, has a central process hold the canonical parameters and aggregate worker updates. Asynchronous variants allow workers to push stale gradients without waiting, trading optimization quality for throughput; this approach is mostly historical for dense deep learning but persists in some recommender and embedding workloads where sparsity and load imbalance favor it.

    Model parallelism

    When a single layer's parameters or activations exceed device memory, the layer itself must be partitioned. In tensor parallelism, a matrix multiplication $ Y = X W $ is split along the hidden dimension: each worker holds a column slab $ W_k $, computes $ Y_k = X W_k $, and the partial outputs are concatenated or summed depending on the layer. Megatron-LM popularized this scheme for transformer blocks by sharding the attention and MLP projections so that each block requires only two all-reduces in the forward and backward passes.

    Tensor parallelism scales well within a node where high-bandwidth links such as NVLink connect a small group of GPUs (typically 4 to 8). Across nodes, the per-step communication cost grows quickly and is usually combined with other forms of parallelism rather than used alone.

    Pipeline parallelism

    Pipeline parallelism partitions the network across workers by depth: worker $ k $ holds a contiguous block of layers and forwards activations to worker $ k+1 $. Naively running one micro-batch at a time leaves most workers idle (the "pipeline bubble"). GPipe addresses this by splitting a mini-batch into $ M $ micro-batches and overlapping their forward passes; the bubble fraction is roughly $ (P-1)/(M+P-1) $ for $ P $ stages. PipeDream-Flush and 1F1B schedules interleave forward and backward passes more aggressively to reduce activation memory.

    Pipeline parallelism is sensitive to load balance: a single slow stage stalls the pipeline. Practitioners profile per-stage time and adjust layer assignments, sometimes shifting work between stages or duplicating cheap layers.

    Hybrid and 3D parallelism

    At cluster scale, no single axis suffices. The common arrangement, sometimes called 3D parallelism, composes data, tensor, and pipeline parallelism. Tensor parallelism handles the largest layers within a node, pipeline parallelism partitions the model across nodes, and data parallelism replicates this whole pipeline across replica groups for throughput. Sequence parallelism, which shards activations along the sequence dimension, is a fourth axis used for very long contexts.

    Choosing a partition is a discrete optimization problem driven by memory budgets, interconnect topology, and the model's attention and feed-forward shapes. Tools such as Alpa, GSPMD, and the partitioning routines in DeepSpeed and Megatron search this space automatically.

    ZeRO and sharded optimizer state

    The Zero Redundancy Optimizer (ZeRO) reduces the memory cost of data parallelism by sharding what each replica stores. ZeRO-1 partitions optimizer state across the data-parallel group; ZeRO-2 also partitions gradients; ZeRO-3 (also known as Fully Sharded Data Parallel, FSDP) partitions the parameters themselves, gathering them on demand for each layer. Because the gather and reduce-scatter primitives compose with backward computation, communication can largely overlap with arithmetic, yielding memory close to model parallelism with the simpler programming model of data parallelism.

    Combined with mixed-precision storage and CPU or NVMe offloading, ZeRO has enabled the training of trillion-parameter models on commodity-scale GPU clusters.

    Communication and synchronization

    The dominant cost in distributed training is moving tensors. Several techniques reduce or hide it:

    • Gradient compression, including 1-bit SGD, signSGD, and PowerSGD, reduces the bytes per all-reduce at the cost of introducing bias or variance that must be controlled.
    • Local SGD performs $ H $ steps on each worker before averaging, lowering communication frequency at the price of slight optimization drift.
    • Overlap of computation and communication uses the fact that gradients for early layers are ready while later layers are still computing; frameworks such as PyTorch DDP and Horovod schedule asynchronous all-reduces in buckets so the network is in flight during the backward pass.
    • Topology-aware collectives prefer intra-node, high-bandwidth links for early reduction and only cross slow links once.

    Practical considerations and limitations

    Distributed training amplifies every small failure. Stragglers, link errors, and silent data corruption become routine at thousands of accelerators, so checkpoint frequency, deterministic restarts, and elastic resharding receive significant engineering attention. Numerical issues such as loss-spike instabilities are more common at large batch sizes and may require gradient clipping, lower learning rates, or recovery from a recent checkpoint.

    The choice of parallelism strategy is rarely portable: a partition tuned for one cluster topology may be far from optimal on another, and changing the model shape can require re-sharding optimizer state. Communication libraries, kernel launch overheads, and per-device memory fragmentation all become first-order concerns. Despite the engineering burden, distributed training is the only known route to current frontier-scale models, and improvements in collective algorithms, sharding strategies, and accelerator interconnects continue to widen the practical envelope.

    References

    [1] [2] [3] [4] [5] [6]