Statistical Optimal Transport via Factored Couplings

Study Group on Optimal Transport

Daniel Yukimura

Statistical Optimal Transport via Factored Couplings

Contibutions:

  • Estimation of Wasserstein distance in HD, using low transport rank.
  • Efficient implementation.
  • Theoretical and empirical evidence.

Context:

W_2(P_0, P_1) = \displaystyle\inf_{\gamma \in \Gamma(P_0, P_1)} \sqrt{ \int_{\mathbb{R}^d \times \mathbb{R}^d} \|x-y\|^2 d\gamma(x,y) }
X_1, \dots, X_n \overset{\text{i.i.d.}}{\sim} P_0

Wasserstein distance

\Rightarrow \hspace{2mm} W_2(\hat P_0, \hat P_1)

Plug-in estimator

Y_1, \dots, Y_n \overset{\text{i.i.d.}}{\sim} P_1
  • linear programming

Context:

  • [Sommerfield and Munk, 2017]: 
\bullet \hspace{3mm} \Delta_n = \left|W_2(\hat P_0, \hat P_1) - W_2(P_0, P_1) \right|
\Delta_n \asymp n^{-1/2}, \hspace{3mm} \text{if }P_0\neq P_1
\Delta_n \asymp n^{-1/4}, \hspace{3mm} \text{if }P_0 = P_1

only for finite support !

  • [Dobrić and Yukich, 1995]:
\Delta_n \asymp n^{-1/d}, \hspace{3mm} d \geq 3

Transport rank (TR):

\textbf{Definition:} \text{ Given } \gamma \in \Gamma(P_0, P_1), \text{ the TR of }\gamma
\text{ is the smallest }k\in\mathbb{Z} \text{ s.t.}
\gamma = \displaystyle \sum_{j=1}^k \lambda_j \left ( Q_j^0 \otimes Q_j^1 \right)
\text{where the }Q_j^0, Q_j^1 \in \mathcal{P}\left(\mathbb{R}^d\right), \text{ and } \lambda_j \geq 0.
\bullet \hspace{3mm} \text{ When }P_0 \text{ and }P_1\text{ have finite supp., the TR of } \gamma\in\Gamma(P_0, P_1)
\text{coincides with the rank of }\gamma\text{ viewed as matrix.}

factored coupling

\bullet \hspace{3mm} \Gamma_k(P_0,P_1) \text{ denote couplings with TR at most }k.

Regularization via Factored Couplings:

Idea: Opt. couplings in practice can be well approx. by assuming the distributions have a small number of pieces moving nearly independently.

\textbf{Def 2:} \text{ A }\textbf{soft cluster}\text{ of a prob. }P \text{ is a sub-prob. }C
\text{of total mass }\lambda\in [0,1]\text{ s.t. } 0 \leq C \leq P.
\text{We say that a collection } C_1,\dots, C_k \text{ of soft clusters of } P
\text{ is a } \textbf{partition} \text{ of }P \text{ if } C_1+\dots+C_k = P
\textbf{centroid} \text{ of } C: \mu(C) = \frac{1}{\lambda} \int x dC(x)
\textbf{Prop:} \text{ Let }\gamma\in\Gamma(P_0, P_1) \text{ and let } C_1^0,\dots C_k^0 \text { and }
C_1^1,\dots C_k^1 \text{ be the induced part. of }P_0 \text{ and }P_1
\displaystyle\int \|x-y\|^2d\gamma(x,y) = \displaystyle\sum_{j=1}^k \left( \lambda_j \|\mu(C_j^0) - \mu(C_j^1)\|^2 \vphantom{.\displaystyle\sum_{\ell\in\{0,1\}} } \right .
\left. + \displaystyle\sum_{\ell\in\{0,1\}} \displaystyle \int \|x - \mu(C_j^\ell)\|^2 d C_j^\ell (x) \right)
\textbf{Def 3:} \text{ The cost of a factored coupling }\gamma\in\Gamma(P_0, P_1) \text{ is}
\text{cost}(\gamma) = \displaystyle\sum_{j=1}^k \lambda_j \|\mu(C_j^0) - \mu(C_j^1)\|^2
\underset{\gamma\in\Gamma_k(\hat P_0, \hat P_1)}{\text{argmin}} \displaystyle \int \|x-y\|^2 d\gamma(x,y)

Regularized optimal coupling:

Questions:

  • How to solve the above problem efficiently ?
  • Is the solution robust ?

k-Wasserstein barycenters

\mathcal{D}_k = \left\{ \sum\limits_{j=1}^k \alpha_j \delta_{x_j}: \alpha_j \geq 0, \sum\limits_{j=1}^k \alpha_j = 1, x_j\in\mathbb{R}^d \right\}

prob. supported on k points

\bar{P} = \underset{P\in \mathcal{D}_k}{\text{argmin}} \displaystyle\sum_{j=1}^N W_2^2 (P, P_j)

k-Wasserstein barycenter

H = \underset{P\in \mathcal{D}_k}{\text{argmin}} \left\{ W_2^2(P, \hat P_0)+ W_2^2(P, \hat P_1) \right\}
\bullet \hspace{3mm} \exists \text{ efficient procedure for finding the barycenters}
\bullet \hspace{3mm} \gamma_0 = \text{OC}(\hat P_0, H), \hspace{3mm} \gamma_1 = \text{OC}(\hat P_1, H)
\bullet \hspace{3mm} \text{supp}(H) = \{z_1, \dots, z_k\}
\gamma_0 = \displaystyle\sum_{j=1}^k \gamma_0(\cdot | z_j) H(z_j), \hspace{3mm} \gamma_1 = \displaystyle\sum_{j=1}^k \gamma_1(\cdot | z_j) H(z_j)
\Rightarrow \gamma_H(A\times B) = \displaystyle\sum_{j=1}^k H(z_j) \gamma_0(A | z_j) \gamma_1(B | z_j)

induced factored coupling:

\textbf{Prop:} \text{ The partitions } C_1^0,\dots C_k^0 \text { and } C_1^1,\dots C_k^1
\text{induced by } H \text{ are the min. of}
\displaystyle\sum_{j=1}^k \left( \frac{\lambda_j}{2} \|\mu(C_j^0) - \mu(C_j^1)\|^2 + \displaystyle\sum_{\ell\in\{0,1\}} \displaystyle \int \|x - \mu(C_j^\ell)\|^2 d C_j^\ell (x) \right)
\widehat W := \text{cost}(\gamma_H)
\widehat T (X_i) = X_i + \frac{1}{\sum\limits_{j=1}^k C_j^0(X_i)} \displaystyle\sum_{j=1}^k C_j^0(X_i) (\mu(C_j^1) - \mu(C_j^0))
  • Estimator for the squared Wasserstein distance:
  • Estimated transport map:
\text{How to estimate } \gamma_H ?
H = \underset{P\in \mathcal{D}_k}{\text{argmin}} \left\{ W_2^2(P, \hat P_0)+ W_2^2(P, \hat P_1) \right\}
\bullet \hspace{3mm} \text{It is sep. convex in } \mathcal{H} = \{z_1,\dots, z_k\} \text{ and } (\gamma_0, \gamma_1)
\Rightarrow \text{ admits alternating optim.}
z_j = \dfrac{\displaystyle\sum_{i=1}^n \gamma_0(z_j, X_i)X_i + \displaystyle\sum_{i=1}^n \gamma_1(z_j, Y_i)Y_i }{ \displaystyle\sum_{i=1}^n \gamma_0(z_j, X_i) + \displaystyle\sum_{i=1}^n \gamma_1(z_j, Y_i) }
\text{Updating hubs:}
\text{Updating }(\gamma_0, \gamma_1)
-\varepsilon \displaystyle\sum_{i,j} (\gamma_0)_{j,i} \log{\left((\gamma_0)_{j,i}\right)} -\varepsilon \displaystyle\sum_{i,j} (\gamma_1)_{j,i} \log{\left((\gamma_1)_{j,i}\right)}
\bullet \hspace{2mm} \text{Add entropic regularization}
\bullet \hspace{2mm} \text{Use Sinkhorn iteration}
\textbf{Theorem:} \text{ Let }P\in\mathcal{P}(\mathbb{R}^d) \text{ supported on the unit ball}
\text{and }\hat P \text{ the empirical meas. on a sample of size } n.
\text{Then with prob. at least }1-\delta
\displaystyle\sup_{\rho\in\mathcal{D}_k} \left| W_2^2(\rho, \hat P) - W_2^2(\rho, P) \right| \lesssim \sqrt{ \dfrac{k^3 d \log{k} + \log{(1/\delta)} }{n} }
\bullet \hspace{2mm} \text{Recovers the } n^{-1/2} \text{ bound}.
\bullet \hspace{2mm} \text{Can be extended to compact support}.
\bullet \hspace{2mm} \text{It doesn't provide rates for }\widehat{W}.

Experiments:

Fragmented hypercube (synthetic):

P_0 = \text{Unif}([-1,1]^d)
P_1 = T_\#(P_0), \hspace{3mm} T(X) = X + 2 \text{sign}(X)\odot(e_1+e_2)

OT

H

FOT

Experiments:

Fragmented hypercube (synthetic):

Experiments:

Batch correction for single cell RNA data:

  • Each coordinate represents the expression-level of the corresponding gene.
\bullet \hspace{3mm} \text{Dim}\sim 10^4
\bullet \hspace{3mm} \text{Number of cells}\sim 10^2\text{ to }10^6
  • haematopoetic (blood) datasets from different sources.

Experiments:

Batch correction for single cell RNA data:

Sketch proof for Theorem 4:

\text{Given }c\in \mathbb{R}^d \text{ and } S\in\mathcal{P}_{k-1} \text{ define}
\mathcal{P}_{n} = \{ S\subset \mathbb{R}^d: \text{S is the inters. of }n\text{ closed half-spaces} \}
f_{c, S}(x) := \|x - c\|^2 \boldsymbol{1}_{x\in S}
\textbf{Prop:} \text{ Let } P,Q\in\mathcal{P}(\mathbb{R}^d) \text{ supported on the unit ball. Then}
\displaystyle\sup_{\rho\in \mathcal{D})k} |W_2^2(\rho, P) - W_2^2(\rho, Q)| \leq 5 k \displaystyle\sup_{c: \|c\|\leq 1, S\in\mathcal{P}_{k-1}} | \mathbb{E}_P f_{c, S} - \mathbb{E}_Q f_{c, S} |
\textbf{Prop: } \exists \text{ univ. constant } C \text{ s.t. if } P \text{ is supp. on the unit ball}
\mathbb{E} \displaystyle\sup_{c: \|c\|\leq 1, S\in\mathcal{P}_{k-1}} | \mathbb{E}_P f_{c, S} - \mathbb{E}_{\hat P} f_{c, S} | \leq C \sqrt{\dfrac{k d \log{k}}{n} }
\text{and }X_1, \dots, X_n\sim \mu \text{ are i.i.d., then}
\Rightarrow \hspace{2mm} \mathbb{E} \displaystyle\sup_{\rho\in \mathcal{D})k} |W_2^2(\rho, \mu) - W_2^2(\rho, \hat \mu)| \lesssim \sqrt{\dfrac{k^3 d \log{k}}{n} }
\text{bounded differences ineq. } \Rightarrow \text{ high-prob. result}