Learning Monge maps by lifting and constraining the relative entropy gradient flow

Théo Dumont
Univ. Gustave Eiffel
theo.dumont@univ-eiffel.fr

T.D., Théo Lacombe, François-Xavier Vialard. Learning Monge maps by lifting and constraining Wasserstein gradient flows
(on arXiv on Monday!).

\left\{\begin{aligned} &\ S_0=\operatorname{id} \\ &\ \partial_t S_t = v_t\circ S_t \end{aligned}\right.

Flow map

-\nabla_{\!{}_W}D(S_t{}_*\varrho_0)
\operatorname{Tan}_{\nabla\phi}K=\overline{\big\{\nabla p\mid \exists t>0,\,\phi+tp\text{ convex} \big\}}^{L^2_{\varrho_0}}

Brenier's theorem

Let \(\rho_0\in\mathcal P(\mathbb R^d)\) s.t. \(\rho_0\ll \mathrm dx\). Then for any \(\gamma\in\mathcal P(\mathbb R^d)\) there exists a unique solution to OT, and it is the gradient of a convex function \(\phi:\mathbb R^d\to\mathbb R\), that is, \[T^{\varrho_0}_{\gamma}=\nabla \phi.\]

\displaystyle \operatorname{OT}(\varrho_0,\gamma)= \inf_{T\in\mathcal T(\rho_0,\gamma)} \int_{\mathbb R^d} \|x-T(x)\|^2\,\mathrm d\varrho_0(x)
\text{where }\mathcal T(\rho_0,\gamma)\coloneqq\{T\in L^2_{\varrho_0}(\mathbb R^d,\mathbb R^d)\mid T_*\rho_0=\gamma\}.
K=\{\nabla\phi\in L^2_{\varrho_{{}_0}}(\mathbb R^d,\mathbb R^d)\mid\phi\text{ convex}\}.
\partial_t\varrho_t=-\operatorname{div}(\varrho_tv_t),\quad\text{with }v_t=-\nabla_{\!{}_W}D(\varrho_t)

Wasserstein gradient flow

\partial_t\varrho_t=\Delta\varrho_t+\operatorname{div}(\varrho_t \nabla V)

(Fokker-Planck)

\operatorname{KL}(\varrho\,|\,\gamma)=\int_{\mathbb R^d}\log\Big(\frac{\mathrm d\varrho}{\mathrm d\gamma}\Big)\,\mathrm d\varrho,
\gamma=e^{-V}\,\mathrm dx,

Convergence of the entropy flow

Suppose that \(\gamma\) is log-concave. Then the Wasserstein gradient flow of the relative entropy converges to \(\gamma\).

Sub-optimality of the flow map

The limit \(S_\infty\) of the flow map \(S_t\) of the relative entropy is not the optimal transport map between \(\varrho_0\) and \(\gamma\).

S_t{}_*\varrho_0=\varrho_t,\text{ and in particular }S_\infty{}_*\varrho_0=\gamma.
\theta_{k+1}\in\argmin_{\theta\in\Theta} \int_{\mathbb R^d}\Big\|\tilde v_k\circ T_{\theta_k}-\frac{T_{\theta}-T_{\theta_k}}\tau\Big\|^2 \,\mathrm d\varrho_0
\partial_t T_t = \argmin_{w_t\in \operatorname{Tan}_{T_t}\!K} \int_{\mathbb R^d}\|\tilde v_t\circ T_t-w_t\|^2\,\mathrm d\varrho_0
\partial_t\rho_t=-\operatorname{div}(\rho_t w_t\circ T_t^{-1})

Proposition.

\(K\) is convex and closed in \(L^2_{ \varrho_0}(\mathbb R^d,\mathbb R^d)\).

The idea:  lift the functional \(D\) in the simple (Hilbert) space \(L^2_{\varrho_0}(\mathbb R^d,\mathbb R^d)\), and constrain the flow to stay in the closed convex set \(K\) of optimal transport maps.

Theorem 1. Existence of solutions

Let \(D:\mathcal P(\mathbb R^d)\to\mathbb R\) be l.s.c., differentiable, and
\(\lambda\)-convex along generalized geodesics with anchor point \(\varrho_0\).
Then, for every \(t_{\text{max}}>0\), there exists a solution \((T_t)_t\in H^1([0,t_{\text{max}}],K)\) to (Cons.GF).

\text{geodesic:}\qquad\quad\varrho_t=[(1-t)\operatorname{id}+t T_{\varrho_0}^{\varrho_1}]_*\varrho_0
\text{gen. geodesic:}\quad\varrho_t=[(1-t)T_{\varrho}^{\gamma_0}+t T_{\varrho}^{\gamma_1}]_*\varrho

Theorem 2. Convergence

Under the same assumptions, suppose also \(\lambda>0\). Then \[\forall t\geq0,\quad D(T_t{}_*\varrho_0)-D(\gamma)\leq e^{-2\lambda t}\big(D(\varrho_0)-D(\gamma)\big)\] and \[\forall t\geq0,\quad \|T_t-T_{\varrho_0}^{\gamma}\|^2_{L^2_{\varrho_0}}\leq\frac{4}{\lambda} e^{-2\lambda t}\big(D(\varrho_0)-D(\gamma)\big)\]

Corollary. For the relative entropy

Let \(D(\varrho)\coloneqq \operatorname{KL}(\varrho\,|\,\gamma)\), where \(\gamma\) is some \(\lambda\)-log-concave measure with \(\lambda>0\).
Then (Cons.GF) admits a solution and it converges to the OT map \(T_{\varrho_0}^\gamma\) in \(O(e^{-2\lambda t})\).

Lifting functionals

 

  • \(F\) constant on fibers \(\pi^{-1}(\rho)=\{T\mid T_*\varrho_0=\varrho\}\)
  • \(D\) l.s.c. \(\implies\) \(F\) l.s.c.
  • \(\argmin_{\mathcal P(\mathbb R^d)} D=\pi(\argmin_K F)\)
  • \(\nabla F(T)=\nabla_{\!{}_W}D(T_*\varrho_0)\circ T\)
  • \(D\) convex along gen. geodesics \(\implies\) \(F\) convex on \(K\)

Can we build a flow that remains an OT map?

?

The constrained gradient flow

The constrained gradient flow: theory

Implementation

Let \(D:\mathcal P(\mathbb R^d)\to\mathbb R\).
The lifted functional \(F\) is defined as \(F=D\circ\pi\),
where \(\pi:T\mapsto T_*\varrho_0\).

Theorem. 

This is a Wasserstein natural gradient descent in \(\Theta\).

Some properties:

\begin{aligned} \Theta&\longrightarrow K =\{\nabla\phi\mid\phi\text{ convex}\} \\ \theta&\longmapsto T_\theta \end{aligned}

Natural gradient descent/flow

L:\Theta\xrightarrow{\boldsymbol\varrho}\mathcal P(\mathbb R^d)\xrightarrow{D}\mathbb R
  • classical gradient flow: \[\partial_t\theta_t=-\nabla_\theta D(\varrho_\theta)\]
  • Wasserstein natural gradient flow: \[\partial_t\theta_t=-(G_\theta)^{-1}\nabla_\theta D(\varrho_\theta)\] where \(G_\theta\) is the pullback of the Wasserstein metric by \(\boldsymbol\varrho:\theta\mapsto\varrho_\theta\) \[G_\theta(\delta\theta,\delta\theta)=g_{\text{Wass}}(d_\theta\boldsymbol\varrho[\delta\theta],d_\theta\boldsymbol\varrho[\delta\theta])\]

Optimal transport maps

Optimal transport problem (Monge)

Let \(K\) be the set of optimal transport maps:

\rho_0,\gamma\in\mathcal P(\mathbb R^d)

\(\mathcal P(\mathbb R^d)\): probability measures of finite second-order moment

Flow map of the relative entropy

For the relative entropy / Kullback-Leibler divergence

and the gradient flow is

quadratic optimization problem

For instance: \(\nabla\) ICNN, \(\nabla\) LSE

Parameterize \(K\):

  • If \(\gamma\propto e^{-V}\): \(D=\operatorname{KL}(\cdot\,|\,\gamma)\)
  • If \(\gamma\) is known through samples: \(D=\text{MMD}(\cdot\,|\,\gamma)\)

Choices for \(D\):

(Cons.GF)

Let \(D:\mathcal P(\mathbb R^d)\to\mathbb R\) be some functional on \(\mathcal P(\mathbb R^d)\), minimal at \(\gamma\).

\left\{\begin{aligned} &\ T_0=\operatorname{id} \\ &\ \partial_t T_t = \operatorname{proj}_{\operatorname{Tan}_{T_t}\!K}[\tilde v_t\circ T_t], \end{aligned}\right.

Constrained gradient flow

-\nabla_{\!{}_W}D(T_t{}_*\varrho_0)

(Cons.GF)

the vector field is

v_t=-\nabla\log\varrho_t-\nabla V

Flow map

\left\{\begin{aligned} &\ S_0=\operatorname{id} \\ &\ \partial_t S_t = v_t\circ S_t, \end{aligned}\right.

[AGS, 2008], [LS, 2022]

:T\mapsto T_*\varrho_0

[POSTER] Optimal entropy flow

By Théo Dumont

[POSTER] Optimal entropy flow

[Poster] On the existence of Monge maps for the Gromov-Wasserstein problem (https://arxiv.org/abs/2210.11945)

  • 49