Flow Matching/zh

    From Marovi AI
    This page is a translated version of the page Flow Matching and the translation is 100% complete.
    Other languages:
    Article
    Topic area generative-models
    Prerequisites Diffusion Models, Optimal Transport, Neural Ordinary Differential Equations


    概述

    Flow Matching 是一种用于 Continuous Normalizing Flows 的免模拟训练框架,其中神经网络直接回归到一个目标的时间相关向量场,该场将一个简单的先验分布传输到数据分布。该方法由 Lipman、Chen、Ben-Hamu、Nickel 和 Le 于 2022 年提出,推广并统一了若干早期方法,包括 Score MatchingRectified Flow,并已成为图像、视频、音频和分子生成建模的主流范式。与流模型的经典极大似然训练相比,后者需要昂贵的轨迹模拟;与 Denoising Diffusion Probabilistic Models 相比,后者需要随机微分方程的形式化,Flow Matching 提供了一个确定性且概念简洁的替代方案:选取一条连接噪声与数据的概率路径,导出生成该路径的向量场,并以均方误差损失学习该向量场。

    直觉

    一个连续的归一化流描述了分布空间中的一条曲线,通过常微分方程 (ODE) 将样本从时间 $ t=0 $ 处的初始概率密度 $ p_0 $ 传输到时间 $ t=1 $ 处的目标密度 $ p_1 $。对于每个时刻 $ t \in [0,1] $,一个 Vector Field $ u_t(x) $ 给出了位置 $ x $ 处的瞬时速度;沿时间积分该向量场,便将样本从先验推送到数据分布。

    核心挑战在于,一般而言我们无法直接观测到一个能将可处理的先验传输到经验数据分布的向量场。Flow Matching 绕开了这一难题,通过对单个数据点进行条件化、分段地构造该路径。对于固定的数据样本 $ x_1 $,可以容易地写出从一个噪声样本到 $ x_1 $ 的光滑路径,并读出生成它的速度。在噪声与数据的联合采样下对这些逐样本的速度取平均,便得到将整个总体从先验驱动到数据的无条件速度场。关键洞见是:将神经网络回归到条件速度,在期望意义下即可恢复无条件速度,从而完全无需计算边缘密度。

    概率路径与向量场

    一条概率路径是按时间索引的密度族 $ \{p_t\}_{t \in [0,1]} $,其中 $ p_0 $ 是所选定的先验(通常是一个标准的 Gaussian Distribution),$ p_1 $ 是数据分布。当 Continuity Equation 成立时,向量$ u_t $ 生成该路径:

    $ {\displaystyle \frac{\partial p_t(x)}{\partial t} + \nabla \cdot (p_t(x)\, u_t(x)) = 0.} $

    等价地,从 $ p_0 $ 抽取并由 ODE $ dx/dt = u_t(x) $ 演化的样本,在每个中间时刻都服从 $ p_t $。多种向量场可以生成同一条路径,因此需要额外的结构(例如直线性,或关于某种传输代价的最优性)来挑出一个优选的向量场。

    条件 Flow Matching

    直接回归 $ u_t $ 不可行,因为 $ u_t $ 依赖于未知的边缘密度。条件 Flow Matching (CFM) 的目标通过对目标样本 $ x_1 $ 进行条件化来解决此问题。对于选定的条件路径 $ p_t(x \mid x_1) $(例如,一个均值在 $ t=0 $ 时为 $ 0 $、在 $ t=1 $ 时线性插值到 $ x_1 $高斯分布)以及生成它的条件向量$ u_t(x \mid x_1) $,损失为

    $ {\displaystyle \mathcal{L}_{\mathrm{CFM}}(\theta) = \mathbb{E}_{t,\, x_1,\, x \sim p_t(\cdot \mid x_1)}\!\left[\, \lVert v_\theta(t, x) - u_t(x \mid x_1) \rVert^2 \right],} $

    其中 $ v_\theta $ 是所学习的向量场,$ t $$ [0,1] $ 上均匀采样,$ x_1 $ 从数据中采样。Lipman 等人证明,该目标关于 $ \theta $ 的梯度与对边缘场 $ u_t $ 的回归相同,尽管该边缘场不可处理。关键的设计选择是条件路径;常用选择包括保方差的高斯路径、爆方差路径,以及 Optimal Transport 位移线性插值 $ x_t = (1-t)\, x_0 + t\, x_1 $,后者产生异常简洁的回归目标 $ u_t(x \mid x_0, x_1) = x_1 - x_0 $

    训练与推断

    训练仅需采样一个时刻、一个噪声向量和一个数据点;以闭式计算条件速度;并最小化平方误差。训练期间无需模拟 ODE,无需辅助分数网络,也无需跟踪变分下界小批量由相互独立的三元组 $ (t, x_0, x_1) $ 组成,其中 $ x_0 $先验中抽取,$ x_1 $ 从数据集中抽取。

    在推断时,通过在从先验中抽取的初始条件下,将所学到的 ODE $ dx/dt = v_\theta(t, x) $$ t=0 $ 积分到 $ t=1 $ 来生成样本。可以使用任何黑盒 ODE 求解器;常见选择包括自适应 Runge-Kutta 方法以及定步长 Euler 或 Heun's Method 积分器。由于使用线性(最优传输)插值训练的 Flow Matching 往往产生近乎直线的轨迹,生成样本通常只需少量求解器步骤,这与可能需要数十至数百步的扩散模型形成对比。

    变体

    Flow Matching 的若干变体会调整条件路径、$ x_0 $$ x_1 $ 之间的耦合,或训练过程:

    • Rectified Flow(Liu 等,2022)使用与 OT-CFM 相同的线性插值进行训练,然后在自身被拉直的轨迹上迭代地重新训练模型,产生越来越直的流,从而支持单步或少步采样
    • Stochastic Interpolants(Albergo 与 Vanden-Eijnden,2023)将该框架推广至允许随机动力学,将基于流和基于扩散的生成建模统一在一个插值形式之下。
    • Optimal Transport Conditional Flow Matching(Tong 等,2023)用基于小批量Optimal Transport 耦合替代 $ x_0 $$ x_1 $ 的独立耦合,从而锐化噪声与数据之间的对齐并降低路径曲率。
    • Multisample Flow Matching(Pooladian 等,2023)提出了相关的批量耦合视角,并对相应估计量给出了理论分析。
    • Riemannian Flow Matching 将该构造扩展到流形上的数据,用测地线插值代替欧氏插值,并采用流形感知的 ODE 积分器。
    • Discrete Flow Matching 用连续时间马尔可夫链取代 ODE,将该框架适配于分类数据。

    与扩散模型的比较

    扩散模型与 Flow Matching 紧密相关:两者都学习从噪声到数据的时间相关变换,并且都可以表述为对目标场的回归问题。它们的差别在于过程与参数化的选择。扩散模型通过随机的正向与逆向过程加以表述,并学习 Score Function $ \nabla \log p_t(x) $;其训练对应于 Flow Matching 家族中一种特定的保方差高斯路径。Flow Matching 在 ODE 层面上是纯确定性的,将路径视为自由的设计选择,并参数化速度而非分数。经验上,OT 风格的 Flow Matching 产生更直的轨迹并支持更快的采样,而扩散的随机性在某些情形下可以提升样本多样性。基于分数的扩散采样器可以被重新解释为概率流 ODE 的 ODE 积分器,从而精确揭示了这两族之间的数学桥梁。

    局限性

    Flow Matching 继承了 Continuous Normalizing Flows 的常见困难:在轨迹弯曲或刚性时,推断阶段的 ODE 积分可能代价高昂;精确的对数似然计算需要 Hutchinson Trace Estimator 或昂贵的雅可比求值;而在高维流形上则可能需要精心选择先验,以免浪费建模容量。该框架还假设存在一条可处理的条件路径,这在欧氏空间中是直接的,但在流形、图或离散空间上则更加微妙。条件化、Classifier-Free Guidance 与无似然评估可以从扩散迁移到 Flow Matching,但有时需要谨慎调整,因为其底层对象是一个向量场而非分数。

    应用

    Flow Matching 已被应用于高分辨率图像生成,包括将 OT-CFM 扩展到数十亿参数的文本到图像模型;语音与音频合成,其中直线轨迹支持实时生成;在 $ \mathrm{SE}(3) $ 流形上的蛋白质与分子结构生成;以及机器人领域的轨迹生成。许多近期的大规模生成系统都采用 rectified-flow 或 OT-CFM 训练,因其简洁性与少步推断的特性。

    参考文献

    Cite error: <ref> tag with name "lipman2022" defined in <references> has group attribute "" which does not appear in prior text.
    Cite error: <ref> tag with name "liu2022" defined in <references> has group attribute "" which does not appear in prior text.
    Cite error: <ref> tag with name "albergo2023" defined in <references> has group attribute "" which does not appear in prior text.
    Cite error: <ref> tag with name "tong2023" defined in <references> has group attribute "" which does not appear in prior text.
    Cite error: <ref> tag with name "pooladian2023" defined in <references> has group attribute "" which does not appear in prior text.