Understanding deep nets

Local Lipschitz functions and learned proximal networks

Jeremias Sulam

SILO Seminar - Nov 2023

Important Open Questions

Important Open Questions

1. Adversarial Attacks

Imperceptible perturbations (to humans) to inputs can compromise ML systems

Important Open Questions

1. Adversarial Attacks

Imperceptible perturbations (to humans) to inputs can compromise ML systems

2. Generalization

Estimate the risk of a model from training samples

1. Adversarial Attacks

Imperceptible perturbations (to humans) to inputs can compromise ML systems

2. Generalization

Estimate the risk of a model from training samples

Important Open Questions

3. Deep models in Inverse Problems

How should standard restoration approaches be adapted to exploit deep learning models?

1. Adversarial Attacks

Imperceptible perturbations (to humans) to inputs can compromise ML systems

2. Generalization

Estimate the risk of a model from training samples

Important Open Questions

3. Deep models in Inverse Problems

How should standard restoration approaches be adapted to exploit deep learning models?

1. Adversarial Attacks

2. Generalization

Agenda

3. Deep models in Inverse Problems

PART I
Local Sparse Lipschitzness
PART II
Learned Proximal Networks
PART I
Local Sparse Lipschitzness
\displaystyle \bullet~~ \text{data}~~ \{(x_i,y_i)\}_{i=1}^n \sim \mathcal D_{\mathcal X \times \mathcal Y}
\displaystyle \bullet~~ \hat{f} \in \underset{f\in\mathcal H}{\arg\min}~~ \underbrace{\frac1n \sum_{i=1}^n \ell(y_i,f(x_i))}_{\hat{R}(f)}
\displaystyle \left( \text{We want low Risk:}~ R(\hat f) = \mathbb E [ \ell(y,\hat f(x)) ] \right)
Setting
f_\theta(x) = W^{(K+1)} \cdot \left( W^{(K)} \texttt{ReLU}(W^{(K-1)} \dots (W^{(1)}x+b_1) \dots + b_{K-1}) + b_K \right)
\mathcal H
Hypothesis Class

nonlinear, complex function

\mathcal H
Hypothesis Class
Sensitivity
\bullet~ ~f(x)~ \text{is Lipschitz (w.r.t inputs) if }~
\|f_\theta(x) - f_\theta(z)\|_2 \leq L_\mathcal{X} \|x-z\|_2 ~~~ \forall ~ x,z \in \mathcal X
\bullet~ ~f(x)~ \text{is Lipschitz (w.r.t parameteres) if }~
\|f_{\theta}(x) - f_{\tilde \theta}(x)\|_2 \leq L_\mathcal{H} \|\theta - \tilde \theta\|_\mathcal{H} ~~~ \forall ~ \theta,\tilde\theta \in \mathcal \Theta
f_\theta(x) = W^{(K+1)} \cdot \left( W^{(K)} \texttt{ReLU}(W^{(K-1)} \dots (W^{(1)}x+b_1) \dots + b_{K-1}) + b_K \right)
"Complex Deep Nets are locally simple"
"Complex Deep Nets are locally simple"

Today's Agenda

\displaystyle \bullet~~ \text{data}~~ \{(x_i,y_i)\}_{i= 1}^n \overset{iid}{\sim} \mathcal D_{\mathcal X \times \mathcal Y}
\displaystyle \bullet~~ \text{find}~~ \hat{f}\in\mathcal F_{NN} ~~\text{so that}~~ \hat{f}(x)\approx y

Adversarial Robustness

Generalization

Learning to solve Inverse Problems

f(\cdot) = \langle \mathbf{w} , \varphi_\theta(\cdot) \rangle , \quad \mathbf{w} \in \mathcal W = \{ \mathbf{w} \in \mathbb{R}^p : \|\mathbf{w}\|_2\leq B \}, \theta \in \Theta
\text{learned representation }~~{\varphi_\theta: \mathcal X \to \mathbb R^p}

Sparse Local Lipschitzness

\varphi(x) \text{\color{black} is SLL at } {x} \text{ \color{black}if } \exists \text{ \color{black}an inactive set } I : \mathcal P_I(\varphi(x))=\mathbf{0}, \text{\color{black}of size } s = |I|, \text{\color{black} so that}
\bullet~ ~ \|\varphi(x) - \varphi(z)\|_2 \leq L(x,s) \|x-z\|_2
\bullet~ ~ \mathcal P_I(\varphi(x)) = \mathcal P_I(\varphi(z)) = \mathbf 0
\forall ~ z : \|x-z\|_2\leq r(x,s)
\text{Consider some representation}
{\varphi_\theta: \mathcal X \to \mathbb R^p}
\{

Sparse Local Lipschitzness

\varphi(x) \text{\color{black} is SLL at } {x} \text{ \color{black}if } \exists~ \text{ \color{black}an inactive set } I\subseteq [p], \text{\color{black}of size} s = |I|, \text{\color{black} so that}
\bullet~ ~ \|\phi(x) - \phi(z)\|_2 \leq L(x,s) \|x-z\|_2
\bullet~ ~ \mathcal P_I(\phi(x)) = \mathcal P_I(\phi(z)) = \mathbf 0
\forall ~ z : \|x-z\|_2\leq r(x,s)

Example 0:

\varphi_\theta(x) = A\cdot x
\varphi_\theta(x) ~~ \text{is Sparse Local Lipschitz with }
L(x,s=0) = \phantom{\sigma_{\max}(A)}, ~~ r(x,s=0) = \phantom{\infty}
L(x,s>0) = 1, ~~ r(x,s>0) = 0

Sparse Local Lipschitzness

\varphi(x) \text{\color{black} is SLL at } {x} \text{ \color{black}if } \exists~ \text{ \color{black}an inactive set } I\subseteq [p], \text{\color{black}of size} s = |I|, \text{\color{black} so that}
\bullet~ ~ \|\phi(x) - \phi(z)\|_2 \leq L(x,s) \|x-z\|_2
\bullet~ ~ \mathcal P_I(\phi(x)) = \mathcal P_I(\phi(z)) = \mathbf 0
\forall ~ z : \|x-z\|_2\leq r(x,s)
L(x,s=0) = {\sigma_{\max}(A)}, ~~ r(x,s=0) = {\infty}

ReLU Networks are Sparse Local Lipschitz

= \texttt{ReLU}
)
(
active
weakly inactive
strongly inactive

ReLU Networks are Sparse Local Lipschitz

= \texttt{ReLU}
)
(
active
weakly inactive
strongly inactive

ReLU Networks are Sparse Local Lipschitz

active
weakly inactive
strongly inactive
= \texttt{ReLU}
)
(
active
weakly inactive
strongly inactive

ReLU Networks are Sparse Local Lipschitz

\text{The map }~\varphi(x) = \sigma(Wx+b) ~~ \text{is Sparse Local Lipschitz with }

Theorem:

r(x,s) = -\text{sort}\left[ \left( \frac{w_i^T~x + b_i}{\|w_i\|_2} \right)_{i=1..p} \right]_{s+1},
= \texttt{ReLU}
)
(
active
weakly inactive
strongly inactive

ReLU Networks are Sparse Local Lipschitz

\text{The map }~\varphi(x) = \sigma(Wx+b) ~~ \text{is Sparse Local Lipschitz with }

Theorem:

L(x,s) = \| W_{[I^c]} \|_2, ~~ |I^c| = p-s
r(x,s) = -\text{sort}\left[ \left( \frac{w_i^T~x + b_i}{\|w_i\|_2} \right) \right]_{s},
= \texttt{ReLU}
)
(
active
weakly inactive
strongly inactive
I^c
\frac{w_i^T~x + b_i}{\|w_i\|_2}
r(x,s)
s
I
I(x,s) = \text{Top-s}\left( - \frac{w_i^T~x + b_i}{\|w_i\|_2} \right)
\text{if}~ r(x,s)>0
\text{larger } s \implies \text{ smaller } r(x,s)

ReLU Networks are Sparse Local Lipschitz

\text{The map }~\varphi(x) = \sigma(Wx+b) ~~ \text{is Sparse Local Lipschitz with }

Theorem:

L(x,s) = \| W_{[J]} \|_2, ~~ |J| = p-s
r(x,s) = -\text{sort}\left[ \left( \frac{w_i^T~x + b_i}{\|w_i\|_2} \right)_{i=1..p} \right]_{s+1},

Lemma:

\text{ If $\varphi^{(i)}(x)$, $i=1,\dots,K$ are sparse local Lipschitz at $x$, then }
\varphi^{[K]}(x) = \varphi^{(K)} \circ \varphi^{(K-1)} \circ \dots \circ \varphi^{(1)}(x)
\text{ is SLL at $x$ with radius $r^{[k]}(x,s^{[k]})$ and Lipschitz scale }
L^{[k]}(x,s) = \prod_{i=1}^K L^{(i)}(x,s^{(i)})

Robustness Certificates

Theorem:

\underset{j}{\arg\max} ~ \langle W, \varphi(x) \rangle = \underset{j}{\arg\max} ~ \langle W, \varphi(x+v) \rangle \quad \forall~v : \|v\|_2\leq \min\left\{ {\color{Mahogany}r({x,s})} , \frac{\text{margin}(\varphi(x))}{2\|W\|_2 {\color{OliveGreen}L({x,s})} } \right\}
Margin vs stability

Optimal Sparsity

s^* = \underset{0\leq s\leq p-\|\varphi(x)\|_0}{\arg\max}~ r_\text{cert}(x,s)
s

Certified Accuracy

Perturbation Energy

0.0

0.5

1.0

1.5

0.0

0.2

0.4

0.6

0.8

1.0

Muthukumar & S. (2023). Adversarial robustness of sparse local lipschitz predictors. SIAM Journal on Mathematics of Data Science, 5(4), 920-948.

Robustness Certificates

Generalization

\displaystyle \bullet~~ \text{data}~~ \{(x_i,y_i)\}_{i=1}^n \sim \mathcal D_{\mathcal X \times \mathcal Y}^n
\displaystyle \bullet~~ \hat{f} \in \underset{f\in\mathcal H}{\arg\min}~~ \underbrace{\frac1n \sum_{i=1}^n \ell(y_i,f(x_i))}_{\hat{R}(f)}
\displaystyle \bullet~~ \text{But we want } \hat{f} \text{ with low Risk:}~ R(\hat f) = \mathbb E [ \ell(y,\hat f(x)) ]

Generalization Gap:

R(\hat f) - \hat{R}(\hat f) \leq \Delta (\hat f) \in [0,1]

Generalization

\text{Given a matrix ${W}$ of size $d_2 \times d_1$, its $(s_1,s_2)$-sparse norm is the }
\text{operator norm of its worst sub-matrices: }
\|W\|_{(s_1,s_2)} = \max_{|J_2|=d_2-s_2} \max_{|J_1|=d_1-s_1} \| \mathcal P_{J_2,J_1} (W) \|_2
\max_{i,j} |W_{i,j}| = \|W\|_{(d_1-1,d_2-1)} \leq \|W\|_{(s_1,s_2)} \leq \|W\|_{(0,0)} = \|W\|_2

Remark 1

Sparse Induced Norms

Generalization

\text{Given a matrix ${W}$ of size $d_2 \times d_1$, its $(s_1,s_2)$-sparse norm is the }
\text{operator norm of its worst sub-matrices: }
\|W\|_{(s_1,s_2)} = \max_{|J_2|=d_2-s_2} \max_{|J_1|=d_1-s_1} \| \mathcal P_{J_2,J_1} (W) \|_2

more sparsity, lower sparse norms

\|W\|_{(s_1,s_2)} \leq \|W\|_{(\tilde{s}_1,\tilde s_2)} \quad \text{ for } \quad (\tilde{s}_1,\tilde s_2) \preceq (s_1,s_2)

Remark 2

Sparse Induced Norms

Generalization

Sparse Induced Norms

Lemma:

\text{Let } f \in \mathcal H_K \text{ with weights } \{ W_k \}_{k=1}^K,\text{ and let } \tilde{f}\in\mathcal B_{\text{sparse}}(f,\boldsymbol{\epsilon}).
\text{Then } f \text{ is SLL w.r.t. its parameters:}
\|W\|_{(s_1,s_2)} = \max_{|J_2|=d_2-s_2} \max_{|J_1|=d_2-s_2} \| \mathcal P_{J_2,J_1} (W) \|_2
\bullet~ ~ \| x_k - \tilde{x}_k\|_2 \leq \gamma_k \prod_{j=1}^k \|W_j\|_{(s_j,s_{j-1})}
\bullet~ ~ \mathcal P_{I_k}(x_k) = \mathcal P_{I_k}(\tilde{x}_k) = \mathbf{0}
\text{If sparse radius is large enough } r_k(f,x,s_k) > \gamma_k = \mathcal O(\prod_{j=1}^k \epsilon_j )
\text{(Bounded small sensitivity)}
\text{ (Strongly inactive entries are preserved) }

Generalization

Sparse Induced Norms

\|W\|_{(s_1,s_2)} = \max_{|J_2|=d_2-s_2} \max_{|J_1|=d_2-s_2} \| \mathcal P_{J_2,J_1} (W) \|_2
\bullet~ ~ \mathcal P_{I_k}(x_k) = \mathcal P_{I_k}(\tilde{x}_k) = \mathbf{0}
\text{ (Strongly inactive entries are preserved) }

Sparse Norms Balls

\tilde{f} \in \mathcal B(f,\boldsymbol \epsilon)~ \text{ if } ~ \| W_k - \tilde{W}_k \|_{(s_k,s_{k-1})} \leq \epsilon_k

Lemma:

\text{Let } \tilde{f} \in \mathcal B(f,\boldsymbol \epsilon). \text{ If the sparse radius at } x \text{ is large enough }
\text{then } f(x) \text{ is SLL (at $x$) w.r.t. its parameters:}
r_k(f,x,s_k) > \gamma_k = \mathcal O(\prod_{j=1}^k \epsilon_j ),
\bullet~ ~ \| f^{(k)}(x) - \tilde{f}^{(k)}(x)\|_2 \leq \gamma_k \prod_{j=1}^k \|W_j\|_{(s_j,s_{j-1})}
\text{(Bounded small sensitivity)}
\bullet~ ~ \mathcal P_{I_k}(f^{(k)}(x)) = \mathcal P_{I_k}(\tilde{f}^{(k)}(x)) = \mathbf{0}
\text{ (Strongly inactive entries are preserved) }

Generalization

Main result (informal)

\text{For any } f \in \mathcal H_K \text{ with bounded sparse norms, with probability } > 1-\delta,
\mathbb P[f(x)\neq y] \leq \hat{R}_{\gamma_K}(f) + \frac Kn \sum_{i=1}^n \mathbf{1}\left[\exists k:r_k(f,x_i,s_k)<\gamma_k\right] + \tilde{\mathcal O} \left( \sqrt{\frac{\text{KL}(\mathcal N(f,\boldsymbol\sigma^2))||P)}{n} } \right)

empirical (margin) risk

sparse loss

deviation from prior

sparse norms:

\text{KL}(\mathcal N(f,\boldsymbol\sigma^2))||P) \propto \sum_k \| W_k \|^2_{(s_k,s_{k-1})} \text{ \color{black}(not exponential in depth)}
\uparrow s ~~ \implies ~~ \downarrow \| W_k \|^2_{(s_k,s_{k-1})}, \downarrow \hat R(f)\text{, but } \uparrow \text{ sparse loss}

sparsity:

Generalization

PAC-Bayes Bounds
\text{1. Choose a prior distribution (ind. of sample): } ~
\text{3. Construct a posterior based on $\hat{f}$: }~
\text{2. Train a predictor: } ~
\text{Gen. Gap} = \tilde{\mathcal O}\left( \sqrt{ \frac{\text{KL}({\color{red}Q}||{\color{blue}P})}{n} } \right)
Observation:
\text{ If } \boldsymbol\sigma \text{ is small and } {f} \sim P \implies {f} \in \mathcal B(\hat{f},\boldsymbol \epsilon) \text{ w.h.p}
\hat{f} \leftarrow \text{ERM}(\mathcal S^n, \ell, \mathcal H^K)
{P =\mathcal N(\mathbf 0,\boldsymbol{\sigma}^2)}
{Q = \mathcal N(\hat{f},\boldsymbol\sigma^2)}

Generalization

Main result (informal)

\text{For any } f \in \mathcal H_K \text{ with bounded sparse norms, with probability } > 1-\delta,
\mathbb P[f(x)\neq y] \leq \hat{R}_{\gamma_K}(f) + \frac Kn \sum_{i=1}^n \mathbf{1}\left[\exists k:r_k(f,x_i,s_k)<\gamma_k\right] + \tilde{\mathcal O} \left( \sqrt{\frac{\text{KL}(\mathcal N(f,\boldsymbol\sigma^2))||P)}{n} } \right)
Is the empirical margin risk smaller than   ?
\gamma_K
Sparse local radius large enough (all layers) over the sample?  

empirical (margin) risk

sparse loss

deviation from prior

Generalization

Main result (informal)

\text{For any } f \in \mathcal H_K \text{ with bounded sparse norms, with probability } > 1-\delta,
\mathbb P[f(x)\neq y] \leq \hat{R}_{\gamma_K}(f) + \frac Kn \sum_{i=1}^n \mathbf{1}\left[\exists k:r_k(f,x_i,s_k)<\gamma_k\right] + \tilde{\mathcal O} \left( \sqrt{\frac{\text{KL}(\mathcal N(f,\boldsymbol\sigma^2))||P)}{n} } \right)

empirical (margin) risk

sparse loss

deviation from prior

sparse norms:

\text{KL}(\mathcal N(f,\boldsymbol\sigma^2))||P) \propto \sum_k \| W_k \|^2_{(s_k,s_{k-1})} \text{ \color{black}(not exponential in depth)}
\uparrow s ~~ \implies ~~ \downarrow \| W_k \|^2_{(s_k,s_{k-1})}, \downarrow \hat R(f)\text{, but } \uparrow \text{ sparse loss}

sparsity:

Generalization

Related results

Generalization

Examples

Muthukumar & S. (2023). Sparsity-aware generalization theory for deep neural networks. In The Thirty Sixth Annual Conference on Learning Theory PMLR.

Increasing width

1. Adversarial Attacks

2. Generalization

Agenda

3. Deep models in Inverse Problems

PART I
Local Sparse Lipschitzness
PART II
Learned Proximal Networks
PART II
Learned Proximal Networks

Learning for Inverse Problems

\left(\text{MAP estimate with }R(x) = -\log p_x\right)
y = A(x^*) + v

measurements

\hat x = \arg\min \frac 12 \| y - A(x) \|^2_2 + R(x)

reconstruction

Learning for Inverse Problems

\hat x = \arg\min_x \frac 12 \| y - A(x) \|^2_2 + R(x)

Proximal Gradient Descent

x^{t+1} = \text{prox}_R \left(x^t - \nabla A(x^t)(A(x^t)-y)\right)

ADMM

x^{t+1} = \arg\min_x \frac 12 \|y-A(x)\|_2^2 + \frac \rho 2 \|x-z^t+u^t\|^2_2
z^{t+1} = \arg\min_z \frac 12 \|z - (x^{t+1}-u^t)\|_2^2 + R(z) = \text{prox}_R \left( x^{t+1}-u^t \right)
u^{t+1} = u^t + x^{t-1} - z^{t-1}

Learning for Inverse Problems

\text{prox}_R \left( u \right) = \arg\min_x \frac 12 \|u - x\|_2^2 + R(x)

I'm going to use any neural network

... that computes a proximal, right?

That computes a proximal, right?

This is a MAP denoiser...
Let's plug-in any off-the-shelf denoiser!

Ongie, Gregory, et al. "Deep learning techniques for inverse problems in imaging." IEEE Journal on Selected Areas in Information Theory 1.1 (2020): 39-56.

Learned Proximal Networks

Proposition

\text{The network defined by } \psi_\theta : \mathbb R^d \to \mathbb R \text{ given by } \psi_\theta(y) = w^Tz_K + b \text{ and }
z_1 = g(H_1y+b_1), \quad z_k = g(W_k z_{k-1} + H_ky + b_k ), k\in [2,K]
g: \text{convex, non-decreasing, } W_k \text{ and }w_K: \text{non-negative entries}.
H_k
\left( \psi_\theta(y,\alpha) = \psi_\theta(y) + \frac \alpha 2 \|y\|^2_2 \right)
\text{Let } f_\theta = \nabla \psi_\theta.
\text{ Then, } \exist ~ R_\theta: \mathbb R^d \to \mathbb R \text{ so that } f_\theta(y) = \text{prox}_{R_\theta}(y)

Learned Proximal Networks

Convergence:

x^{t+1} = \text{prox}_R \left(x^t - \eta A^T(Ax^t - y)\right)
\text{Let } f_\theta \text{ with } \alpha>0, \text{ and } 0<\eta<1/\sigma_{\max}(A),\text{ with activations g: } C^2
\text{Then } \exists x^* : \lim_{k\to\infty} x^t = x^* \text{ and }
f_\theta(x^* - \eta A^T(Ax^*-y)) = x^*

Proximal Gradient Descent

\text{Let } y = x+v,~ v \sim \mathcal N(0,\sigma^2)
f_\theta = \arg\min_{f_\theta:\text{prox}} \mathbb E_{x,y} \left[ {\color{Blue}\ell (f_\theta(y),x)} \right]
\bullet ~~ {\color{Blue}\ell (f_\theta(y),x)} = \|f_\theta(y) - x\|^2_2 ~~\implies~~ \mathbb E[y|x] \text{ (MMSE)}
\bullet ~~ {\color{Blue}\ell (f_\theta(y),x)} = \|f_\theta(y) - x\|_1 ~~\implies~~ \texttt{median}(p_{y|x})

(we don't know     !)

p_x

Proximal Matching

How do we train so that                                   ?

\hat f(y) \approx \text{prox}_{-\log p_x}(y)

Proximal Matching

\text{Let } y = x+v,~ v \sim \mathcal N(0,\sigma^2)
f_\theta = \arg\min_{f_\theta:\text{prox}} \mathbb E_{x,y} \left[ d (f_\theta(y),x) \right]

(we don't know     !)

Theorem (informal)

\hat{f}^* = \arg\min_{f} \lim_{\gamma \searrow 0}~ \mathbb E_{x,y} \left[ \ell_\text{PM}(f_\theta(y),x)\right]
\hat{f}^*(y) = \arg\max_c p_{x|y}(c) = \text{prox}_{-\sigma^2\log p_x}(y)
p_x
\ell_\text{PM} (f_\theta(y),x) = 1- \frac{1}{(\pi\gamma^2)^{n/2}} \exp\left( -\frac{\|f(y)-x\|_2^2}{\gamma_2} \right)

How do we train so that                                   ?

\hat f(y) \approx \text{prox}_{-\log p_x}(y)

Results

What's in your prior?

R_\theta(x) = \langle {\color{blue}\hat{f}^{-1}_\theta}(x),x\rangle - \frac 12 \|x\|^2_2 - \psi_\theta( {\color{blue}\hat{f}^{-1}_\theta}(x) )
(\{ x_i,y_i \}_{i=1}^n) \times \mathcal F_\text{LPN} \times \text{Proximal Matching} ~~ \to ~~ \hat f_\theta
\text{Sample } y = x+v,~ \text{ with } x \sim \frac 12 \exp (-|x|) \text{ and } v \sim \mathcal N(0,\sigma^2)

Results

\text{Sample } y = x+v,~ \text{ with } x \sim \frac 12 \exp (-|x|) \text{ and } v \sim \mathcal N(0,\sigma^2)
(\{ x_i,y_i \}_{i=1}^n) \times \mathcal F_\text{LPN} \times \text{Proximal Matching} ~~ \to ~~ \hat f_\theta

Results

MNIST

Results

Fang, Buchanan & S. (2023). What's in a Prior? Learned Proximal Networks for Inverse Problems. arXiv preprint arXiv:2310.14344.

That is all

Ram Muthukumar

JHU

Zhenghan Fang

JHU

Sam Buchanan

TTIC

NSF CCF 2007649
NIH P41EB031771

Appendix

Example 1:

\varphi_\theta(x) = \underset{\gamma}{\arg\min} \frac12 \|x-D \gamma\|^2_2 + \lambda \|\gamma\|_1

Sparse Local Lipschitzness

[ Mairal et al., '12 ]

\varphi_\theta(x) ~~ \text{is Sparse Local Lipschitz with }
L(x,s) = \frac{1}{\sqrt{1-\text{RIP}(D)_{p-s}}}
r(x,s) = \text{sort}\left[\lambda \textbf{1} - | D^T (x - D\varphi_D(x))| ~\right]_{p-s}

Theorem:

[ Mehta & Gray, '13;   S., Muthukumar, Arora, '20 ]