Wasserstein Loss/zh
| Article | |
|---|---|
| Topic area | Generative Models |
| Prerequisites | Generative Adversarial Network, Optimal Transport, Lipschitz Continuity, Kullback-Leibler Divergence |
概述
Wasserstein 损失是一种用于训练生成模型的目标函数,它使用推土机距离(EMD),也称为 1-Wasserstein 距离,来衡量两个概率分布之间的距离。该方法由 Arjovsky 等人于 2017 年为生成对抗网络(GAN)引入,它取代了原始 GAN 中隐含的Jensen-Shannon准则,所采用的度量即使在模型分布和数据分布的支集不相交时,仍然定义良好且连续。由此产生的Wasserstein GAN(WGAN)框架显著改善了训练稳定性,并提供了一种其大小与样本质量相关的损失。
Wasserstein 损失目前是在需要进行分布比较且支集可能不重叠时的标准工具。除了 GAN 之外,它还出现在领域自适应、分布鲁棒优化和密度估计中。其核心吸引力是几何性的:它不是询问一个分布与另一个分布重叠多少,而是询问需要移动多少概率质量以及移动多远,才能将一个分布转化为另一个分布。
直觉
考虑排列在一条直线上的两堆土,它们表示两个总质量为单位 1 的分布。推土机距离是通过运输土壤将第一堆转化为第二堆所需的最小成本,其中将一单位质量移动距离 $ d $ 的成本为 $ d $。如果两堆完全相同,则成本为零。如果一堆以 $ 0 $ 为中心,另一堆以 $ \theta $ 为中心,则成本为 $ |\theta| $,与每堆的集中程度无关。
这一性质揭示了KL 散度和Jensen-Shannon 散度在同一场景下的核心弱点。如果两个分布都是位于 $ 0 $ 和 $ \theta $ 处的点质量(Dirac delta),则 KL 在 $ \theta \neq 0 $ 时为无穷大,而 Jensen-Shannon 为常数 $ \log 2 $。两者都没有关于 $ \theta $ 的有用梯度。相比之下,Wasserstein 距离等于 $ |\theta| $,具有信息梯度 $ \mathrm{sign}(\theta) $。这正是生成器在训练初期所面临的情况,此时其输出分布的支集是一个低维流形,几乎肯定与数据流形不重叠。在这种状态下提供梯度的损失可以推动学习,而基于散度的损失则无法做到这一点。
公式化
对于度量空间 $ (\mathcal{X}, d) $ 上的两个概率测度 $ P_r $(真实)和 $ P_g $(生成),1-Wasserstein 距离定义为
$ {\displaystyle W_1(P_r, P_g) = \inf_{\gamma \in \Pi(P_r, P_g)} \mathbb{E}_{(x, y) \sim \gamma} [\, d(x, y) \,]} $
其中 $ \Pi(P_r, P_g) $ 是边际分布为 $ P_r $ 和 $ P_g $ 的联合分布的集合。每个 $ \gamma $ 是一个运输计划:$ \gamma(x, y) $ 指定源中位于 $ x $ 处的多少质量被发送到目标中的 $ y $ 处。
对于高维连续分布,这种原始形式难以处理。Kantorovich-Rubinstein 对偶性提供了一个可行的等价形式:
$ {\displaystyle W_1(P_r, P_g) = \sup_{\|f\|_L \leq 1} \, \mathbb{E}_{x \sim P_r}[f(x)] - \mathbb{E}_{x \sim P_g}[f(x)]} $
其中上确界取遍所有 1-Lipschitz 函数 $ f : \mathcal{X} \to \mathbb{R} $。函数 $ f $ 被称为评论家(而不是判别器),因为它产生一个无界的实值分数,而不是概率。
在 Wasserstein GAN 中,评论家 $ f_w $ 由神经网络参数化,并训练以最大化对偶目标,而生成器 $ g_\theta $ 训练以最小化 $ -\mathbb{E}[f_w(g_\theta(z))] $。极小-极大问题为
$ {\displaystyle \min_\theta \max_{w : \|f_w\|_L \leq 1} \mathbb{E}_{x \sim P_r}[f_w(x)] - \mathbb{E}_{z \sim P_z}[f_w(g_\theta(z))].} $
关键的实际问题是如何对 $ f_w $ 施加 1-Lipschitz 约束。
强制 Lipschitz 约束
权重裁剪
原始的 WGAN 在每次梯度更新后将评论家的每个权重裁剪到一个小区间 $ [-c, c] $,$ c $ 通常取 $ 0.01 $。这保证了 Lipschitz 常数有界,但限制了评论家的表达容量,并使有效约束对 $ c $ 敏感。裁剪后的网络经常饱和或无法学习高频特征。
梯度惩罚
WGAN-GP(Gulrajani 等人,2017)用一种软惩罚替代裁剪,强制评论家的梯度范数近似为 $ 1 $:
$ {\displaystyle \mathcal{L}_{\mathrm{GP}} = \mathbb{E}_{\hat{x} \sim P_{\hat{x}}} \left[ (\|\nabla_{\hat{x}} f_w(\hat{x})\|_2 - 1)^2 \right]} $
其中 $ \hat{x} $ 沿着 $ P_r $ 和 $ P_g $ 中样本对之间的直线均匀采样。评论家的总损失变为对偶目标减去 $ \lambda \mathcal{L}_{\mathrm{GP}} $,$ \lambda $ 通常取 $ 10 $。WGAN-GP 与评论家中的批归一化不兼容,因为该惩罚是逐样本的,但与层归一化或实例归一化配合良好。
谱归一化
谱归一化(Miyato 等人,2018)将每个权重矩阵除以其最大的奇异值,通过幂迭代进行近似。由于复合函数的 Lipschitz 常数以各层 Lipschitz 常数的乘积为界,将每层归一化为谱范数 $ 1 $ 即可将整个网络限制为 1-Lipschitz。这在计算上代价较低,并且与优化器解耦。
其他方法
此外还提出了一致性惩罚、hinge 损失以及直接投影到 Lipschitz 函数类等方法。该选择会与优化器(通常是 Adam 或 RMSProp)、评论家与生成器的更新比率(通常为 5:1)以及架构相互作用。
训练与推理
一个典型的 Wasserstein GAN 训练步骤交替执行以下操作:
- 从真实数据中采样一个小批量,以及一批潜在向量 $ z \sim P_z $。
- 通过上升对偶目标(如果适用,减去梯度惩罚)来更新评论家 $ n_{\mathrm{critic}} $ 次。
- 通过下降 $ -\mathbb{E}[f_w(g_\theta(z))] $ 更新生成器一次。
在原始 WGAN 论文中,评论家在每个生成器步骤都被训练到接近收敛,这与原始 GAN 中刻意压制判别器的做法形成对比。使用Wasserstein 损失时,更强的评论家会提供更好的梯度信号,因为对偶目标提供了 $ W_1 $ 的估计,生成器可以有效地跟随这一估计。
在推理时,损失不起作用。样本质量通过常用的生成度量(Inception 得分、Frechet Inception 距离)来评估,额外的好处是评论家目标的收敛值跟踪 $ W_1 $,可以作为训练时样本质量的代理指标。
变体
- 切片 Wasserstein 距离 在一维投影上计算 $ W_1 $ 并对方向进行平均,利用一维闭式解避免神经评论家的估计。
- Sinkhorn 散度 通过熵正则化器近似 Wasserstein,借助 Sinkhorn 算法实现可微分、GPU 并行的实现方式。
- 能量距离和MMD是基于核的替代方法,它们都具有在支集不相交时仍能良好定义的性质,但具有不同的偏差-方差权衡。
- 相对论评论家修改对偶目标以对真实样本与假样本之间的相对差异进行评分,通常能改善稳定性。
- $ p $-Wasserstein当 $ p > 1 $ 时将运输成本提升到 $ p $ 次幂;只有 $ p = 1 $ 允许 WGAN 中所用的 Kantorovich-Rubinstein 对偶,而 $ p = 2 $ 出现在基于最优运输的密度估计和扩散模型中。
与其他损失的比较
| 损失 | 支集不相交 | 关于 $ \theta $ 是否连续 | 是否有界 |
|---|---|---|---|
| KL 散度 | 无穷 | 否 | 否 |
| Jensen-Shannon | 常数 $ \log 2 $ | 否 | 是 |
| Wasserstein-1 | 有限、几何性 | 是 | 否 |
| MMD(带特征核) | 有限 | 是 | 是 |
当判别器最优时,原始 GAN 最小化的量与 Jensen-Shannon 成正比。这解释了模式坍塌和梯度消失的实证病理:当生成器远离数据时,判别器变得近乎完美,几乎不提供任何有用的信号。Wasserstein 损失从结构上避免了这种状态。
与最小二乘或hinge损失相比,Wasserstein 通常产生更易解释的训练曲线,但要求 Lipschitz 约束以及更高的评论家与生成器更新比率。实际上,许多生产级 GAN 将谱归一化与 hinge 或非饱和损失结合,使类别之间的界限趋于模糊。
局限性
无论以何种方式强制执行,Lipschitz 约束都是近似的。权重裁剪限制了容量;梯度惩罚仅沿着采样的直线被强制执行;谱归一化限制了逐层范数,但网络的真正 Lipschitz 常数可能更小。因此,对偶目标对 $ W_1 $ 的估计带有一个在训练过程中变化的乘性因子,所以损失值在不同运行之间只能粗略地进行比较。
Wasserstein 损失并不能彻底消除模式坍塌;虽然它比原始 GAN 更不易出现这种问题,但仍可能发生部分模式丢失,尤其是评论家较弱时。由于评论家与生成器的比率较大,每次生成器更新的计算开销更高。梯度惩罚还需要二阶梯度(梯度范数的梯度),使得内存与计算开销增加一个常数因子。
1-Wasserstein 距离忽略了 $ p = 2 $ 所捕获的高阶几何结构,而且在非常高维的自然图像上,像素空间中所假设的度量可能与感知相似性不匹配,这正是有时还会额外使用感知损失或 FID 风格的特征空间距离的原因之一。
参考文献
- ↑ Template:Cite arxiv
- ↑ Template:Cite arxiv
- ↑ Template:Cite arxiv
- ↑ Villani, Cedric. Optimal Transport: Old and New. Springer, 2009.
- ↑ Template:Cite arxiv