Wasserstein Loss/zh
| Article | |
|---|---|
| Topic area | Generative Models |
| Prerequisites | Generative Adversarial Network, Optimal Transport, Lipschitz Continuity, Kullback-Leibler Divergence |
概述
Wasserstein 損失是一種用於訓練生成模型的目標函數,它使用推土機距離(EMD),也稱為 1-Wasserstein 距離,來衡量兩個概率分佈之間的距離。該方法由 Arjovsky 等人於 2017 年為生成對抗網絡(GAN)引入,它取代了原始 GAN 中隱含的Jensen-Shannon準則,所採用的度量即使在模型分佈和數據分佈的支集不相交時,仍然定義良好且連續。由此產生的Wasserstein GAN(WGAN)框架顯著改善了訓練穩定性,並提供了一種其大小與樣本質量相關的損失。
Wasserstein 損失目前是在需要進行分佈比較且支集可能不重疊時的標準工具。除了 GAN 之外,它還出現在領域自適應、分佈魯棒優化和密度估計中。其核心吸引力是幾何性的:它不是詢問一個分佈與另一個分佈重疊多少,而是詢問需要移動多少概率質量以及移動多遠,才能將一個分佈轉化為另一個分佈。
直覺
考慮排列在一條直線上的兩堆土,它們表示兩個總質量為單位 1 的分佈。推土機距離是通過運輸土壤將第一堆轉化為第二堆所需的最小成本,其中將一單位質量移動距離 $ d $ 的成本為 $ d $。如果兩堆完全相同,則成本為零。如果一堆以 $ 0 $ 為中心,另一堆以 $ \theta $ 為中心,則成本為 $ |\theta| $,與每堆的集中程度無關。
這一性質揭示了KL 散度和Jensen-Shannon 散度在同一場景下的核心弱點。如果兩個分佈都是位於 $ 0 $ 和 $ \theta $ 處的點質量(Dirac delta),則 KL 在 $ \theta \neq 0 $ 時為無窮大,而 Jensen-Shannon 為常數 $ \log 2 $。兩者都沒有關於 $ \theta $ 的有用梯度。相比之下,Wasserstein 距離等於 $ |\theta| $,具有信息梯度 $ \mathrm{sign}(\theta) $。這正是生成器在訓練初期所面臨的情況,此時其輸出分佈的支集是一個低維流形,幾乎肯定與數據流形不重疊。在這種狀態下提供梯度的損失可以推動學習,而基於散度的損失則無法做到這一點。
公式化
對於度量空間 $ (\mathcal{X}, d) $ 上的兩個概率測度 $ P_r $(真實)和 $ P_g $(生成),1-Wasserstein 距離定義為
$ {\displaystyle W_1(P_r, P_g) = \inf_{\gamma \in \Pi(P_r, P_g)} \mathbb{E}_{(x, y) \sim \gamma} [\, d(x, y) \,]} $
其中 $ \Pi(P_r, P_g) $ 是邊際分佈為 $ P_r $ 和 $ P_g $ 的聯合分佈的集合。每個 $ \gamma $ 是一個運輸計劃:$ \gamma(x, y) $ 指定源中位於 $ x $ 處的多少質量被發送到目標中的 $ y $ 處。
對於高維連續分佈,這種原始形式難以處理。Kantorovich-Rubinstein 對偶性提供了一個可行的等價形式:
$ {\displaystyle W_1(P_r, P_g) = \sup_{\|f\|_L \leq 1} \, \mathbb{E}_{x \sim P_r}[f(x)] - \mathbb{E}_{x \sim P_g}[f(x)]} $
其中上確界取遍所有 1-Lipschitz 函數 $ f : \mathcal{X} \to \mathbb{R} $。函數 $ f $ 被稱為評論家(而不是判別器),因為它產生一個無界的實值分數,而不是概率。
在 Wasserstein GAN 中,評論家 $ f_w $ 由神經網絡參數化,並訓練以最大化對偶目標,而生成器 $ g_\theta $ 訓練以最小化 $ -\mathbb{E}[f_w(g_\theta(z))] $。極小-極大問題為
$ {\displaystyle \min_\theta \max_{w : \|f_w\|_L \leq 1} \mathbb{E}_{x \sim P_r}[f_w(x)] - \mathbb{E}_{z \sim P_z}[f_w(g_\theta(z))].} $
關鍵的實際問題是如何對 $ f_w $ 施加 1-Lipschitz 約束。
強制 Lipschitz 約束
權重裁剪
原始的 WGAN 在每次梯度更新後將評論家的每個權重裁剪到一個小區間 $ [-c, c] $,$ c $ 通常取 $ 0.01 $。這保證了 Lipschitz 常數有界,但限制了評論家的表達容量,並使有效約束對 $ c $ 敏感。裁剪後的網絡經常飽和或無法學習高頻特徵。
梯度懲罰
WGAN-GP(Gulrajani 等人,2017)用一種軟懲罰替代裁剪,強制評論家的梯度範數近似為 $ 1 $:
$ {\displaystyle \mathcal{L}_{\mathrm{GP}} = \mathbb{E}_{\hat{x} \sim P_{\hat{x}}} \left[ (\|\nabla_{\hat{x}} f_w(\hat{x})\|_2 - 1)^2 \right]} $
其中 $ \hat{x} $ 沿着 $ P_r $ 和 $ P_g $ 中樣本對之間的直線均勻採樣。評論家的總損失變為對偶目標減去 $ \lambda \mathcal{L}_{\mathrm{GP}} $,$ \lambda $ 通常取 $ 10 $。WGAN-GP 與評論家中的批歸一化不兼容,因為該懲罰是逐樣本的,但與層歸一化或實例歸一化配合良好。
譜歸一化
譜歸一化(Miyato 等人,2018)將每個權重矩陣除以其最大的奇異值,通過冪迭代進行近似。由於複合函數的 Lipschitz 常數以各層 Lipschitz 常數的乘積為界,將每層歸一化為譜範數 $ 1 $ 即可將整個網絡限制為 1-Lipschitz。這在計算上代價較低,並且與優化器解耦。
其他方法
此外還提出了一致性懲罰、hinge 損失以及直接投影到 Lipschitz 函數類等方法。該選擇會與優化器(通常是 Adam 或 RMSProp)、評論家與生成器的更新比率(通常為 5:1)以及架構相互作用。
訓練與推理
一個典型的 Wasserstein GAN 訓練步驟交替執行以下操作:
- 從真實數據中採樣一個小批量,以及一批潛在向量 $ z \sim P_z $。
- 通過上升對偶目標(如果適用,減去梯度懲罰)來更新評論家 $ n_{\mathrm{critic}} $ 次。
- 通過下降 $ -\mathbb{E}[f_w(g_\theta(z))] $ 更新生成器一次。
在原始 WGAN 論文中,評論家在每個生成器步驟都被訓練到接近收斂,這與原始 GAN 中刻意壓制判別器的做法形成對比。使用Wasserstein 損失時,更強的評論家會提供更好的梯度信號,因為對偶目標提供了 $ W_1 $ 的估計,生成器可以有效地跟隨這一估計。
在推理時,損失不起作用。樣本質量通過常用的生成度量(Inception 得分、Frechet Inception 距離)來評估,額外的好處是評論家目標的收斂值跟蹤 $ W_1 $,可以作為訓練時樣本質量的代理指標。
變體
- 切片 Wasserstein 距離 在一維投影上計算 $ W_1 $ 並對方向進行平均,利用一維閉式解避免神經評論家的估計。
- Sinkhorn 散度 通過熵正則化器近似 Wasserstein,藉助 Sinkhorn 算法實現可微分、GPU 並行的實現方式。
- 能量距離和MMD是基於核的替代方法,它們都具有在支集不相交時仍能良好定義的性質,但具有不同的偏差-方差權衡。
- 相對論評論家修改對偶目標以對真實樣本與假樣本之間的相對差異進行評分,通常能改善穩定性。
- $ p $-Wasserstein當 $ p > 1 $ 時將運輸成本提升到 $ p $ 次冪;只有 $ p = 1 $ 允許 WGAN 中所用的 Kantorovich-Rubinstein 對偶,而 $ p = 2 $ 出現在基於最優運輸的密度估計和擴散模型中。
與其他損失的比較
| 損失 | 支集不相交 | 關於 $ \theta $ 是否連續 | 是否有界 |
|---|---|---|---|
| KL 散度 | 無窮 | 否 | 否 |
| Jensen-Shannon | 常數 $ \log 2 $ | 否 | 是 |
| Wasserstein-1 | 有限、幾何性 | 是 | 否 |
| MMD(帶特徵核) | 有限 | 是 | 是 |
當判別器最優時,原始 GAN 最小化的量與 Jensen-Shannon 成正比。這解釋了模式坍塌和梯度消失的實證病理:當生成器遠離數據時,判別器變得近乎完美,幾乎不提供任何有用的信號。Wasserstein 損失從結構上避免了這種狀態。
與最小二乘或hinge損失相比,Wasserstein 通常產生更易解釋的訓練曲線,但要求 Lipschitz 約束以及更高的評論家與生成器更新比率。實際上,許多生產級 GAN 將譜歸一化與 hinge 或非飽和損失結合,使類別之間的界限趨於模糊。
局限性
無論以何種方式強制執行,Lipschitz 約束都是近似的。權重裁剪限制了容量;梯度懲罰僅沿着採樣的直線被強制執行;譜歸一化限制了逐層範數,但網絡的真正 Lipschitz 常數可能更小。因此,對偶目標對 $ W_1 $ 的估計帶有一個在訓練過程中變化的乘性因子,所以損失值在不同運行之間只能粗略地進行比較。
Wasserstein 損失並不能徹底消除模式坍塌;雖然它比原始 GAN 更不易出現這種問題,但仍可能發生部分模式丟失,尤其是評論家較弱時。由於評論家與生成器的比率較大,每次生成器更新的計算開銷更高。梯度懲罰還需要二階梯度(梯度範數的梯度),使得內存與計算開銷增加一個常數因子。
1-Wasserstein 距離忽略了 $ p = 2 $ 所捕獲的高階幾何結構,而且在非常高維的自然圖像上,像素空間中所假設的度量可能與感知相似性不匹配,這正是有時還會額外使用感知損失或 FID 風格的特徵空間距離的原因之一。
參考文獻
- ↑ Template:Cite arxiv
- ↑ Template:Cite arxiv
- ↑ Template:Cite arxiv
- ↑ Villani, Cedric. Optimal Transport: Old and New. Springer, 2009.
- ↑ Template:Cite arxiv