Wasserstein Loss
| Article | |
|---|---|
| Topic area | Generative Models |
| Prerequisites | Generative Adversarial Network, Optimal Transport, Lipschitz Continuity, Kullback-Leibler Divergence |
Overview
The Wasserstein loss is an objective function for training generative models that measures the distance between two probability distributions using the earth mover's distance (EMD), also known as the 1-Wasserstein distance. Introduced for generative adversarial networks (GANs) by Arjovsky et al. in 2017, it replaced the Jensen-Shannon criterion implicit in the original GAN with a metric that remains well-defined and continuous even when the model and data distributions have disjoint supports. The resulting Wasserstein GAN (WGAN) framework substantially improved training stability and provided a loss whose magnitude correlates with sample quality.
The Wasserstein loss is now a standard tool whenever distributional comparison is needed and the supports may not overlap. Beyond GANs, it appears in domain adaptation, distributionally robust optimization, and density estimation. Its core appeal is geometric: rather than asking how much one distribution overlaps another, it asks how much probability mass must be moved, and how far, to transform one into the other.
Intuition
Consider two piles of dirt arranged on a line, representing two distributions of unit total mass. The earth mover's distance is the minimum cost of transforming the first pile into the second by transporting dirt, where moving a unit of mass over a distance $ d $ incurs cost $ d $. If the piles are identical, the cost is zero. If one pile is centered at $ 0 $ and the other at $ \theta $, the cost is $ |\theta| $, regardless of how concentrated each pile is.
This last property exposes the central weakness of KL divergence and Jensen-Shannon divergence in the same scenario. If both distributions are point masses (Dirac deltas) at $ 0 $ and $ \theta $, KL is infinite for $ \theta \neq 0 $, and Jensen-Shannon is the constant $ \log 2 $. Neither has a useful gradient with respect to $ \theta $. The Wasserstein distance, by contrast, equals $ |\theta| $ with informative gradient $ \mathrm{sign}(\theta) $. This is exactly the situation a generator faces early in training, when its output distribution is supported on a low-dimensional manifold that almost certainly does not overlap the data manifold. A loss that provides gradient in this regime can drive learning where divergence-based losses cannot.
Formulation
For two probability measures $ P_r $ (real) and $ P_g $ (generated) on a metric space $ (\mathcal{X}, d) $, the 1-Wasserstein distance is
$ {\displaystyle W_1(P_r, P_g) = \inf_{\gamma \in \Pi(P_r, P_g)} \mathbb{E}_{(x, y) \sim \gamma} [\, d(x, y) \,]} $
where $ \Pi(P_r, P_g) $ is the set of joint distributions with marginals $ P_r $ and $ P_g $. Each $ \gamma $ is a transport plan: $ \gamma(x, y) $ specifies how much mass at $ x $ in the source is sent to $ y $ in the target.
This primal form is intractable for high-dimensional continuous distributions. The Kantorovich-Rubinstein duality provides a workable equivalent:
$ {\displaystyle W_1(P_r, P_g) = \sup_{\|f\|_L \leq 1} \, \mathbb{E}_{x \sim P_r}[f(x)] - \mathbb{E}_{x \sim P_g}[f(x)]} $
where the supremum is over all 1-Lipschitz functions $ f : \mathcal{X} \to \mathbb{R} $. The function $ f $ is referred to as the critic (rather than discriminator) because it produces an unbounded real-valued score instead of a probability.
In a Wasserstein GAN, the critic $ f_w $ is parameterized by a neural network and trained to maximize the dual objective, while the generator $ g_\theta $ is trained to minimize $ -\mathbb{E}[f_w(g_\theta(z))] $. The min-max problem is
$ {\displaystyle \min_\theta \max_{w : \|f_w\|_L \leq 1} \mathbb{E}_{x \sim P_r}[f_w(x)] - \mathbb{E}_{z \sim P_z}[f_w(g_\theta(z))].} $
The critical practical question is how to enforce the 1-Lipschitz constraint on $ f_w $.
Enforcing the Lipschitz Constraint
Weight Clipping
The original WGAN clips each weight of the critic to a small interval $ [-c, c] $ after every gradient update, with $ c $ typically $ 0.01 $. This guarantees a bounded Lipschitz constant but restricts the critic's expressive capacity and makes the effective constraint sensitive to $ c $. Clipped networks frequently saturate or fail to learn high-frequency features.
Gradient Penalty
WGAN-GP (Gulrajani et al., 2017) replaces clipping with a soft penalty that enforces the gradient norm of the critic to be approximately $ 1 $:
$ {\displaystyle \mathcal{L}_{\mathrm{GP}} = \mathbb{E}_{\hat{x} \sim P_{\hat{x}}} \left[ (\|\nabla_{\hat{x}} f_w(\hat{x})\|_2 - 1)^2 \right]} $
where $ \hat{x} $ is sampled uniformly along straight lines between pairs from $ P_r $ and $ P_g $. The total critic loss becomes the dual objective minus $ \lambda \mathcal{L}_{\mathrm{GP}} $, with $ \lambda $ typically $ 10 $. WGAN-GP is incompatible with batch normalization in the critic, since the penalty is per-sample, but works well with layer or instance normalization.
Spectral Normalization
Spectral Normalization (Miyato et al., 2018) divides each weight matrix by its largest singular value, approximated by power iteration. Since the Lipschitz constant of a composition bounds by the product of layer Lipschitz constants, normalizing each layer to have spectral norm $ 1 $ caps the network at $ 1 $-Lipschitz. This is computationally cheap and decouples from the optimizer.
Other Approaches
Consistency penalties, hinge losses, and direct projection onto Lipschitz function classes have all been proposed. The choice interacts with optimizer (typically Adam or RMSProp), critic-to-generator update ratio (often 5:1), and architecture.
Training and Inference
A typical Wasserstein GAN training step alternates:
- Sample a minibatch from the real data and a batch of latent vectors $ z \sim P_z $.
- Update the critic $ n_{\mathrm{critic}} $ times by ascending the dual objective (and subtracting the gradient penalty if applicable).
- Update the generator once by descending $ -\mathbb{E}[f_w(g_\theta(z))] $.
The critic is trained to near-convergence at every generator step in the original WGAN paper, in contrast to the original GAN where the discriminator is held back. With Wasserstein loss, a stronger critic gives a better gradient signal because the dual objective provides an estimate of $ W_1 $ that the generator can usefully follow.
At inference time the loss has no role. Sample quality is assessed by the usual generative metrics (Inception Score, Frechet Inception Distance), with the additional benefit that the converged value of the critic objective tracks $ W_1 $ and can serve as a training-time proxy for sample quality.
Variants
- Sliced Wasserstein Distance computes $ W_1 $ on one-dimensional projections and averages over directions, exploiting the closed-form 1D solution to avoid neural critic estimation.
- Sinkhorn Divergence approximates Wasserstein with an entropic regularizer, enabling a differentiable, GPU-parallel implementation via the Sinkhorn algorithm.
- Energy distance and MMD are kernel-based alternatives that share the property of remaining defined under disjoint support, with different bias-variance trade-offs.
- Relativistic critics modify the dual objective to score relative differences between real and fake samples, often improving stability.
- $ p $-Wasserstein for $ p > 1 $ raises the transport cost to power $ p $; only $ p = 1 $ admits the Kantorovich-Rubinstein dual used in WGAN, but $ p = 2 $ arises in optimal-transport-based density estimation and diffusion models.
Comparison with Other Losses
| Loss | Disjoint support | Continuous in $ \theta $ | Bounded |
|---|---|---|---|
| KL divergence | Infinite | No | No |
| Jensen-Shannon | Constant $ \log 2 $ | No | Yes |
| Wasserstein-1 | Finite, geometric | Yes | No |
| MMD (with characteristic kernel) | Finite | Yes | Yes |
The original GAN minimizes a quantity proportional to Jensen-Shannon when the discriminator is optimal. This explains the empirical pathology of mode collapse and vanishing gradients: when the generator is far from the data, the discriminator becomes near-perfect and provides almost no useful signal. Wasserstein loss avoids this regime structurally.
Compared with least-squares or hinge losses, Wasserstein typically yields more interpretable training curves but demands the Lipschitz constraint and a higher critic-to-generator update ratio. In practice many production GANs combine spectral normalization with hinge or non-saturating losses, blurring the categorical distinction.
Limitations
The Lipschitz constraint, however enforced, is approximate. Weight clipping caps capacity; gradient penalty is enforced only along sampled lines; spectral normalization bounds layer-wise norms but the network's true Lipschitz constant may be smaller. The dual objective therefore estimates $ W_1 $ up to a multiplicative factor that varies during training, so the loss value is comparable across runs only loosely.
Wasserstein loss does not eliminate mode collapse outright; while less prone to it than the original GAN, partial mode dropping can still occur, especially with weak critics. Computational cost is higher per generator update because of the larger critic-to-generator ratio. Gradient penalty also requires a second-order gradient (gradient of a gradient norm), increasing memory and compute by a constant factor.
The 1-Wasserstein distance ignores higher-order geometric structure that $ p = 2 $ captures, and on very high-dimensional natural images the metric assumed in pixel space may not match perceptual similarity, which is one reason perceptual losses or FID-style feature-space distances are sometimes used in addition.
References
- ↑ Template:Cite arxiv
- ↑ Template:Cite arxiv
- ↑ Template:Cite arxiv
- ↑ Villani, Cedric. Optimal Transport: Old and New. Springer, 2009.
- ↑ Template:Cite arxiv