Loss Functions/zh
| Article | |
|---|---|
| Topic area | Machine Learning |
| Difficulty | Introductory |
損失函數(也稱為代價函數或目標函數)量化模型預測與期望輸出之間的差距。最小化損失函數是機器學習訓練過程的核心目標:優化算法調整模型的參數,使損失儘可能低。
目的
損失函數將模型的預測 $ \hat{y} $ 和真實目標 $ y $ 映射到一個非負實數。形式上,對於單個樣本:
- $ \ell: \mathcal{Y} \times \mathcal{Y} \to \mathbb{R}_{\geq 0} $
在包含 $ N $ 個樣本的數據集上,總損失通常是平均值:
- $ L(\theta) = \frac{1}{N}\sum_{i=1}^{N}\ell\bigl(y_i,\, \hat{y}_i(\theta)\bigr) $
損失函數的選擇編碼了問題的結構——哪些錯誤重要以及應當以多大嚴厲程度對其進行懲罰。選擇不當的損失函數會導致模型優化錯誤的目標。
均方誤差
均方誤差(MSE)是回歸任務的默認損失函數:
- $ L_{\text{MSE}} = \frac{1}{N}\sum_{i=1}^{N}(y_i - \hat{y}_i)^2 $
MSE 以二次方式懲罰較大的誤差,使其對離群值敏感。其梯度計算簡單:
- $ \frac{\partial}{\partial \hat{y}_i} (y_i - \hat{y}_i)^2 = -2(y_i - \hat{y}_i) $
一個密切相關的變體是平均絕對誤差(MAE),$ \frac{1}{N}\sum|y_i - \hat{y}_i| $,它對離群值更具魯棒性,但在零點處梯度不平滑。Huber 損失結合了兩者:對小誤差表現得像 MSE,對大誤差表現得像 MAE。
交叉熵損失
交叉熵損失是分類任務的標準選擇。它衡量預測概率分佈與真實標籤分佈之間的差異。
二元交叉熵
對於具有預測概率 $ p $ 和真實標籤 $ y \in \{0, 1\} $ 的二元分類:
- $ L_{\text{BCE}} = -\frac{1}{N}\sum_{i=1}^{N}\bigl[y_i \log p_i + (1 - y_i)\log(1 - p_i)\bigr] $
當預測概率與真實標籤完全匹配時($ y = 1 $ 時 $ p = 1 $,$ y = 0 $ 時 $ p = 0 $),此損失最小化。
多類交叉熵
對於具有 $ C $ 個類別的多類分類和預測概率向量 $ \hat{\mathbf{y}} $:
- $ L_{\text{CE}} = -\frac{1}{N}\sum_{i=1}^{N}\sum_{c=1}^{C} y_{i,c} \log \hat{y}_{i,c} $
當真實標籤採用 one-hot 編碼時,只有對應於正確類別的項保留下來。
合頁損失
合頁損失(hinge loss)與支持向量機(SVM)和最大間隔分類器相關。對於具有標籤 $ y \in \{-1, +1\} $ 和模型原始輸出 $ s $ 的二元分類問題:
- $ L_{\text{hinge}} = \frac{1}{N}\sum_{i=1}^{N}\max(0,\; 1 - y_i \, s_i) $
當預測具有正確符號且間隔至少為 1 時,合頁損失為零;否則線性增加。由於在合頁點處不可微,因此使用次梯度方法進行優化。
其他常見的損失函數
| 損失 | 公式 | 典型用途 |
|---|---|---|
| Huber | $ \begin{cases}\tfrac{1}{2}(y-\hat{y})^2 & |y-\hat{y}|\leq\delta \\ \delta(|y-\hat{y}|-\tfrac{\delta}{2}) & \text{otherwise}\end{cases} $ | 魯棒回歸 |
| KL 散度 | $ \sum_c p_c \log\frac{p_c}{q_c} $ | 分佈匹配,VAE |
| Focal 損失 | $ -\alpha(1-p_t)^\gamma \log p_t $ | 不平衡分類 |
| CTC 損失 | 在對齊上的動態規劃 | 語音識別、OCR |
| 三元組損失 | $ \max(0,\; d(a,p) - d(a,n) + m) $ | 度量學習、人臉驗證 |
選擇合適的損失函數
合適的損失函數取決於具體任務:
- 回歸 —— MSE 是默認選擇;如果擔心離群值,可切換到 MAE 或 Huber。
- 二元分類 —— 使用 sigmoid 輸出的二元交叉熵。
- 多類分類 —— 使用 softmax 輸出的多類交叉熵。
- 多標籤分類 —— 對每個標籤獨立應用二元交叉熵。
- 排序或檢索 —— 對比損失、三元組損失或 listwise 排序損失。
一個重要的考慮因素是損失是否已校準——即最小化它是否能產生校準良好的預測概率。交叉熵是一種適當的評分規則,能產生校準的概率,而合頁損失則不能。
正則化項
在實際應用中,總目標通常包含一個正則化項,用於懲罰模型的複雜度:
- $ J(\theta) = L(\theta) + \lambda \, R(\theta) $
其中 $ \lambda $ 控制正則化的強度。常見選擇包括 L2 正則化($ R = \|\theta\|_2^2 $)和 L1 正則化($ R = \|\theta\|_1 $)。詳見 Overfitting and Regularization。
參見
- Gradient Descent
- Neural Networks
- Backpropagation
- Overfitting and Regularization
- Stochastic Gradient Descent
參考文獻
- Bishop, C. M. (2006). Pattern Recognition and Machine Learning,第 1 章。Springer。
- Goodfellow, I., Bengio, Y. 與 Courville, A. (2016). Deep Learning,第 6 章與第 8 章。MIT Press。
- Lin, T.-Y. et al. (2017). "Focal Loss for Dense Object Detection". ICCV.
- Murphy, K. P. (2022). Probabilistic Machine Learning: An Introduction. MIT Press.