Cross-Entropy Loss/zh
| Article | |
|---|---|
| Topic area | Machine Learning |
| Difficulty | Intermediate |
| Prerequisites | Loss Functions, Softmax Function |
交叉熵损失(也称为对数损失)是机器学习分类任务中使用最广泛的损失函数。它根植于信息论,衡量真实标签分布与模型预测概率分布之间的差异,提供一个平滑、可微的目标,驱动概率分类器做出自信且正确的预测。
信息论基础
熵
离散概率分布 $ p $ 的熵量化其不确定性:
- $ H(p) = -\sum_{k=1}^{K} p_k \log p_k $
对于确定性分布(独热标签),$ H(p) = 0 $。当所有结果等概率时,熵达到最大。
KL 散度
Kullback-Leibler 散度衡量一个分布 $ q $ 与参考分布 $ p $ 之间的差异:
- $ D_{\mathrm{KL}}(p \,\|\, q) = \sum_{k=1}^{K} p_k \log \frac{p_k}{q_k} $
KL 散度是非负的,并且当且仅当 $ p = q $ 时等于零。
交叉熵
分布 $ p $(真实)与 $ q $(预测)之间的交叉熵为:
- $ H(p, q) = -\sum_{k=1}^{K} p_k \log q_k = H(p) + D_{\mathrm{KL}}(p \,\|\, q) $
由于 $ H(p) $ 相对于模型参数是常数,最小化交叉熵等价于最小化 KL 散度——即使预测分布 $ q $ 尽可能接近真实分布 $ p $。
二元交叉熵
对于二元分类,真实标签 $ y \in \{0, 1\} $,预测概率 $ \hat{y} = \sigma(z) $(其中 $ \sigma $ 是 sigmoid 函数):
- $ \mathcal{L}_{\mathrm{BCE}} = -\bigl[y \log \hat{y} + (1 - y) \log(1 - \hat{y})\bigr] $
在包含 $ N $ 个样本的数据集上:
- $ \mathcal{L} = -\frac{1}{N} \sum_{i=1}^{N} \bigl[y_i \log \hat{y}_i + (1 - y_i) \log(1 - \hat{y}_i)\bigr] $
关于logit $ z $的梯度具有优雅简洁的形式 $ \hat{y} - y $,既直观又计算高效。
分类交叉熵
对于具有$ K $个类别的多分类问题,真实标签通常是一个独热向量$ \mathbf{y} $,其中正确类别$ c $对应$ y_c = 1 $。预测概率$ \hat{\mathbf{y}} $通过Softmax Function获得:
- $ \mathcal{L}_{\mathrm{CE}} = -\sum_{k=1}^{K} y_k \log \hat{y}_k = -\log \hat{y}_c $
这简化为正确类别的负对数概率,这就是为什么在此情境下分类交叉熵也被称为负对数似然(negative log-likelihood)。
数值稳定性
Log-Sum-Exp 技巧
直接计算$ \log(\mathrm{softmax}(z_k)) $需要对可能较大的logits取指数,导致数值上溢。log-sum-exp技巧可以避免此问题:
- $ \log \hat{y}_k = z_k - \log \sum_{j=1}^{K} e^{z_j} = z_k - \left(m + \log \sum_{j=1}^{K} e^{z_j - m}\right) $
其中$ m = \max_j z_j $。减去最大logit可确保最大指数为零,从而防止上溢。所有主流的深度学习框架都实现了这一融合操作(例如,PyTorch的CrossEntropyLoss接受原始logits)。
截断(Clamping)
预测概率应当从精确的 0 和 1 处截断开,以避免 $ \log(0) = -\infty $。通常使用一个较小的 epsilon(例如 $ 10^{-7} $)。
标签平滑
标签平滑(Szegedy 等人,2016)将硬独热目标替换为软分布:
- $ y_k^{\mathrm{smooth}} = (1 - \alpha)\, y_k + \frac{\alpha}{K} $
其中$ \alpha $是一个小常数(通常为0.1)。这可防止模型过度自信,改善校准,通常能带来更好的泛化能力。在训练大型图像分类器和transformer模型时已是标准做法。
与其他损失函数的比较
| 损失 | 公式 | 典型用途 |
|---|---|---|
| 交叉熵 | $ -\sum y_k \log \hat{y}_k $ | 分类 |
| 均方误差 | $ \frac{1}{K}\sum(y_k - \hat{y}_k)^2 $ | 回归(不适用于分类) |
| Hinge loss | $ \max(0, 1 - y \cdot z) $ | SVM 类型的分类 |
| Focal loss | $ -(1-\hat{y}_c)^\gamma \log \hat{y}_c $ | 不平衡分类 |
当预测自信地出错时,交叉熵的梯度比 MSE 更陡峭,从而能够更快地纠正较大的错误。
参见
参考文献
- Shannon, C. E. (1948). "A Mathematical Theory of Communication". Bell System Technical Journal.
- Goodfellow, I., Bengio, Y. and Courville, A. (2016). Deep Learning. MIT Press, 第6章.
- Szegedy, C. et al. (2016). "Rethinking the Inception Architecture for Computer Vision". CVPR.
- Lin, T.-Y. et al. (2017). "Focal Loss for Dense Object Detection". ICCV.