FLOW MATCHING FOR GENERATIVE MODELING 阅读笔记

论文提出了一种新的生成模型。论文的目的是给定一个目标分布,有目标分布的一定量的样本,但是不知道目标分布的概率密度函数,学习一个模型能生成服从目标分布的新样本。
Flow Matching (FM)是一种训练连续标准化流Continuous Normalizing Flow (CNF)的方法。
FM是一种通用的方法。FM可以用于训练扩散路径,用FM训练扩散路径更稳定。FM也可以用于训练其他路径,一个例子是训练最优传输(OT)位移插值定义的条件概率路径,这些路径比扩散路径更有效,提供更快的训练和采样,从而获得更好的泛化效果。

核心的思想是把无条件估计问题的转换为有条件的问题的来学习。作者说是从denoised score matching得到的启发:

We first show that we can construct such target vector fields through per-example (i.e., conditional) formulations. Then, inspired by denoising score matching, we show that a per-example training objective, termed Conditional Flow Matching (CFM), provides equivalent gradients and does not require explicit knowledge of the intractable target vector field.

连续标准化流

数据点 x ∈ R d \pmb x \in \mathbb R^d xRd,时变概率密度路径 p : [ 0 , 1 ] × R d → R > 0 p:[0,1] \times \mathbb R^d \rightarrow \mathbb R_{>0} p:[0,1]×RdR>0,时变向量场 v t : [ 0 , 1 ] × R d → R d v_t:[0,1] \times \mathbb R^d \rightarrow \mathbb R^d vt:[0,1]×RdRd
流flow把一个分布映射成另一个分布,可以通过常微分方程用 v t v_t vt构建flow ϕ : [ 0 , 1 ] × R d → R d \phi:[0,1] \times \mathbb R^d \rightarrow \mathbb R^d ϕ:[0,1]×RdRd
d ϕ t ( x ) d t = v t ( ϕ t ( x ) ) ϕ 0 ( x ) = x (1) \frac{d\phi_t(\pmb x)}{dt}=v_t(\phi_t(\pmb x)) \tag{1} \\ \phi_0(\pmb x)=\pmb x dtdϕt(x)=vt(ϕt(x))ϕ0(x)=x(1)时变向量场可以用神经网络 v t ( x ; θ ) v_t(\pmb x; \theta) vt(x;θ)来建模,这样构建的flow ϕ t \phi_t ϕt叫做连续标准化流(Continuous Normalizing Flow,CNF)。CNF通常用于把一个简单的分布 p 0 p_0 p0变成一个复杂的分布 p 1 p_1 p1,其符合push-forward方程:
p t ( x ) = [ ϕ t ] ⋆ p 0 ( x ) = p 0 ( ϕ t − 1 ( x ) ) det ⁡ [ ∂ ϕ t − 1 ∂ x ( x ) ] p_t(x)=[\phi_t]_\star p_0(x)=p_0(\phi_t^{-1}(x))\det[\frac{\partial \phi_t^{-1}}{\partial x}(x)] pt(x)=[ϕt]p0(x)=p0(ϕt1(x))det[xϕt1(x)]我们的目标是采样服从复杂目标分布的样本,方法是首先随机采样服从简单分布的噪声样本 x ∼ N ( 0 , I ) \pmb x \sim \mathcal N (\pmb 0, \pmb I) xN(0,I),然后使用ODE求解器在区间 t ∈ [ 0 , 1 ] t \in [0, 1] t[0,1]上使用训练得到的向量场 v t v_t vt求解方程(1)得到服从目标分布的样本 ϕ 1 ( x ) \phi_1(\pmb x) ϕ1(x)。所以主要的问题是如何学习 v t ( x ; θ ) v_t(\pmb x; \theta) vt(x;θ)

Flow Matching(FM)

x 1 \pmb x_1 x1表示服从未知的目标分布 q ( x 1 ) q(\pmb x_1) q(x1)的随机变量,我们不知道 q ( x 1 ) q(\pmb x_1) q(x1)的密度函数,但可以获得服从 q ( x 1 ) q(\pmb x_1) q(x1)的样本。用 p t p_t pt表示概率密度路径, p 0 p_0 p0服从标准高斯分布, p 1 p_1 p1近似 q q q
Flow Matching的训练目标是学习 v t v_t vt,损失函数是 L F M ( θ ) = E t , p t ( x ) ∥ v t ( x ; θ ) − u t ( x ) ∥ 2 \mathcal L_{FM}(\theta)=\mathbb E_{t,p_t(\pmb x)}\|v_t(\pmb x; \theta)-u_t(\pmb x)\|^2 LFM(θ)=Et,pt(x)vt(x;θ)ut(x)2流匹配的损失函数很简单,但在实践中没法使用,因为我们不知道如何定义合适的 p t p_t pt u t u_t ut

Conditional Flow Matching(CFM)

为了解决上面的问题,考虑条件流匹配。条件流匹配的损失函数是 L C F M ( θ ) = E t , q ( x 1 ) , p t ( x ∣ x 1 ) ∥ v t ( x ; θ ) − u t ( x ∣ x 1 ) ∥ 2 \mathcal L_{CFM}(\theta)=\mathbb E_{t,q(\pmb x_1),p_t(\pmb x|\pmb x_1)}\|v_t(\pmb x; \theta)-u_t(\pmb x|\pmb x_1)\|^2 LCFM(θ)=Et,q(x1),pt(xx1)vt(x;θ)ut(xx1)2与流匹配的目标不同,条件流匹配的目标允许我们轻松地对无偏估计进行采样,只要我们可以从 p t ( x ∣ x 1 ) p_t(\pmb x|\pmb x_1) pt(xx1) 有效地采样并计算 u t ( x ∣ x 1 ) u_t(\pmb x|\pmb x_1) ut(xx1),这两者都可以很容易地完成,因为它们是对每个样本定义的。
论文中证明了优化CFM目标等价于优化FM目标(从期望的角度)。所以,剩下的问题是如何设计合适的条件概率路径 p t ( x ∣ x 1 ) p_t(\pmb x|\pmb x_1) pt(xx1)和向量场 u t ( x ∣ x 1 ) u_t(\pmb x|\pmb x_1) ut(xx1)

条件概率路径和条件向量场

上面的讨论是通用的,并没有规定条件概率路径和条件向量场的形式。为了简单,作者讨论的是高斯条件概率路径:
p t ( x ∣ x 1 ) = N ( x ∣ μ t ( x 1 ) , σ t ( x 1 ) 2 I ) p_t(\pmb x|\pmb x_1)=\mathcal N(\pmb x| \mu_t(\pmb x_1), \sigma_t(\pmb x_1)^2\pmb I) pt(xx1)=N(xμt(x1),σt(x1)2I)其中 μ 0 ( x 1 ) = 0 \mu_0(\pmb x_1)=0 μ0(x1)=0 σ 0 ( x 1 ) = 1 \sigma_0(\pmb x_1)=1 σ0(x1)=1 μ 1 ( x 1 ) = x 1 \mu_1(\pmb x_1)=\pmb x_1 μ1(x1)=x1 σ 1 ( x 1 ) = σ min ⁡ \sigma_1(\pmb x_1)=\sigma_{\min} σ1(x1)=σmin
有无数的向量场可以产生给定的概率路径,这里作者讨论的是最简单的典型变换。
考虑条件flow:
ψ t ( x ) = σ t ( x 1 ) x + μ t ( x 1 ) \psi_t(\pmb x)= \sigma_t(\pmb x_1)\pmb x + \mu_t(\pmb x_1) ψt(x)=σt(x1)x+μt(x1)对应的条件向量场可以通过求解方程得到,并有封闭解:
u t ( x ∣ x 1 ) = σ t ′ ( x 1 ) σ t ( x 1 ) ( x − μ t ( x 1 ) ) + μ t ′ ( x 1 ) u_t(\pmb x|\pmb x_1)=\frac{\sigma'_t(\pmb x_1)}{\sigma_t(\pmb x_1)}(x-\mu_t(\pmb x_1))+\mu'_t(\pmb x_1) ut(xx1)=σt(x1)σt(x1)(xμt(x1))+μt(x1)优化的损失函数是 L C F M ( θ ) = E t , q ( x 1 ) , p ( x 0 ) ∥ v t ( ψ t ( x 0 ) ; θ ) − u t ( ψ t ( x 0 ) ∣ x 1 ) ∥ 2 \mathcal L_{CFM}(\theta)=\mathbb E_{t,q(\pmb x_1),p(\pmb x_0)}\|v_t(\psi_t(\pmb x_0); \theta)-u_t(\psi_t(\pmb x_0)|\pmb x_1)\|^2 LCFM(θ)=Et,q(x1),p(x0)vt(ψt(x0);θ)ut(ψt(x0)x1)2

### FLOW MATCHING生成模型的代码实现 FLOW MATCHING是一种基于流(flow)的生成模型方法,其核心思想是通过一系列可逆变换将简单的分布转换为目标分布。这种方法的优势在于它可以通过精确计算似然函数来进行训练,并且可以高效地生成样本。 以下是基于Python的一个简单FLOW MATCHING生成模型的代码实现示例: ```python import torch import torch.nn as nn import torch.optim as optim from torch.distributions import MultivariateNormal class FlowMatchingModel(nn.Module): def __init__(self, dim=2, hidden_dim=32): super(FlowMatchingModel, self).__init__() self.fc1 = nn.Linear(dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, hidden_dim) self.fc3 = nn.Linear(hidden_dim, dim) def forward(self, z): h = torch.relu(self.fc1(z)) h = torch.relu(self.fc2(h)) return self.fc3(h) def train_flow_matching(model, data_loader, optimizer, num_epochs=100): criterion = nn.MSELoss() prior = MultivariateNormal(torch.zeros(data_loader.dataset[0].shape), torch.eye(data_loader.dataset[0].shape)) for epoch in range(num_epochs): total_loss = 0 for batch_data in data_loader: z_prior = prior.sample((batch_data.shape[0],)).to(batch_data.device) transformed_z = model(z_prior) loss = criterion(transformed_z, batch_data) optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() if (epoch + 1) % 10 == 0: print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(data_loader):.4f}') # 初始化模型和优化器 model = FlowMatchingModel().cuda() optimizer = optim.Adam(model.parameters(), lr=0.001) # 假设我们有一个数据加载器 `data_loader` 提供目标分布的数据 train_flow_matching(model, data_loader, optimizer) ``` 上述代码定义了一个简单的多层感知机作为变换模型,并使用均方误差损失来匹配源分布和目标分布[^1]。此代码仅为演示目的设计,在实际应用中可能需要更复杂的架构和正则化技术。 #### 关键点解释 - **可逆变换**:FLOW MATCHING的核心在于构建一组参数化的可逆变换 \( f_\theta \),使得输入噪声经过这些变换后接近真实数据分布。 - **似然估计**:由于变换是可逆的,因此可以直接最大化数据的对数似然函数进行训练。 - **复杂交互建模**:类似于变分自编码器(VAE)和生成对抗网络(GAN),FLOW MATCHING也能够捕捉高维数据中的复杂依赖关系。 ### 注意事项 在实际部署过程中,还需要考虑以下几个方面: - 数据预处理:确保输入数据被标准化至适当范围。 - 超参数调整:包括学习率、层数、隐藏单元数量等都需要仔细调优。 - 计算资源需求:对于大规模数据集或高维度特征向量,可能需要GPU加速支持。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值