What's in my network?
Jeremias Sulam
On learned proximals and testing for explanations



CSIP Seminar, Georgia Tech October 2024








"The biggest lesson that can be read from 70 years of AI research is that general methods that leverage computation are ultimately the most effective, and by a large margin. [...] Seeking an improvement that makes a difference in the shorter term, researchers seek to leverage their human knowledge of the domain, but the only thing that matters in the long run is the leveraging of computation. [...]
We want AI agents that can discover like we can, not which contain what we have discovered."
The Bitter Lesson, Rich Sutton 2019


What did my model learn?

PART I
Inverse Problems
PART II
Image Classification
Inverse Problems

measurements
reconstruction
Inverse Problems

measurements
reconstruction

Image Priors

Deep Learning in Inverse Problems

Option A: One-shot methods
Given enough training pairs \({(x_i,y_i)}\) train a network
\(f_\theta(y) = g_\theta(A^+y) \approx x\)

[Mousavi & Baraniuk, 2017]
[Ongie, Willet, et al, 2020]
Deep Learning in Inverse Problems
Option B: data-driven regularizer
- Priors as critics
[Lunz, Öktem, Schönlieb, 2020] and others ..
- via MLE
[Ye Tan, ..., Schönlieb, 2024], ...
- RED
[Romano et al, 2017] ...
- Generative Models
[Bora et al, 2017] ...
Deep Learning in Inverse Problems
Option C: Implicit Priors
Proximal Gradient Descent: \( x^{t+1} = \text{prox}_R \left(x^t - \eta A^T(Ax^t-y)\right) \)
... a denoiser
Deep Learning in Inverse Problems
any latest and greatest NN denoiser


[Venkatakrishnan et al., 2013; Zhang et al., 2017b; Meinhardt et al., 2017; Zhang et al., 2021; Kamilov et al., 2023b; Terris et al., 2023]
[Gilton, Ongie, Willett, 2019]
Proximal Gradient Descent: \( x^{t+1} = {\color{red}f_\theta} \left(x^t - \eta A^T(A(x^t)-y)\right) \)
Option C: Implicit Priors
Question 1)
When will \(f_\theta(x)\) compute a \(\text{prox}_R(x)\) ? and for what \(R(x)\)?
Deep Learning in Inverse Problems
\(\mathcal H_\text{prox} = \{f : \text{prox}_R~ \text{for some }R\}\)
\(\mathcal H = \{f: \mathbb R^n \to \mathbb R^n\}\)
Question 1)
When will \(f_\theta(x)\) compute a \(\text{prox}_R(x)\) ? and for what \(R(x)\)?
Question 2)
Can we estimate the "correct" prox?
Deep Learning in Inverse Problems
\(\mathcal H_\text{prox} = \{f : \text{prox}_R~ \text{for some }R\}\)
\(\mathcal H = \{f: \mathbb R^n \to \mathbb R^n\}\)

Interpretable Inverse Problems
Question 1)
When will \(f_\theta(x)\) compute a \(\text{prox}_R(x)\) ?

Theorem [Gibonval & Nikolova, 2020]
\( f(x) \in \text{prox}_R(x) ~\Leftrightarrow \exist ~ \text{convex l.s.c.}~ \psi: \mathbb R^n\to\mathbb R : f(x) \in \partial \psi(x)~\)
Interpretable Inverse Problems
Question 1)
When will \(f_\theta(x)\) compute a \(\text{prox}_R(x)\) ?
\(R(x)\) need not be convex
Theorem [Gibonval & Nikolova, 2020]

Learned Proximal Networks
Take \(f_\theta(x) = \nabla \psi_\theta(x)\) for convex (and differentiable) \(\psi_\theta\)

\( f(x) \in \text{prox}_R(x) ~\Leftrightarrow \exist ~ \text{convex l.s.c.}~ \psi: \mathbb R^n\to\mathbb R : f(x) \in \partial \psi(x)~\)

Interpretable Inverse Problems
Question 1)
When will \(f_\theta(x)\) compute a \(\text{prox}_R(x)\) ?
\(R(x)\) need not be convex

Learned Proximal Networks
Take \(f_\theta(x) = \nabla \psi_\theta(x)\) for convex (and differentiable) \(\psi_\theta\)
\( f(x) \in \text{prox}_R(x) ~\Leftrightarrow \exist ~ \text{convex l.s.c.}~ \psi: \mathbb R^n\to\mathbb R : f(x) \in \partial \psi(x)~\)
Theorem [Gribonval & Nikolova, 2020]

Interpretable Inverse Problems
If so, can you know for what \(R(x)\)?
[Gibonval & Nikolova]
Easy! \[{\color{grey}y^* =} \arg\min_{y} \psi(y) - \langle y,x\rangle {\color{grey}= \hat{f}_\theta^{-1}(x)}\]

Interpretable Inverse Problems
Question 2)
Could we have \(R(x) = -\log p_x(x)\)?
(we don't know \(p_x\)!)
i.e. \(f_\theta(y) = \text{prox}_R(y) = \texttt{MAP}(x|y)\)
Which loss function?
Interpretable Inverse Problems
i.e. \(f_\theta(y) = \text{prox}_R(y) = \texttt{MAP}(x|y)\)

Theorem (informal)
Proximal Matching Loss
\(\gamma\)
Question 2)
Could we have \(R(x) = -\log p_x(x)\)?
(we don't know \(p_x\)!)
Learned Proximal Networks






























Learned Proximal Networks



























Learned Proximal Networks



\(R(\tilde{x})\)
Theorem (PGD with Learned Proximal Networks)
Let \(f_\theta = \text{prox}_{\hat{R}} {\color{grey}\text{ with } \alpha>0}, \text{ and } 0<\eta<1/\sigma_{\max}(A) \) with smooth activations
(Analogous results hold for ADMM)
Learned Proximal Networks
Convergence guarantees for PnP

Fang, Buchanan & S. What's in a Prior? Learned Proximal Networks for Inverse Problems. ICLR 2024.

Learned Proximal Networks
What did my model learn?

PART I
Inverse Problems
PART II
Image Classification
Interpretability in Image Classification
\((X,Y) \in \mathcal X \times \mathcal Y\)
\((X,Y) \sim P_{X,Y}\)
\(\hat{Y} = f(X) : \mathcal X \to \mathcal Y\)
Setting:



-
What features are important for this prediction?
-
What does importance mean, exactly?



Is the piano important for \(\hat Y = \text{cat}\)?
Semantic Interpretability of classifiers
How can we explain black-box predictors with semantic features?
Is the piano important for \(\hat Y = \text{cat}\), given that there is a cute mammal in the image?



Is the piano important for \(\hat Y = \text{cat}\)?
Semantic Interpretability of classifiers
How can we explain black-box predictors with semantic features?
Is the piano important for \(\hat Y = \text{cat}\), given that there is a cute mammal in the image?
Post-hoc Interpretability Methods
Interpretable by
construction



Is the piano important for \(\hat Y = \text{cat}\)?
Semantic Interpretability of classifiers
How can we explain black-box predictors with semantic features?
Is the piano important for \(\hat Y = \text{cat}\), given that there is a cute mammal in the image?
Post-hoc Interpretability Methods
Interpretable by
construction


Semantic Interpretability of classifiers
Concept Bank: \(C = [c_1, c_2, \dots, c_m] \in \mathbb R^{d\times m}\)
Embeddings: \(H = f(X) \in \mathbb R^d\)
Semantics: \(Z = C^\top H \in \mathbb R^m\)
Concept Bank: \(C = [c_1, c_2, \dots, c_m] \in \mathbb R^{d\times m}\)
Concept Activation Vectors
(Kim et al, 2018)
\(c_\text{cute}\)
Semantic Interpretability of classifiers
Vision-language models
(CLIP, BLIP, etc... )


Semantic Interpretability of classifiers
[Bhalla et al, "Splice", 2024]
Concept Bottleneck Models (CMBs)
[Koh et al '20, Yang et al '23, Yuan et al '22 ]
- Need to engineer a (large) concept bank
- Performance hit w.r.t. original predictor
\(\tilde{Y} = \hat w^\top Z\)
\(\hat w_j\) is the importance of the \(j^{th}\) concept

Desiderata
- Precise testing with guarantees (Type 1 error/FDR control)
- Fixed original predictor (post-hoc)
- Global and local importance notions
- Testing for any concepts (no need for large concept banks)

Precise notions of semantic importance



\(C = \{\text{``cute''}, \text{``whiskers''}, \dots \}\)
Global Importance
\(H^G_{0,j} : \hat{Y} \perp\!\!\!\perp Z_j \)
Global Conditional Importance
\(H^{GC}_{0,j} : \hat{Y} \perp\!\!\!\perp Z_j | Z_{-j}\)
Precise notions of semantic importance



Global Importance
\(C = \{\text{``cute''}, \text{``whiskers''}, \dots \}\)
\(H^G_{0,j} : g(f(X)) \perp\!\!\!\perp c_j^\top f(X) \)
Global Conditional Importance
\(H^{GC}_{0,j} : g(f(X)) \perp\!\!\!\perp c_j^\top f(X) | C_{-j}^\top f(X)\)
\(H^G_{0,j} : \hat{Y} \perp\!\!\!\perp Z_j \)
\(H^{GC}_{0,j} : \hat{Y} \perp\!\!\!\perp Z_j | Z_{-j}\)
Precise notions of semantic importance
"The classifier (its distribution) does not change if we condition
on concepts \(S\) vs on concepts \(S\cup\{j\} \)"
\(C = \{\text{``cute''}, \text{``whiskers''}, \dots \}\)
Local Conditional Importance

Tightly related to Shapley values
\[H^{j,S}_0:~ g({\tilde H_{S \cup \{j\}}}) \overset{d}{=} g(\tilde H_S), \qquad \tilde H_S \sim P_{H|Z_S = C_S^\top f(x)} \]
[Teneggi et al, The Shapley Value Meets Conditional Independence Testing, 2023]
Precise notions of semantic importance
"The classifier (its distribution) does not change if we condition
on concepts \(S\) vs on concepts \(S\cup\{j\} \)"


\(\hat{Y}_\text{gas pump}\)

\(Z_S\cup Z_{j}\)
\(Z_{S}\)
\(Z_j=\)
Local Conditional Importance
\[H^{j,S}_0:~ g({\tilde H_{S \cup \{j\}}}) \overset{d}{=} g(\tilde H_S), \qquad \tilde H_S \sim P_{H|Z_S = C_S^\top f(x)} \]
Precise notions of semantic importance
"The classifier (its distribution) does not change if we condition
on concepts \(S\) vs on concepts \(S\cup\{j\} \)"



\(\hat{Y}_\text{gas pump}\)
\(\hat{Y}_\text{gas pump}\)

\(Z_S\cup Z_{j}\)
\(Z_{S}\)
\(Z_S\cup Z_{j}\)
\(Z_{S}\)
Local Conditional Importance

\(Z_j=\)
\(Z_j=\)
\[H^{j,S}_0:~ g({\tilde H_{S \cup \{j\}}}) \overset{d}{=} g(\tilde H_S), \qquad \tilde H_S \sim P_{H|Z_S = C_S^\top f(x)} \]
Testing by betting
\(H^G_{0,j} : \hat{Y} \perp\!\!\!\perp Z_j \iff P_{\hat{Y},Z_j} = P_{\hat{Y}} \times P_{Z_j}\)
Testing importance via two-sample tests
\(H^{GC}_{0,j} : \hat{Y} \perp\!\!\!\perp Z_j | Z_{-j} \iff P_{\hat{Y}Z_jZ_{-j}} = P_{\hat{Y}\tilde{Z}_j{Z_{-j}}}\)
\(\tilde{Z_j} \sim P_{Z_j|Z_{-j}}\)
[Shaer et al, 2023]
[Teneggi et al, 2023]
\[H^{j,S}_0:~ g({\tilde H_{S \cup \{j\}}}) \overset{d}{=} g(\tilde H_S), \qquad \tilde H_S \sim P_{H|Z_S = C_S^\top f(x)} \]
Testing by betting
[Shaer et al. 2023, Shekhar and Ramdas 2023 ]
Goal: Test a null hypothesis \(H_0\) at significance level \(\alpha\)
Standard testing by p-values
Collect data, then test, and reject if \(p \leq \alpha\)
Online testing by e-values
Any-time valid inference, monitor online and reject when \(e\geq 1/\alpha\)
- Consider a wealth process
\(K_0 = 1;\)
\(\text{for}~ t = 1, \dots \\ \)
Reject \(H_0\) when \(K_t \geq 1/\alpha\)

Online testing by e-values
[Shaer et al. 2023, Shekhar and Ramdas 2023 ]
Fair game: \(~~\mathbb E_{H_0}[\kappa_t | \text{Everything seen}_{t-1}] = 0\)
\(v_t \in (0,1):\) betting fraction
\(\kappa_t \in [-1,1]\) payoff
\( K_t = K_{t-1}(1+\kappa_t v_t)\)
Testing by betting via SKIT (Podkopaev et al., 2023)

Online testing by e-values
\(v_t \in (0,1):\) betting fraction
\(H_0: ~ P = Q\)
\(\kappa_t = \text{tahn}({\color{teal}\rho(X_t)} - {\color{teal}\rho(Y_t)})\)
Payoff function
\({\color{black}\text{MMD}(P,Q)} = \underset{\rho \in R : \|\rho\|_\mathcal{R} \leq 1}{\sup} \mathbb E_P [\rho(X)] - \mathbb E_Q [\rho(Y)]\)
\({\color{teal}\rho} = \underset{\rho\in \mathcal R:\|\rho\|_\mathcal R\leq 1}{\arg\sup} ~\mathbb E_P [\rho(X)] - \mathbb E_Q[\rho(Y)]\)
\( K_t = K_{t-1}(1+\kappa_t v_t)\)
Data efficient
Rank induced by rejection time
Testing by betting via SKIT (Podkopaev et al., 2023)

rejection time


rejection rate
Important Semantic Concepts
(Reject \(H_0\))

Unimportant Semantic Concepts
(fail to reject \(H_0\))

Results
Type 1 error control
False discovery rate control
Results: CUB dataset

Results: Imagenette

Global Importance
Results: Imagenette

Global Conditional Importance

Results: Imagenette


Results: Imagenette

Results: Imagenette
Semantic comparison of vision-language models
-
Exciting open problems to making AI tools safe, trustworthy and interpretable
-
Importance in clear definitions and guarantees
Concluding Remarks
* Fang, Z., Buchanan, S., & J.S. (2023). What's in a Prior? Learned Proximal Networks for Inverse Problems. International Conference on Learning Representations (ICLR 2024). * Teneggi, J. and J.S. I Bet You Did Not Mean That: Testing Semantic Importance via Betting. NeurIPS 2024 (to appear).


...that's it!









Zhenghan Fang
JHU
Jacopo Teneggi
JHU
Beepul Bharti
JHU
Sam Buchanan
TTIC
Yaniv Romano
Technion

Appendix
Learned Proximal Networks
Convergence guarantees for PnP
- [Sreehari et al., 2016; Sun et al., 2019; Chan, 2019; Teodoro et al., 2019]
Convergence of PnP for non-expansive denoisers. -
[Ryu et al, 2019]
Convergence for close to contractive operators - [Xu et al, 2020]
Convergence of Plug-and-Play priors with MMSE denoisers -
[Hurault et al., 2022]
Lipschitz-bounded denoisers
-
Sensitivity or Gradient-based perturbations
-
Shapley coefficients
-
Variational formulations
-
Counterfactual & causal explanations
LIME [Ribeiro et al, '16], CAM [Zhou et al, '16], Grad-CAM [Selvaraju et al, '17]
Shap [Lundberg & Lee, '17], ...
RDE [Macdonald et al, '19], ...
[Sani et al, 2020] [Singla et al '19],..

Post-hoc Interpretability in Image Classification
efficiency
nullity
symmetry
exponential complexity
Lloyd S Shapley. A value for n-person games. Contributions to the Theory of Games, 2(28):307–317, 1953.
Let be an -person cooperative game with characteristic function
How important is each player for the outcome of the game?
marginal contribution of player i with coalition S
Shapley values
Shap-Explanations
inputs
responses
predictor
How important is feature \(x_i\) for \(f(x)\)?
\(X_{S_j^c}\sim \mathcal D_{X_{S_j}={x_{S_j}}}\)
Scott Lundberg and Su-In Lee. A Unified Approach to Interpreting Model Predictions, NeurIPS , 2017
Shap-Explanations
inputs
responses
How important is feature \(x_i\) for \(f(x)\)?
predictor
\(X_{S_j^c}\sim \mathcal D_{X_{S_j}={x_{S_j}}}\)
Scott Lundberg and Su-In Lee. A Unified Approach to Interpreting Model Predictions, NeurIPS , 2017
Question 2) What do these coefficients mean, really?
Precise notions of importance
Formal Feature Importance
[Candes et al, 2018]
Question 2) What do these coefficients mean, really?
Precise notions of importance

XRT: eXplanation Randomization Test
returns a \(\hat{p}_{i,S}\) for the test above
How do we test?
Local Feature Importance
Precise notions of importance
Local Feature Importance

Given the Shapley coefficient of any feature
Then
and the (expected) p-value obtained for , i.e. ,
Theorem:
Teneggi, Bharti, Romano, and S. "SHAP-XRT: The Shapley Value Meets Conditional Independence Testing." TMLR (2023).
Question 3)
How to go beyond input-features explanations?
Shap-Explanations
Question 1)
Can we resolve the computational bottleneck (and when) ?
Question 2)
What do these coefficients mean, really?
Question 3)
How to go beyond input-features explanations?

We focus on data with certain structure:
Example:
if contains a sick cell



Hierarchical Shap (h-Shap)
Question 1) Can we resolve the computational bottleneck (and when) ?
Theorem (informal)
-
hierarchical Shap runs in linear time
-
Under A1, h-Shap \(\to\) Shapley
[Teneggi, Luster & S., IEEE TPAMI, 2022]






We focus on data with certain structure:
Hierarchical Shap (h-Shap)
Question 1) Can we resolve the computational bottleneck (and when) ?
Theorem (informal)
-
h-Shap runs in linear time
-
Under A1, h-Shap \(\to\) Shapley
Fast hierarchical games for image explanations, Teneggi, Luster & S., IEEE Transactions on Pattern Analysis and Machine Intelligence, 2022





We focus on data with certain structure:
Theorem (informal)
-
h-Shap runs in linear time
-
Under A1, h-Shap \(\to\) Shapley
Hierarchical Shap (h-Shap)



Hierarchical Shap (h-Shap)

Fast hierarchical games for image explanations, Teneggi, Luster & Sulam, IEEE Transactions on Pattern Analysis and Machine Intelligence, 2022




Hierarchical Shap (h-Shap)
inputs
responses
predictor
Shap-Explanations
inputs
responses
predictor
Shap-Explanations
Shap-Explanations



inputs
responses
predictor
Scott Lundberg and Su-In Lee. A Unified Approach to Interpreting Model Predictions, NeurIPS , 2017
efficiency
nullity
symmetry
exponential complexity
Shap-Explanations
inputs
responses
predictor





We focus on data with certain structure:
Theorem (informal)
-
h-Shap runs in linear time
-
Under A1, h-Shap \(\to\) Shapley
Hierarchical Shap (h-Shap)



Hierarchical Shap (h-Shap)



Teneggi, Luster & S. Fast hierarchical games for image explanations, IEEE Transactions on Pattern Analysis and Machine Intelligence, 2022




Hierarchical Shap (h-Shap)








What has my model learned? GTech
By Jeremias Sulam
What has my model learned? GTech
- 95