p can be some simple distributions e.g.:
Mix all the easy conditional paths together, to get hard marginal path for free.
Theorem 2
Conditional Flow Matching works with any conditional probability path
Conditional Flow Matching works with any conditional probability path
Conditional Flow Matching works with any conditional probability path
Theorem 3
for x1 in dataloader:
x0 = torch.randn_like(x1) # sample noise
t = torch.rand(batch_size, 1) # sample time
xt = sigma_t * x0 + mu_t * x1 # depends on p_t choice
ut = d_mu_t * x1 + d_sigma_t * x0 # corresponding target velocity
loss = ((model(t, xt) - ut) ** 2).mean() # MSE loss
loss.backward()
optimizer.step()
@torch.no_grad()
def generate(model, shape, steps=100):
x = torch.randn(shape) # x0 ~ N(0, I)
dt = 1.0 / steps
for i in range(steps):
t = torch.full((shape[0], 1), i * dt)
v = model(t, x) # predict velocity
x = ODEStep(x, v, t, dt) # depends on ODE solver
return x # x1 ~ q
All equations
All equations
All equations
All equations
All equations