Optimal Transport Applications

Study Group on Optimal Transport

Daniel Yukimura

\min\limits_{T:\hspace{1mm} T(X)\sim \nu} \mathbb{E}_{X\sim \mu} \left[ c(X, T(X)) \right]

Classic problems:

Monge's problem (1781):

Classic problems:

Logistics and economics:

  • Tolstoi (1920)
  • Hitchcock, Kantorovich, and Koopmans (1940s)

Classic problems:

Classic problems:

Classic problems:

Classic problems:

Kantorovich problem:

\min\limits_{\pi\in \mathcal(C)(\mu, \nu)} \displaystyle \int_{\mathcal{X}\times \mathcal{Y}} c(x,y) d\pi(x,y)

Dual problem:

\max\limits_{f(x)+g(x)\leq c(x,y)} \displaystyle \int_{\mathcal{X}} f(x) d\mu(x) + \int_{\mathcal{Y}} g(y) d\nu(y)

Classic problems:

  • Assignment and Routing problems
  • Operations Research
  • Economics
    • Labor market
    • Incomplete Econometrics
    • Quantile methods
    • Contract theory
    • etc...
  • Robust Optimization
\sup\limits_{\mathcal{W}(P, P_0) \leq \delta} \mathbb{E}_P ( f(X) )

OT in Nature:

Biology:

  • Liquid transport in plants.
  • Single cell gene expression 

OT in Nature:

Physical systems:

  • Metereology: 
    • Semigeostrophic equation
  • Crystals
    • Isoperimetric problem

Alessio Figalli

OT in Nature:

and more:

  • Astrophysics
  • Geology

OT for Machine Learning

Learning prob. distributions:

OT for Machine Learning

Generative Modeling:

\textit{Observations:}\hspace{2mm} \beta = \frac{1}{n}\sum\limits_{i=1}^n \delta_{x_i}
\textit{Param. Model:}\hspace{2mm} \theta \rightarrow \alpha_\theta
\textit{Density fitting:}\hspace{2mm} \min\limits_{\theta} D(\alpha_\theta, \beta)
\textbf{Unsupervised learning:}

OT for Machine Learning

Generative modeling:

\text{Find } G: \mathbb{R}^q \rightarrow \mathbb{R}^n \text{ s.t. } G_\# \mu = \nu_{\text{data}}
G
\mathcal{Z}
\mathcal{X}
\textit{Latent space}
\textit{Data space}
\sim \mu
\sim \nu_{\text{data}}

OT for Machine Learning

Generative Adversarial Networks

OT for Machine Learning

Generative Adversarial Networks

\min\limits_{G}\max\limits_{D} \mathbb{E}_{X\sim \nu_{data}} [\log{D(X)}] + \mathbb{E}_{Z\sim \mu}[\log{(1-D(G(Z))}]

OT for Machine Learning

Wasserstein GANs:

\min\limits_{G} \mathcal{W} (G_\# \mu, \nu_{\text{data}})
\mathcal{W} (\nu_0, \nu_1) = \sup\limits_{\|f\|\leq 1} \mathbb{E}_{X\sim \nu_0}[f(X)] - \mathbb{E}_{X\sim \nu_1}[f(X)]

Kantorovich-Rubinstein duality

\max\limits_{\gamma\in\Gamma} \mathbb{E}_{X\sim \nu_{\text{data}}}[f_\gamma(X)] - \mathbb{E}_{Z\sim \mu}[f_\gamma( G_\theta(Z) )]

OT for Machine Learning

Wasserstein GANs:

\max\limits_{\gamma\in\Gamma} \mathbb{E}_{X\sim \nu_{\text{data}}}[f_\gamma(X)] - \mathbb{E}_{Z\sim \mu}[f_\gamma( G_\theta(Z) )]
\nabla_\theta \mathcal{W}(\nu_{\text{data}}, \nu_\theta) = - \mathbb{E}_{Z\sim \mu} [\nabla_\theta f^*(G_\theta(Z))]

OT for Machine Learning

Generative Adversarial Networks

Progressive growing GANs, 2017

StyleGAN2, 2018

StyleGAN3, 2021

Applications of GANs:

Natural Language Processing - GPT-3, 2020

  • 175 billion parameters

Applications of GM:

AlphaFold - Protein folding, 2021

More applications of OT in ML

More applications of OT in ML

More applications of OT in ML

 

  • Reinforcement Learning
    • Imitation Learning
    • Multi-agent policy transfer
  • Style Transfer
  • Normalizing Flows
  • Mean field simulations
  • Motion correction
  • Theoretical Deep Learning
  • ...

Context:

W_2(P_0, P_1) = \displaystyle\inf_{\gamma \in \Gamma(P_0, P_1)} \sqrt{ \int_{\mathbb{R}^d \times \mathbb{R}^d} \|x-y\|^2 d\gamma(x,y) }
X_1, \dots, X_n \overset{\text{i.i.d.}}{\sim} P_0

Wasserstein distance

\Rightarrow \hspace{2mm} W_2(\hat P_0, \hat P_1)

Plug-in estimator

Y_1, \dots, Y_n \overset{\text{i.i.d.}}{\sim} P_1
  • linear programming

Context:

  • [Sommerfield and Munk, 2017]: 
\bullet \hspace{3mm} \Delta_n = \left|W_2(\hat P_0, \hat P_1) - W_2(P_0, P_1) \right|
\Delta_n \asymp n^{-1/2}, \hspace{3mm} \text{if }P_0\neq P_1
\Delta_n \asymp n^{-1/4}, \hspace{3mm} \text{if }P_0 = P_1

only for finite support !

  • [Dobrić and Yukich, 1995]:
\Delta_n \asymp n^{-1/d}, \hspace{3mm} d \geq 3

Transport rank (TR):

\textbf{Definition:} \text{ Given } \gamma \in \Gamma(P_0, P_1), \text{ the TR of }\gamma
\text{ is the smallest }k\in\mathbb{Z} \text{ s.t.}
\gamma = \displaystyle \sum_{j=1}^k \lambda_j \left ( Q_j^0 \otimes Q_j^1 \right)
\text{where the }Q_j^0, Q_j^1 \in \mathcal{P}\left(\mathbb{R}^d\right), \text{ and } \lambda_j \geq 0.
\bullet \hspace{3mm} \text{ When }P_0 \text{ and }P_1\text{ have finite supp., the TR of } \gamma\in\Gamma(P_0, P_1)
\text{coincides with the rank of }\gamma\text{ viewed as matrix.}

factored coupling

\bullet \hspace{3mm} \Gamma_k(P_0,P_1) \text{ denote couplings with TR at most }k.

Regularization via Factored Couplings:

Idea: Opt. couplings in practice can be well approx. by assuming the distributions have a small number of pieces moving nearly independently.

\textbf{Def 2:} \text{ A }\textbf{soft cluster}\text{ of a prob. }P \text{ is a sub-prob. }C
\text{of total mass }\lambda\in [0,1]\text{ s.t. } 0 \leq C \leq P.
\text{We say that a collection } C_1,\dots, C_k \text{ of soft clusters of } P
\text{ is a } \textbf{partition} \text{ of }P \text{ if } C_1+\dots+C_k = P
\textbf{centroid} \text{ of } C: \mu(C) = \frac{1}{\lambda} \int x dC(x)
\textbf{Prop:} \text{ Let }\gamma\in\Gamma(P_0, P_1) \text{ and let } C_1^0,\dots C_k^0 \text { and }
C_1^1,\dots C_k^1 \text{ be the induced part. of }P_0 \text{ and }P_1
\displaystyle\int \|x-y\|^2d\gamma(x,y) = \displaystyle\sum_{j=1}^k \left( \lambda_j \|\mu(C_j^0) - \mu(C_j^1)\|^2 \vphantom{.\displaystyle\sum_{\ell\in\{0,1\}} } \right .
\left. + \displaystyle\sum_{\ell\in\{0,1\}} \displaystyle \int \|x - \mu(C_j^\ell)\|^2 d C_j^\ell (x) \right)
\textbf{Def 3:} \text{ The cost of a factored coupling }\gamma\in\Gamma(P_0, P_1) \text{ is}
\text{cost}(\gamma) = \displaystyle\sum_{j=1}^k \lambda_j \|\mu(C_j^0) - \mu(C_j^1)\|^2
\underset{\gamma\in\Gamma_k(\hat P_0, \hat P_1)}{\text{argmin}} \displaystyle \int \|x-y\|^2 d\gamma(x,y)

Regularized optimal coupling:

Questions:

  • How to solve the above problem efficiently ?
  • Is the solution robust ?

k-Wasserstein barycenters

\mathcal{D}_k = \left\{ \sum\limits_{j=1}^k \alpha_j \delta_{x_j}: \alpha_j \geq 0, \sum\limits_{j=1}^k \alpha_j = 1, x_j\in\mathbb{R}^d \right\}

prob. supported on k points

\bar{P} = \underset{P\in \mathcal{D}_k}{\text{argmin}} \displaystyle\sum_{j=1}^N W_2^2 (P, P_j)

k-Wasserstein barycenter

H = \underset{P\in \mathcal{D}_k}{\text{argmin}} \left\{ W_2^2(P, \hat P_0)+ W_2^2(P, \hat P_1) \right\}
\bullet \hspace{3mm} \exists \text{ efficient procedure for finding the barycenters}
\bullet \hspace{3mm} \gamma_0 = \text{OC}(\hat P_0, H), \hspace{3mm} \gamma_1 = \text{OC}(\hat P_1, H)
\bullet \hspace{3mm} \text{supp}(H) = \{z_1, \dots, z_k\}
\gamma_0 = \displaystyle\sum_{j=1}^k \gamma_0(\cdot | z_j) H(z_j), \hspace{3mm} \gamma_1 = \displaystyle\sum_{j=1}^k \gamma_1(\cdot | z_j) H(z_j)
\Rightarrow \gamma_H(A\times B) = \displaystyle\sum_{j=1}^k H(z_j) \gamma_0(A | z_j) \gamma_1(B | z_j)

induced factored coupling:

\textbf{Prop:} \text{ The partitions } C_1^0,\dots C_k^0 \text { and } C_1^1,\dots C_k^1
\text{induced by } H \text{ are the min. of}
\displaystyle\sum_{j=1}^k \left( \frac{\lambda_j}{2} \|\mu(C_j^0) - \mu(C_j^1)\|^2 + \displaystyle\sum_{\ell\in\{0,1\}} \displaystyle \int \|x - \mu(C_j^\ell)\|^2 d C_j^\ell (x) \right)
\widehat W := \text{cost}(\gamma_H)
\widehat T (X_i) = X_i + \frac{1}{\sum\limits_{j=1}^k C_j^0(X_i)} \displaystyle\sum_{j=1}^k C_j^0(X_i) (\mu(C_j^1) - \mu(C_j^0))
  • Estimator for the squared Wasserstein distance:
  • Estimated transport map:
\text{How to estimate } \gamma_H ?
H = \underset{P\in \mathcal{D}_k}{\text{argmin}} \left\{ W_2^2(P, \hat P_0)+ W_2^2(P, \hat P_1) \right\}
\bullet \hspace{3mm} \text{It is sep. convex in } \mathcal{H} = \{z_1,\dots, z_k\} \text{ and } (\gamma_0, \gamma_1)
\Rightarrow \text{ admits alternating optim.}
z_j = \dfrac{\displaystyle\sum_{i=1}^n \gamma_0(z_j, X_i)X_i + \displaystyle\sum_{i=1}^n \gamma_1(z_j, Y_i)Y_i }{ \displaystyle\sum_{i=1}^n \gamma_0(z_j, X_i) + \displaystyle\sum_{i=1}^n \gamma_1(z_j, Y_i) }
\text{Updating hubs:}
\text{Updating }(\gamma_0, \gamma_1)
-\varepsilon \displaystyle\sum_{i,j} (\gamma_0)_{j,i} \log{\left((\gamma_0)_{j,i}\right)} -\varepsilon \displaystyle\sum_{i,j} (\gamma_1)_{j,i} \log{\left((\gamma_1)_{j,i}\right)}
\bullet \hspace{2mm} \text{Add entropic regularization}
\bullet \hspace{2mm} \text{Use Sinkhorn iteration}
\textbf{Theorem:} \text{ Let }P\in\mathcal{P}(\mathbb{R}^d) \text{ supported on the unit ball}
\text{and }\hat P \text{ the empirical meas. on a sample of size } n.
\text{Then with prob. at least }1-\delta
\displaystyle\sup_{\rho\in\mathcal{D}_k} \left| W_2^2(\rho, \hat P) - W_2^2(\rho, P) \right| \lesssim \sqrt{ \dfrac{k^3 d \log{k} + \log{(1/\delta)} }{n} }
\bullet \hspace{2mm} \text{Recovers the } n^{-1/2} \text{ bound}.
\bullet \hspace{2mm} \text{Can be extended to compact support}.
\bullet \hspace{2mm} \text{It doesn't provide rates for }\widehat{W}.

Experiments:

Fragmented hypercube (synthetic):

P_0 = \text{Unif}([-1,1]^d)
P_1 = T_\#(P_0), \hspace{3mm} T(X) = X + 2 \text{sign}(X)\odot(e_1+e_2)

OT

H

FOT

Experiments:

Fragmented hypercube (synthetic):

Experiments:

Batch correction for single cell RNA data:

  • Each coordinate represents the expression-level of the corresponding gene.
\bullet \hspace{3mm} \text{Dim}\sim 10^4
\bullet \hspace{3mm} \text{Number of cells}\sim 10^2\text{ to }10^6
  • haematopoetic (blood) datasets from different sources.

Experiments:

Batch correction for single cell RNA data:

Sketch proof for Theorem 4:

\text{Given }c\in \mathbb{R}^d \text{ and } S\in\mathcal{P}_{k-1} \text{ define}
\mathcal{P}_{n} = \{ S\subset \mathbb{R}^d: \text{S is the inters. of }n\text{ closed half-spaces} \}
f_{c, S}(x) := \|x - c\|^2 \boldsymbol{1}_{x\in S}
\textbf{Prop:} \text{ Let } P,Q\in\mathcal{P}(\mathbb{R}^d) \text{ supported on the unit ball. Then}
\displaystyle\sup_{\rho\in \mathcal{D})k} |W_2^2(\rho, P) - W_2^2(\rho, Q)| \leq 5 k \displaystyle\sup_{c: \|c\|\leq 1, S\in\mathcal{P}_{k-1}} | \mathbb{E}_P f_{c, S} - \mathbb{E}_Q f_{c, S} |
\textbf{Prop: } \exists \text{ univ. constant } C \text{ s.t. if } P \text{ is supp. on the unit ball}
\mathbb{E} \displaystyle\sup_{c: \|c\|\leq 1, S\in\mathcal{P}_{k-1}} | \mathbb{E}_P f_{c, S} - \mathbb{E}_{\hat P} f_{c, S} | \leq C \sqrt{\dfrac{k d \log{k}}{n} }
\text{and }X_1, \dots, X_n\sim \mu \text{ are i.i.d., then}
\Rightarrow \hspace{2mm} \mathbb{E} \displaystyle\sup_{\rho\in \mathcal{D})k} |W_2^2(\rho, \mu) - W_2^2(\rho, \hat \mu)| \lesssim \sqrt{\dfrac{k^3 d \log{k}}{n} }
\text{bounded differences ineq. } \Rightarrow \text{ high-prob. result}