Estimation of Wasserstein distances in the Spiked Transport Model

Study Group on Optimal Transport

Daniel Yukimura

Estimation of Wasserstein distances in the Spiked Transport Model

Contibutions:

  • Estimation of Wasserstein distance in HD for distributions that differ only in a LD subspace.
  • Lower bounds.
  • A computational-statistical gap.

Introduction:

W_p(\mu, \nu) = \displaystyle\inf_{\gamma \in \Gamma_{\mu,\nu}} \left( \int_{\mathbb{R}^d \times \mathbb{R}^d} \|x-y\|^p d\gamma(x,y) \right)^{1/p}

Wasserstein distance

Prop. 1: Let \(\mu\) be a prob. on \([-1,1]^d\). If \(\mu_n\) is the assoc. empirical measure, then for any \(p\in [1,\infty]\)

 

\mathbb{E} W_p (\mu_n, \mu) \leq r_{p,d}(n) := c_p \sqrt{d} \left\{\begin{matrix} n^{-1/2p} \phantom{(\log{n})^{1/p}} & \text{if } d < 2p\\ n^{-1/2p}(\log{n})^{1/p} & \text{if } d = 2p\\ n^{-1/d} \phantom{(\log{n})^{1/p}} & \text{if } d > 2p \end{matrix}\right.

Spiked Transport Model:

  • Fix \( \mathcal{U}\subseteq \mathbb{R}^d\) of dim. \(k \ll d\).
  • R.v. \(X^{(1)}, X^{(2)}\in \mathcal{U}\) with arbitrary distributions.
  • R.v. \(Z\perp (X^{(1)}, X^{(2)})\) supported on \(\mathcal{U}^{\perp}\).
\mu^{(1)} := \text{Law}(X^{(1)}+ Z)
\mu^{(2)} := \text{Law}(X^{(2)}+ Z)

low dim.

Spiked Transport Model:

Question: Given \(n\) i.i.d. observations from both \(\mu^{(1)}\) and \(\mu^{(2)}\) is it possible to estimate \(W_p(\mu^{(1)}, \mu^{(2)})\) at a rate faster than \(n^{-1/d}\)?

  • Answer is yes, but we need smoothness and decay assumptions on the measures.

Concentration Assumptions:

A prob. meas. \(\mu\) on \(\mathbb{R}^d\) satisfy the \(T_p(\sigma^2)\) transport inequality if

W_p(\nu, \mu) \leq \sqrt{2 \sigma^2 D(\nu \| \mu)} \hspace{4mm} \forall \nu\in \text{Prob}(\mathbb{R}^d)

Wasserstein Projection Pursuit:

  • Consider the set \( \mathcal{V}_k(\mathbb{R}^d) \) of \(k\times d\) matrices with orthonormal rows.
  • For \(\mu \in \text{Prob}(\mathbb{R}^d)\) and \( U\in \mathcal{V}_k(\mathbb{R}^d) \), define \(\mu_U\) as the law of \(U.Y\) for \(Y\sim\mu\).

Def.: For \(k\in [d]\), the \(k\)-dimensional Wasserstein distance between \(\mu^{(1)}\) and \(\mu^{(2)}\) is

\tilde W_{p,k} (\mu^{(1)}, \mu^{(2)}) = \underset{U\in \mathcal{V}_k(\mathbb{R}^d)}{\sup} W_p(\mu_U^{(1)}, \mu_U^{(2)})

Wasserstein Projection Pursuit:

\widehat W_{p,k} = \tilde W_{p,k} \left(\mu^{(1)}_n, \mu^{(2)}_n\right)

WPP estimator

Estimation:

Theorem 1: Let \( (\mu^{(1)}, \mu^{(2)}) \)  sats. the Spiked Transport Model (STM). For any \(p\in[1,2]\), if \( \mu^{(1)}\) and \(\mu^{(2)}\) sats. the \(T_p(\sigma^2)\) ineq., then

\mathbb{E} \left| \widehat W_{p,k} - W_p(\mu^{(1)}, \mu^{(2)}) \right| \leq c_k \sigma \left( r_{p,k}(n) + \sqrt{\frac{d\log{n}}{n}} \right)

Before proving:

Prop. 3: Under the STM

\tilde W_{p,k}(\mu^{(1)}, \mu^{(2)}) = W_{p}(\mu^{(1)}, \mu^{(2)})

Thm. 6: Let \(p\in [1,2]\). A meas. \(\mu\in \mathcal{P}_p(\mathbb{R}^d)\) (prob. with  finite pth moment) satisfies \(T_p(\sigma^2)\) if and only if the r.v. \(W_p(\mu_n, \mu)\) is \(\sigma^2/n\)-subgaussian for all \(n\).

Prop. 5: Let \(U\in \mathcal{V}_k(\mathbb{R}^d)\). For any \(p\in[1,2]\) and \(\sigma>0\), if \(\mu\) satisfies \(T_p(\sigma^2)\), then so does \(\mu_U\).

Proof for estimation:

First notice that

\[\mathbb{E} \left| \widehat W_{p,k} - W_p(\mu^{(1)}, \mu^{(2)}) \right| \leq \mathbb{E} \tilde W_{p,k} (\mu^{(1)}, \mu^{(1)}_n) + \mathbb{E} \tilde W_{p,k} (\mu^{(2)}, \mu^{(2)}_n)\] 

Then we can focus on bounding \(\mathbb{E} \tilde W_{p,k} (\mu, \mu_n)\).

 

  • We assume w.l.g. that \(\mu\) has mean 0, and \(\sigma=1\).
  • Consider \(Z_U := W_p(\mu_U, (\mu_n)_U)\)

Lemma:  \(\exists\) r.v. \(L\) s.t. for all \(U,V\in \mathcal{V}_k(\mathbb{R}^d)\)

\[|Z_U - Z_V| \leq L\|U-V\|_{op}\]

and \(\mathbb{E} L \lesssim \sqrt{dp}\)

Proof for estimation:

Lemma:  \(\exists\) r.v. \(L\) s.t. for all \(U,V\in \mathcal{V}_k(\mathbb{R}^d)\)

\[|Z_U - Z_V| \leq L\|U-V\|_{op}\]

and \(\mathbb{E} L \lesssim \sqrt{dp}\)

proof: Let \(X\sim \mu\), then

|Z_U - Z_V| \leq W_p (\mu_U, \mu_V) + W_p((\mu_n)_U, (\mu_n)_V)
\leq \left( \mathbb{E}\|(U-V)X\|^p \right)^{1/p} + \left( \frac{1}{n}\sum\limits_{i=1}^n \|(U-V)X_i\|^p \right)^{1/p}
\leq \|U-V\|_{op}\left(\left( \mathbb{E}\|X\|^p \right)^{1/p} + \left( \frac{1}{n}\sum\limits_{i=1}^n \|X_i\|^p \right)^{1/p}\right)
\leq L \|U-V\|_{op}

Proof for estimation:

  • The process \(Z_U\) is Lipschitz,
  • By Thm. 6 \(Z_U\) is \(n^{-1}\)-subgaussian.

Now, using a standard \(\varepsilon\)-net argument on our estimation over \(\mathcal{V}_k(\mathbb{R}^d)\) we get

\mathbb{E} \underset{U\in \mathcal{V}_k(\mathbb{}R)^d}{\sup} (Z_U - \mathbb{E} Z_U) \lesssim \underset{\varepsilon>0}{\inf} \left\{ \varepsilon \mathbb{E}L + \sqrt{\frac{\log{\mathcal{N}(\mathcal{V}_k, \varepsilon, \|\cdot\|_{op})}}{n}} \right\}

Where \(\mathcal{N}(\mathcal{V}_k, \varepsilon, \|\cdot\|_{op})\) is the covering number of \(\mathcal{V}_k\) with resp. to the op. norm.

Proof for estimation:

\mathbb{E} \underset{U\in \mathcal{V}_k(\mathbb{}R)^d}{\sup} (Z_U - \mathbb{E} Z_U) \lesssim \sqrt{\frac{dkp}{n}} + c_p\sqrt{\frac{dk\log{n}}{n}}

There exist a univ. const. \(c\) such that \(\mathcal{N}(\mathcal{V}_k, \varepsilon, \|\cdot\|_{op}) \leq dk \log{\frac{c\sqrt{k}}{\varepsilon}}\) for \(\varepsilon\in (0,1]\). Choosing \(\varepsilon = \sqrt{k/n}\) yelds

\leq c_p\sqrt{\frac{dk\log{n}}{n}}

Proof for estimation:

\mathbb{E} \underset{U\in \mathcal{V}_k(\mathbb{}R)^d}{\sup} W_p(\mu_U, (\mu_n)_U) \leq \underset{U\in \mathcal{V}_k(\mathbb{}R)^d}{\sup} \mathbb{E} W_p(\mu_U, (\mu_n)_U)

Finally we get,

\lesssim r_{p,k}(n) + c_p\sqrt{\frac{dk\log{n}}{n}}
+ \hspace{2mm}\mathbb{E} \underset{U\in \mathcal{V}_k(\mathbb{}R)^d}{\sup} (Z_U - \mathbb{E} Z_U)

Spike estimation:

Theorem 10: Let \(p\in [1,2]\). Assume \( (\mu^{(1)}, \mu^{(2)}) \)  sats. STM and the \(T_p(\sigma^2)\) ineq. Let \(\hat \mathcal{U} := \text{span}(\hat U)\), where

\mathbb{E} \sin^2(\measuredangle (\hat\mathcal{U}, \mathcal{U}) ) \lesssim \frac{\sigma \left( r_{p,k}(n) + c_p \sqrt{\frac{dk\log{n}}{n}} \right)}{W_p(\mu^{(1)}, \mu^{(2)})}
\hat U := \underset{U\in \mathcal{V}_k(\mathbb{R}^d)}{\argmax} W_p( (\mu_n^{(1)})_U, \mu_n^{(2)})_U )

Then

Lower Bounds:

Consider a compact metric space \(\mathcal{X}\) s.t.

\[c \varepsilon^{-d} \leq \mathcal{N}(\mathcal{X}, \varepsilon) \leq C \varepsilon^{-d}\]

for all \(\varepsilon \leq \text{diam}(\mathcal{X})\).

With \(\mathcal{P} = \text{Prob}(\mathcal{\mathcal{X}})\) define

\[R(n, \mathcal{P}) := \underset{\hat W}{\inf}\underset{\mu, \nu\in\mathcal{P}}{\sup} \mathbb{E}_{\mu, \nu}|\hat W - W_p(\mu,\nu)|.\]

Lower Bounds:

Theorem 11: Let \(d>2p>2\) and \(\mathcal{X}\) with the cov. number as before. Then

\[R(n,\mathcal{P}) \geq C_{d,p} (d \log{n})^{-1/d}\]

\[\mathbb{E}|\hat W - W_p(\mu^{(1)},\mu^{(2)})| \gtrsim \sigma \sqrt{\frac{d}{n}}\]

Theorem 4:

Lower Bounds:

Prop. 9: Assume \(d>2p>2\) and let \(m\) be a pos. integer. Let \(u=\text{Unif}([m])\). \(\exists\) a random function \(F:[m]\rightarrow \mathcal{X}\) s.t. for any dist. \(q\) on \([m]\),

c m^{-1/d} d_{TV}(q,u)^{\frac{1}{p}} \leq W_p(F_\#q, F_\#u) \leq C_{d,p} m^{-1/d} (\chi^2(q,u))^{1/d} d_{TV}(q,u)^{\frac{1}{p} - \frac{2}{d}}

with prob. at least \(.9\)

Lower Bounds:

proof: For the lower bound:

Use the cov. num. to get a set of points \(G_m = \{x_i\}_{i\in[m]}\) s.t. \(d(x_i, x_j) \gtrsim m^{-1/d}\). Then select \(F\) unif. at random from the set of bijections from \([m]\) to \(G_m\). Now, since

\[d(x,y)^p \gtrsim m^{-p/d} \boldsymbol{1}\{x\neq y\}\]

we can com get for any coupling \(\pi\) of \(F_\#q\) and \(F_\#u\)

\[\int d(x,y)^p d\pi(x,y) \gtrsim m^{-p/d}\mathbb{P}[X\neq Y]\geq m^{-p/d} d_{TV}(q, u)\]

Lower Bounds:

For the upper bound: ...

 

Lower Bounds:

Prop. 10: Fix \(n\in\mathbb{N}\) and a cnt. \(\delta\in [0, .1]\). Given \(m\in\mathbb{N}\), let \(D_m\) be the set of prob. \(q\) in \([m]\) sats. \(\chi^2(q, u)\leq 9\). Denote by \(D_{m,\delta}^-\) the subset of \(D_m\) where \(d_{TV}(q,u)\leq \delta\) and by \(D_m^+\) the subset sats. \(d_{TV}(q,u)\geq 1/4\). If \(m = \lceil C\delta^{-1}n\log{n} \rceil\) for a sufficiently large univ. \(C\) and \(n\) is sufficiently large, then

\[\underset{\psi}{\inf} \left\{ \underset{q\in D_m^+}{\sup} \mathbb{P}_q [\psi=1] + \underset{q\in D_{m,\delta}^-}{\sup} \mathbb{P}_q [\psi=0] \right\} \geq .9 \]

where the \(\inf\) is taken over all test based on \(n\) samples.

Lower Bounds:

Proof of Thm. 11:

Let \(A = \{|\hat W - W_p(F_\#q, F_\#u)|\geq \Delta_d\}\), with \(\Delta_d = \frac{1}{16} c^* m^{-1/d} \). Then

 

\underset{\mu, \nu\in\mathcal{P}}{\sup} \mathbb{P}_{\mu, \nu}\left[|\hat W - W_p(\mu,\nu)|\geq \Delta_d\right] \geq \dfrac{1}{2} \left( \underset{q\in D_m^+}{\sup} \mathbb{E}_F\mathbb{P}_{F_\# q, F_\# u} [A] \right.
\left. + \underset{q\in D_{m,\delta}^-}{\sup} \mathbb{E}_F\mathbb{P}_{F_\# q, F_\# u} [A] \right)

We def. the rand. test

\[\psi(X_1,\dots,X_n) := \boldsymbol{1}\{\hat W(F(X_1),\dots,F(X_n);  F(Y_1),\dots,F(Y_n)) \leq 2\Delta_d\}\]

Lower Bounds:

\underset{\mu, \nu\in\mathcal{P}}{\sup} \mathbb{P}_{\mu, \nu}\left[|\hat W - W_p(\mu,\nu)|\geq \Delta_d\right] \geq \dots
\geq \dfrac{1}{2} \left( \underset{q\in D_m^+}{\sup} \mathbb{P}_q [\psi=1] + \underset{q\in D_{m,\delta}^-}{\sup} \mathbb{P}_q [\psi=0] \right) - .1

Choosing \(m = \lceil C\delta^{-1}n\log{n} \rceil\) for suff. large C and apply prop. 10 gives \(\underset{\mu, \nu\in\mathcal{P}}{\sup} \mathbb{P}_{\mu, \nu}\left[|\hat W - W_p(\mu,\nu)|\geq \Delta_d\right] \geq .8\), and Markov's ineq. yields the claim.

Computational-Statistical Gap:

Def. 5: Given a dist. \(\mathcal{D}\) on \(\mathbb{R}^d\), for any sample size \(t>0\) and \(f:\mathbb{R}^d\rightarrow [0,1]\), the oracle \(\text{VSTAT}(t)\) returns a value  \(v\in [p-\tau, p+\tau]\), where \(p = \mathbb{E}f(X)\) and \(\tau = \frac{1}{t}\vee \sqrt{\frac{p(1-p)}{t}}\)

Computational-Statistical Gap:

Theorem 12: There exists a pos. univ. constant \(c\) s.t., for any \(d\), estimating \(W_1(\mu^{(1)}, \mu^{(2)})\) for \(\mu^{(1)}\), \(\mu^{(2)}\) sats. the STM with \(k=1\) to accuracy \(\Theta(1/\sqrt{d})\) with prob. at least \(2/3\) requires at least \(2^{cd}\) queries to \(\text{VSTAT}(2^{cd})\).