Zhenghan Fang*, Sam Buchanan*, Jeremias Sulam
IMSI Computational Imaging Workshop
Wed, 7 Aug, 2024
*Equal contribution.
Measure
Inversion
MAP estimate
Prox. Grad. Desc.
Deep neural networks
Proximal Operator
Plug-n-Play
1. When is a neural network \(f_\theta\) a proximal operator?
2. What's the prior \(R\) learned by the neural network \(f_\theta\)?
3. How can we learn the prox of the true prior?
4. Convergence guarantees for PnP
Prior Landscape
Proximal Matching Loss
Learned proximal networks (LPN)
neural networks for prox. operators
1. When is a neural network \(f_\theta\) a proximal operator?
2. What's the prior \(R\) learned by the neural network \(f_\theta\)?
3. How can we learn the prox of the true prior?
4. Convergence guarantees for PnP
Prior Landscape
Proximal Matching Loss
Learned proximal networks (LPN)
neural networks for prox. operators
Compressed Sensing
Fang, Buchanan, Sulam. What's in a Prior? Learned Proximal Networks for Inverse Problems. ICLR 2024.
Sam Buchanan
TTIC
Jeremias Sulam
JHU
MAP estimate
Proximal Gradient Descent
ADMM
\[{\color{darkorange}\mathrm{prox}_{R}} (z) = \argmin_u \tfrac{1}{2} \|u-z\|_2^2 + R(u)\]
MAP denoiser
Prox Operator
Plug-and-Play: replace \({\color{darkorange} \mathrm{prox}_{R}}\) by off-the-shelf denoisers
MAP estimate
Questions
PnP-PGD
PnP-ADMM
SOTA Neural Network based Denoisers...
\(\ell_2\) loss \(\implies\) E[x|y], MMSE denoiser
\(\ell_1\) loss \(\implies\) Median[x|y]
Can we learn the proximal operator of an unknown prior?
\(f_\theta = \mathrm{prox}_{-\log p_x}\)
\(R_\theta = -\log p_x\)
But we want...
\(\mathrm{prox}_{-\log p_x} = \) Mode[x|y], MAP denoiser
Conventional losses do not suffice!
Prox Matching Loss
\[\ell_{\text{PM}, \gamma}(x, y) = 1 - \frac{1}{(\pi\gamma^2)^{n/2}}\exp\left(-\frac{\|f_\theta(y) - x\|_2^2}{ \gamma^2}\right)\]
Learning the prox of a Laplacian distribution
Proposition (Learned Proximal Networks, LPN).
Let \(\psi_\theta: \R^{n} \rightarrow \R\) be defined by \[z_{1} = g( \mathbf H_1 y + b_1), \quad z_{k} = g(\mathbf W_k z_{k-1} + \mathbf H_k y + b_k), \quad \psi_\theta(y) = \mathbf w^T z_{K} + b\]
with \(g\) convex, non-decreasing, and all \(\mathbf W_k\) and \(\mathbf w\) non-negative.
Let \(f_\theta = \nabla \psi_{\theta}\). Then, there exists a function \(R_\theta\) such that \(f_\theta(y) = \mathrm{prox}_{R_\theta}(y)\).
Neural networks that guarantee to parameterize proximal operators
\(\ell_2\) loss \(\implies\) E[x|y], MMSE denoiser
\(\ell_1\) loss \(\implies\) Median[x|y]
Can we learn the proximal operator of an unknown prior?
\(f_\theta = \mathrm{prox}_{-\log p_x}\)
\(R_\theta = -\log p_x\)
Conventional losses do not suffice!
Prox Matching Loss
\[\ell_{\text{PM}, \gamma}(x, y) = 1 - \frac{1}{(\pi\gamma^2)^{n/2}}\exp\left(-\frac{\|f_\theta(y) - x\|_2^2}{ \gamma^2}\right)\]
Theorem (Prox Matching, informal).
Let
\[f^* = \argmin_{f} \lim_{\gamma \searrow 0} \mathbb{E}_{x,y} \left[ \ell_{\text{PM}, \gamma} \left( x, y \right)\right].\]
Then, almost surely (for almost all \(y\)),
\[f^*(y) = \argmax_{c} p_{x \mid y}(c) = \mathrm{prox}_{-\alpha\log p_x}(y).\]
But we want...
\(\mathrm{prox}_{-\log p_x} = \) Mode[x|y], MAP denoiser
LPN provides convergence guarantees for PnP algorithms under mild assumptions.
Theorem (Convergence of PnP-ADMM with LPN, informal)
Consider running LPN with Plug-and-Play and ADMM with a linear forward operator \(A\). Assume the ADMM penalty parameter satisfies \(\rho > \|A^TA\|\). Then, the sequence of iterates converges to a fixed point of the algorithm.
Gaussian noise
Convex interpolation
Fang, Buchanan, Sulam. What's in a Prior? Learned Proximal Networks for Inverse Problems. ICLR 2024.
Sam Buchanan
TTIC
Jeremias Sulam
JHU