Triplet Loss/zh
| Article | |
|---|---|
| Topic area | Deep Learning |
| Prerequisites | Embeddings, Loss Functions, Stochastic Gradient Descent, Convolutional Neural Networks |
概述
三元组损失是一种损失函数,用于训练能够产生向量嵌入的模型,使得语义相似的输入在嵌入空间中彼此靠近,而不相似的输入彼此远离。它作用于由三个样本组成的三元组(一个锚点、一个与锚点同类的正样本以及一个不同类的负样本),并在负样本与锚点的距离比正样本与锚点的距离近不到固定的间隔时对嵌入进行惩罚。三元组损失最初因 FaceNet 系统在人脸识别中的应用而流行,[1]如今已成为度量学习中的标准工具,被广泛用于图像检索、行人重识别、签名验证、音频相似度以及信息检索中的稠密段落检索。
与使用交叉熵损失的分类式训练(将每个输入分配到固定的类别索引)不同,三元组损失直接塑造嵌入空间的几何结构。这使其非常适合开放集问题,即推理时所遇到的类别集合在训练时并不已知,例如识别在训练中从未见过的人脸。
直观理解
考虑一个将输入 $ x $ 映射为向量 $ f(x) \in \mathbb{R}^d $ 的模型,该向量通常通过 $ L^2 $ 归一化被约束在单位超球面上。一个三元组 $ (a, p, n) $ 由以下三部分组成:
- 一个取自某一类别的锚点 $ a $,
- 一个与锚点取自同一类别的正样本 $ p $,以及
- 一个取自不同类别的负样本 $ n $。
我们希望锚点的嵌入与正样本的距离小于其与负样本的距离。三元组损失并非仅作为一种排序约束施加这一要求,而是作为一种间隔条件:正样本必须比负样本至少近上某一数量 $ \alpha > 0 $。若没有间隔,可以通过将所有嵌入塌缩到同一点来平凡地满足损失,因为任何非零的距离差都足以满足条件。间隔以有限的数量将负样本推远,并赋予损失明确的几何尺度。
形式化表述
设 $ d(\cdot, \cdot) $ 为嵌入空间上的一种距离,最常见的是对经过 $ L^2 $ 归一化的嵌入所计算的欧氏距离平方。单个三元组上的三元组损失为
$ {\displaystyle \mathcal{L}(a, p, n) = \max\bigl(0,\; d(f(a), f(p)) - d(f(a), f(n)) + \alpha\bigr).} $
由此产生两种情形。当 $ d(f(a), f(p)) + \alpha \le d(f(a), f(n)) $ 时,损失为零且梯度消失:三元组已经满足约束,对学习不产生贡献。否则损失严格为正,将锚点拉向正样本,同时将其推离负样本。由于嵌入通常被约束在单位球面上,欧氏距离平方与余弦相似度之间仅相差一个仿射变换,因此三元组损失可以等价地用内积的形式书写。
在小批量上的总目标是所选三元组集合上各三元组损失的平均(或求和)。如何选好这些三元组是核心的实际问题。
三元组挖掘
均匀随机地采样三元组的朴素方法在计算上是浪费的:随着训练的推进,大多数随机三元组已经满足间隔条件,产生的梯度为零。这促使了三元组挖掘——即有针对性地选择具有信息量的三元组的做法。
在包含 $ B $ 个嵌入的小批量内通常区分以下三种情形:
- 简单三元组满足 $ d(a, p) + \alpha < d(a, n) $。损失为零,无学习信号。
- 困难三元组满足 $ d(a, n) < d(a, p) $。负样本比正样本更接近锚点:会产生很强的梯度,但这类三元组往往对应于标签噪声或极端离群值,可能使训练不稳定。
- 半困难三元组满足 $ d(a, p) < d(a, n) < d(a, p) + \alpha $。正样本比负样本更接近,但差距不足;损失为正但有界。
FaceNet 引入了在线半困难挖掘:对批次中的每个锚点,选取最难的半困难负样本。这种做法产生稳定的梯度,被广泛认为是使三元组损失得以在大规模上实用化的关键配方。批次困难挖掘[2]最初为行人重识别提出,则改为对批次中的每个锚点选取最难的正样本和最难的负样本。这种方法需要谨慎处理(如课程学习、热身阶段或使用Adam等鲁棒的优化器),但在许多任务上能得到更强的嵌入。挖掘策略可以是在线的(根据当前批次实时计算)或离线的(在整个数据集上预先计算)。
训练与推理
训练通过对 $ f $ 进行反向传播来完成。典型的批次构造方式是采样 $ P $ 个类别,每类 $ K $ 个样本,从而得到 $ B = PK $ 个嵌入和 $ O(B^2) $ 个候选成对距离;这种 $ P\!\times\!K $ 采样器是批次困难挖掘的标准搭配。间隔 $ \alpha $ 是一个超参数;当嵌入被单位归一化时,常用取值范围为 $ [0.1, 0.5] $。嵌入几乎总是经 $ L^2 $ 归一化,这样既可使损失有界,也因为余弦几何往往比无约束的欧氏几何更为稳定。
在推理阶段,模型仅作为嵌入函数使用:两个输入之间的相似度通过 $ d(f(x_1), f(x_2)) $ 计算,下游任务(验证、检索、聚类)直接使用该距离。无需softmax头。
变体与相关损失
若干损失对三元组目标进行了泛化或改进:
- 对比损失,[3]作为其历史前身,作用于样本对而非三元组,并对相似对和不相似对设定固定的目标距离。
- N 对损失[4]在单个softmax式目标中将一个正样本与多个负样本进行比较,提高了样本效率。
- 四元组损失增加了第二个负样本项,以更大的间隔强制类内距离小于类间距离。
- 提升结构化损失联合考虑批次内所有样本对,通过平滑的 log-sum-exp 进行加权。
- 角度损失将欧氏距离替换为基于角度且对尺度不敏感的约束。
- 多重相似性损失与圆损失通过自相似性和相对相似性对样本对加权进行了泛化。
- InfoNCE及其他对比学习目标与三元组损失关系密切:当只有一个正样本和多个负样本时,InfoNCE 的行为类似于带温度的软化 N 对损失。
比较
三元组损失最常被拿来与基于分类的嵌入学习进行比较。诸如 ArcFace 和 CosFace 等方法训练一个显式分类器,并在softmax的对数概率上施加角度间隔,无需显式三元组采样就能凭经验得到强大的人脸嵌入。这类方法完全规避了挖掘问题,但训练时需要固定的类别集合,并且当身份数量非常庞大时扩展性不佳。三元组损失则与身份无关,能自然地扩展到开放集问题,但需要谨慎的挖掘策略,并且对间隔和学习率较为敏感。
与成对的对比损失相比,三元组损失通常在样本效率上更高,因为每个三元组在一次更新中同时编码了吸引和排斥。与 InfoNCE 风格的对比目标相比,三元组损失具有更锐利的间隔几何,但当负样本未被精心选择时往往不太稳定。
局限性
三元组损失的主要实际困难在于:一旦间隔条件被满足,梯度信号便消失,这使得未经筛选的采样效率低下。这正是必须采用挖掘策略的原因。困难挖掘可能放大标签噪声:单个被错误标注的正样本会成为许多锚点的最难配对,从而主导训练。该损失对间隔、嵌入范数以及批次构成都很敏感;糟糕的 $ P, K $ 采样器可能产出完全没有有用三元组的批次。最后,尽管该损失直接优化的是一种排序式目标,却并未在整个数据集上对距离进行显式校准,因此用于验证的绝对阈值必须在留出验证集上事后确定。
参考文献
- ↑ Template:Cite arxiv
- ↑ Template:Cite arxiv
- ↑ Hadsell, Chopra, and LeCun, "Dimensionality Reduction by Learning an Invariant Mapping", CVPR 2006.
- ↑ Sohn, "Improved Deep Metric Learning with Multi-class N-pair Loss Objective", NeurIPS 2016.