Cross-Entropy Loss/zh

    From Marovi AI
    This page is a translated version of the page Cross-Entropy Loss and the translation is 100% complete.
    Other languages:
    Article
    Topic area Machine Learning
    Difficulty Intermediate
    Prerequisites Loss Functions, Softmax Function

    交叉熵损失(也称为对数损失)是机器学习中分类任务最广泛使用的损失函数。它根植于信息论,衡量真实标签分布与模型预测概率分布之间的差异,提供了一个平滑、可微的目标函数,驱动概率分类器做出自信而正确的预测。

    信息论基础

    离散概率分布 $ p $量化其不确定性:

    $ H(p) = -\sum_{k=1}^{K} p_k \log p_k $

    对于确定性分布(one-hot 标签),$ H(p) = 0 $。当所有结果等概率出现时,熵达到最大。

    KL 散度

    Kullback-Leibler 散度衡量一个分布 $ q $ 与参考分布 $ p $ 之间的差异:

    $ D_{\mathrm{KL}}(p \,\|\, q) = \sum_{k=1}^{K} p_k \log \frac{p_k}{q_k} $

    KL 散度是非负的,并且当且仅当 $ p = q $ 时等于零。

    交叉熵

    分布 $ p $(真实)与 $ q $(预测)之间的交叉熵为:

    $ H(p, q) = -\sum_{k=1}^{K} p_k \log q_k = H(p) + D_{\mathrm{KL}}(p \,\|\, q) $

    由于 $ H(p) $ 相对于模型参数是常数,最小化交叉熵等价于最小化 KL 散度——即使预测分布 $ q $ 尽可能接近真实分布 $ p $

    二元交叉熵

    对于二元分类,真实标签 $ y \in \{0, 1\} $,预测概率 $ \hat{y} = \sigma(z) $(其中 $ \sigma $sigmoid 函数):

    $ \mathcal{L}_{\mathrm{BCE}} = -\bigl[y \log \hat{y} + (1 - y) \log(1 - \hat{y})\bigr] $

    在包含 $ N $ 个样本的数据集上:

    $ \mathcal{L} = -\frac{1}{N} \sum_{i=1}^{N} \bigl[y_i \log \hat{y}_i + (1 - y_i) \log(1 - \hat{y}_i)\bigr] $

    关于 logit $ z $ 的梯度具有优雅简洁的形式 $ \hat{y} - y $,既直观又计算高效。

    分类交叉熵

    对于具有 $ K $ 个类别的多分类问题,真实标签通常是一个 one-hot 向量 $ \mathbf{y} $,其中正确类别 $ c $ 对应 $ y_c = 1 $。预测概率 $ \hat{\mathbf{y}} $ 通过 Softmax Function 获得:

    $ \mathcal{L}_{\mathrm{CE}} = -\sum_{k=1}^{K} y_k \log \hat{y}_k = -\log \hat{y}_c $

    这简化为正确类别的负对数概率,这就是为什么在此情境下分类交叉熵也被称为负对数似然(negative log-likelihood)。

    数值稳定性

    Log-Sum-Exp 技巧

    朴素地计算 $ \log(\mathrm{softmax}(z_k)) $ 需要对可能很大的 logits 取指数,会导致溢出。log-sum-exp 技巧避免了这一问题:

    $ \log \hat{y}_k = z_k - \log \sum_{j=1}^{K} e^{z_j} = z_k - \left(m + \log \sum_{j=1}^{K} e^{z_j - m}\right) $

    其中 $ m = \max_j z_j $。减去最大 logit 可确保最大指数为零,从而防止溢出。所有主流深度学习框架都实现了这一融合运算(例如 PyTorch 的 CrossEntropyLoss 接受原始 logits)。

    截断(Clamping)

    预测概率应当从精确的 0 和 1 处截断开,以避免 $ \log(0) = -\infty $。通常使用一个较小的 epsilon(例如 $ 10^{-7} $)。

    标签平滑

    标签平滑(Szegedy 等,2016)用一个软分布替代硬 one-hot 目标:

    $ y_k^{\mathrm{smooth}} = (1 - \alpha)\, y_k + \frac{\alpha}{K} $

    其中 $ \alpha $ 是一个较小的常数(通常为 0.1)。这可以防止模型变得过度自信,改善校准效果,并且通常带来更好的泛化能力。在训练大型图像分类器和 Transformer 模型时,这是标准做法。

    与其他损失函数的比较

    损失 公式 典型用途
    交叉熵 $ -\sum y_k \log \hat{y}_k $ 分类
    均方误差 $ \frac{1}{K}\sum(y_k - \hat{y}_k)^2 $ 回归(不适用于分类)
    Hinge loss $ \max(0, 1 - y \cdot z) $ SVM 类型的分类
    Focal loss $ -(1-\hat{y}_c)^\gamma \log \hat{y}_c $ 不平衡分类

    当预测自信地出错时,交叉熵的梯度比 MSE 更陡峭,从而能够更快地纠正较大的错误。

    参见

    参考文献

    • Shannon, C. E. (1948). "A Mathematical Theory of Communication". Bell System Technical Journal.
    • Goodfellow, I., Bengio, Y. and Courville, A. (2016). Deep Learning. MIT Press, 第6章.
    • Szegedy, C. et al. (2016). "Rethinking the Inception Architecture for Computer Vision". CVPR.
    • Lin, T.-Y. et al. (2017). "Focal Loss for Dense Object Detection". ICCV.