Loss Functions/zh

    From Marovi AI
    This page is a translated version of the page Loss Functions and the translation is 76% complete.
    Outdated translations are marked like this.
    Other languages:
    Article
    Topic area Machine Learning
    Difficulty Introductory

    損失函數(也稱為代價函數目標函數)量化模型預測與期望輸出之間的差距。最小化損失函數是機器學習訓練過程的核心目標:優化算法調整模型的參數,使損失儘可能低。

    目的

    損失函數將模型的預測 $ \hat{y} $ 和真實目標 $ y $ 映射到一個非負實數。形式上,對於單個樣本:

    $ \ell: \mathcal{Y} \times \mathcal{Y} \to \mathbb{R}_{\geq 0} $

    在包含 $ N $ 個樣本的數據集上,總損失通常是平均值:

    $ L(\theta) = \frac{1}{N}\sum_{i=1}^{N}\ell\bigl(y_i,\, \hat{y}_i(\theta)\bigr) $

    損失函數的選擇編碼了問題的結構——哪些錯誤重要以及應當以多大嚴厲程度對其進行懲罰。選擇不當的損失函數會導致模型優化錯誤的目標。

    均方誤差

    均方誤差(MSE)是回歸任務的默認損失函數:

    $ L_{\text{MSE}} = \frac{1}{N}\sum_{i=1}^{N}(y_i - \hat{y}_i)^2 $

    MSE 以二次方式懲罰較大的誤差,使其對離群值敏感。其梯度計算簡單:

    $ \frac{\partial}{\partial \hat{y}_i} (y_i - \hat{y}_i)^2 = -2(y_i - \hat{y}_i) $

    一個密切相關的變體是平均絕對誤差(MAE),$ \frac{1}{N}\sum|y_i - \hat{y}_i| $,它對離群值更具魯棒性,但在零點處梯度不平滑。Huber 損失結合了兩者:對小誤差表現得像 MSE,對大誤差表現得像 MAE。

    交叉熵損失

    交叉熵損失是分類任務的標準選擇。它衡量預測概率分布與真實標籤分布之間的差異。

    二元交叉熵

    對於具有預測概率 $ p $ 和真實標籤 $ y \in \{0, 1\} $ 的二元分類:

    $ L_{\text{BCE}} = -\frac{1}{N}\sum_{i=1}^{N}\bigl[y_i \log p_i + (1 - y_i)\log(1 - p_i)\bigr] $

    當預測概率與真實標籤完全匹配時($ y = 1 $$ p = 1 $$ y = 0 $$ p = 0 $),此損失最小化。

    多類交叉熵

    對於具有 $ C $ 個類別的多類分類和預測概率向量 $ \hat{\mathbf{y}} $

    $ L_{\text{CE}} = -\frac{1}{N}\sum_{i=1}^{N}\sum_{c=1}^{C} y_{i,c} \log \hat{y}_{i,c} $

    當真實標籤採用 one-hot 編碼時,只有對應於正確類別的項保留下來。

    合頁損失

    合頁損失(hinge loss)與支持向量機(SVM)和最大間隔分類器相關。對於具有標籤 $ y \in \{-1, +1\} $ 和模型原始輸出 $ s $ 的二元分類問題:

    $ L_{\text{hinge}} = \frac{1}{N}\sum_{i=1}^{N}\max(0,\; 1 - y_i \, s_i) $

    當預測具有正確符號且間隔至少為 1 時,合頁損失為零;否則線性增加。由於在合頁點處不可微,因此使用次梯度方法進行優化。

    其他常見的損失函數

    損失 公式 典型用途
    Huber $ \begin{cases}\tfrac{1}{2}(y-\hat{y})^2 & |y-\hat{y}|\leq\delta \\ \delta(|y-\hat{y}|-\tfrac{\delta}{2}) & \text{otherwise}\end{cases} $ 魯棒回歸
    KL 散度 $ \sum_c p_c \log\frac{p_c}{q_c} $ 分布匹配,VAE
    Focal 損失 $ -\alpha(1-p_t)^\gamma \log p_t $ 不平衡分類
    CTC 損失 在對齊上的動態規劃 語音識別、OCR
    三元組損失 $ \max(0,\; d(a,p) - d(a,n) + m) $ 度量學習、人臉驗證

    選擇合適的損失函數

    合適的損失函數取決於具體任務:

    • 回歸 —— MSE 是默認選擇;如果擔心離群值,可切換到 MAE 或 Huber。
    • 二元分類 —— 使用 sigmoid 輸出的二元交叉熵。
    • 多類分類 —— 使用 softmax 輸出的多類交叉熵。
    • 多標籤分類 —— 對每個標籤獨立應用二元交叉熵。
    • 排序或檢索 —— 對比損失、三元組損失或 listwise 排序損失。

    一個重要的考慮因素是損失是否已校準——即最小化它是否能產生校準良好的預測概率。交叉熵是一種適當的評分規則,能產生校準的概率,而合頁損失則不能。

    正則化項

    在實際應用中,總目標通常包含一個正則化項,用於懲罰模型的複雜度:

    $ J(\theta) = L(\theta) + \lambda \, R(\theta) $

    其中 $ \lambda $ 控制正則化的強度。常見選擇包括 L2 正則化($ R = \|\theta\|_2^2 $)和 L1 正則化($ R = \|\theta\|_1 $)。詳見 Overfitting and Regularization

    參見

    參考文獻

    • Bishop, C. M. (2006). Pattern Recognition and Machine Learning,第 1 章。Springer。
    • Goodfellow, I., Bengio, Y. 與 Courville, A. (2016). Deep Learning,第 6 章與第 8 章。MIT Press。
    • Lin, T.-Y. et al. (2017). "Focal Loss for Dense Object Detection". ICCV.
    • Murphy, K. P. (2022). Probabilistic Machine Learning: An Introduction. MIT Press.