Yes, my network works!
Jeremias Sulam
... but what did it learn?
Biomedical Engineering Seminar
Yale University










"The biggest lesson that can be read from 70 years of AI research is that general methods that leverage computation are ultimately the most effective, and by a large margin. [...] Seeking an improvement that makes a difference in the shorter term, researchers seek to leverage their human knowledge of the domain, but the only thing that matters in the long run is the leveraging of computation. [...]
We want AI agents that can discover like we can, not which contain what we have discovered."
The Bitter Lesson, Rich Sutton 2019


What did my model learn?
PART I
Inverse Problems
PART II
Image Classification
Two vignettes on medical imaging


Inverse Problems

measurements
reconstruction
Inverse Problems

measurements
reconstruction

Magnetic Susceptibility Imaging
degree to which a material is magnetized when placed in a magnetic field

[from talk by S. Bollmann]
- Heme Iron (in RBC)
paramagnetic
- Myelin
diamagnetic



- Very important implications for neurodegenerative disorders



[Li et al, 2012]
Magnetic Susceptibility Imaging







Calculation of susceptibility through multiple orientation sampling (COSMOS)
[Liu et al, 2009]


As close as possible to "ground truth" in vivo
Impractical in the clinic




dipole inversion via
Image Priors

Deep Learning in Inverse Problems

Option A: One-shot methods
Given enough training pairs \({(x_i,y_i)}\) train a network
\(f_\theta(y) = g_\theta(A^+y) \approx x\)

[Mousavi & Baraniuk, 2017]
[Ongie, Willet, et al, 2020]
Deep Learning in Inverse Problems
Option A: One-shot methods
Given enough training pairs \({(x_i,y_i)}\) train a network
\(f_\theta(y) = g_\theta(A^+y) \approx x\)
Option B: data-driven regularizer
[Lunz, Öktem, Schönlieb, 2020][Bora et al, 2017][Romano et al, 2017][Ye Tan, ..., Schönlieb, 2024]
Deep Learning in Inverse Problems
Proximal Gradient Descent: \( x^{k+1} = \text{prox}_R \left(x^k - \eta A^T(Ax^k-y)\right) \)
What if we don't know \(R(x)\), or \(\text{prox}_R\) ?
Train via:
Collect data:
pick a function class
Single Phase Quantitative Susceptibility Mapping


[Lai, Aggarwal, van Zijl, Li, Sulam, Learned Proximal Networks for Quantitative Susceptibility Mapping, MICCAI 2020]
(isotropic model)

Single Phase Quantitative Susceptibility Mapping

unseen angles during training

[Lai, Aggarwal, van Zijl, Li, Sulam, Learned Proximal Networks for Quantitative Susceptibility Mapping, MICCAI 2020]
(isotropic model)
Susceptibility Tensor Imaging



(anisotropic model)











Susceptibility Tensor Imaging
What are these networks actually computing?
Deep Learning in Inverse Problems
Proximal Gradient Descent: \( x^{t+1} = \text{prox}_R \left(x^t - \eta A^T(Ax^t-y)\right) \)
... a denoiser
Deep Learning in Inverse Problems
any latest NN denoiser


[Venkatakrishnan et al., 2013; Zhang et al., 2017b; Meinhardt et al., 2017; Zhang et al., 2021; Kamilov et al., 2023b; Terris et al., 2023]
[Gilton, Ongie, Willett, 2019]
Proximal Gradient Descent: \( x^{t+1} = {\color{red}f_\theta} \left(x^t - \eta A^T(A(x^t)-y)\right) \)
Question 1)
When will \(f_\theta(x)\) compute a \(\text{prox}_R(x)\) ? and for what \(R(x)\)?
Deep Learning in Inverse Problems
\(\mathcal H_\text{prox} = \{f : \text{prox}_R~ \text{for some }R\}\)
\(\mathcal H = \{f: \mathbb R^n \to \mathbb R^n\}\)
Question 1)
When will \(f_\theta(x)\) compute a \(\text{prox}_R(x)\) ? and for what \(R(x)\)?
Question 2)
Can we estimate the "correct" prox?
Deep Learning in Inverse Problems
\(\mathcal H_\text{prox} = \{f : \text{prox}_R~ \text{for some }R\}\)
\(\mathcal H = \{f: \mathbb R^n \to \mathbb R^n\}\)

Interpretable Inverse Problems
Question 1)
When will \(f_\theta(x)\) compute a \(\text{prox}_R(x)\) ?
Theorem [Gibonval & Nikolova, 2020]

\( f(x) \in \text{prox}_R(x) ~\Leftrightarrow \exist ~ \text{convex l.s.c.}~ \psi: \mathbb R^n\to\mathbb R : f(x) \in \partial \psi(x)~\)
Interpretable Inverse Problems
Question 1)
When will \(f_\theta(x)\) compute a \(\text{prox}_R(x)\) ?
\(R(x)\) need not be convex!
Theorem [Gibonval & Nikolova, 2020]

Learned Proximal Networks
Take \(f_\theta(x) = \nabla \psi_\theta(x)\) for convex (and differentiable) \(\psi_\theta\)

\( f(x) \in \text{prox}_R(x) ~\Leftrightarrow \exist ~ \text{convex l.s.c.}~ \psi: \mathbb R^n\to\mathbb R : f(x) \in \partial \psi(x)~\)
Given \(f_\theta(x)\), we can compute \(R(x)\) via a LP
Interpretable Inverse Problems
If so, can you know for what \(R(x)\)?
Yes
[Gibonval & Nikolova]
Easy! \[{\color{grey}y^* =} \arg\min_{y} \psi(y) - \langle y,x\rangle {\color{grey}= \hat{f}_\theta^{-1}(x)}\]

Interpretable Inverse Problems
Question 2)
Could we have \(R(x) = -\log p_x(x)\)?
(we don't know \(p_x\)!)
i.e. \(f_\theta(y) = \text{prox}_R(y) = \texttt{MAP}(x|y)\)
Which loss function?
Interpretable Inverse Problems
i.e. \(f_\theta(y) = \text{prox}_R(y) = \texttt{MAP}(x|y)\)

Theorem (informal)
Proximal Matching Loss
\(\gamma\)
Question 2)
Could we have \(R(x) = -\log p_x(x)\)?
(we don't know \(p_x\)!)
Learned Proximal Networks






























Learned Proximal Networks

Fang, Buchanan & S. What's in a Prior? Learned Proximal Networks for Inverse Problems. ICLR 2024.

Learned Proximal Networks
+ Convergence guarantees
What did my model learn?
PART I
Inverse Problems
PART II
Image Classification
Two vignettes on medical imaging


Interpretable Image Classification

What parts of the image are important for this prediction?
What are the subsets of the input so that

-
Sensitivity or Gradient-based perturbations
-
Shapley coefficients
-
Variational formulations
-
Counterfactual explanations
LIME [Ribeiro et al, '16], CAM [Zhou et al, '16], Grad-CAM [Selvaraju et al, '17]
Shap [Lundberg & Lee, '17], ...
RDE [Macdonald et al, '19], ...
[Sani et al, 2020] [Singla et al '19],..

Interpretability
efficiency
nullity
symmetry
exponential complexity
Lloyd S Shapley. A value for n-person games. Contributions to the Theory of Games, 2(28):307–317, 1953.
Let be an -person cooperative game with characteristic function
How important is each player for the outcome of the game?
marginal contribution of player i with coalition S
Shapley values
Shap-Explanations
inputs
responses
predictor
How important is feature \(x_i\) for \(f(x)\)?
\(X_{S_j^c}\sim \mathcal D_{X_{S_j}={x_{S_j}}}\)
Scott Lundberg and Su-In Lee. A Unified Approach to Interpreting Model Predictions, NeurIPS , 2017
Shap-Explanations
inputs
responses
How important is feature \(x_i\) for \(f(x)\)?
predictor
\(X_{S_j^c}\sim \mathcal D_{X_{S_j}={x_{S_j}}}\)
Scott Lundberg and Su-In Lee. A Unified Approach to Interpreting Model Predictions, NeurIPS , 2017
Shap-Explanations
Question 1)
Can we resolve the computational bottleneck (and when) ?
Question 2)
What do these coefficients mean, really?
Question 3)
How to go beyond input-features explanations?

We focus on data with certain structure:
Example:
if contains a sick cell



Hierarchical Shap (h-Shap)
Question 1) Can we resolve the computational bottleneck (and when) ?
Theorem (informal)
-
hierarchical Shap runs in linear time
-
Under A1, h-Shap \(\to\) Shapley
[Teneggi, Luster & S., IEEE TPAMI, 2022]


We focus on data with certain structure:
Example:
if contains a sick cell



Hierarchical Shap (h-Shap)
Question 1) Can we resolve the computational bottleneck (and when) ?
Theorem (informal)
-
hierarchical Shap runs in linear time
-
Under A1, h-Shap \(\to\) Shapley
[Teneggi, Luster & S., IEEE TPAMI, 2022]






Question 2) What do these coefficients mean, really?
Precise notions of importance
Formal Feature Importance
[Candes et al, 2018]
Question 2) What do these coefficients mean, really?
Precise notions of importance

XRT: eXplanation Randomization Test
returns a \(\hat{p}_{i,S}\) for the test above
How do we test?
Local Feature Importance
Precise notions of importance
Local Feature Importance

Given the Shapley coefficient of any feature
Then
and the (expected) p-value obtained for , i.e. ,
Theorem:
Teneggi, Bharti, Romano, and S. "SHAP-XRT: The Shapley Value Meets Conditional Independence Testing." TMLR (2023).
Question 3)
How to go beyond input-features explanations?



Is the piano important for \(\hat Y = \text{cat}\) given that there is a cute mammal?
Testing Semantic Importance
Question 3) How to go beyond input-features explanations?
Precise notions of semantic importance
semantics \(Z = c^TH\)

embeddings \(H = f(X)\)
predictions \(\hat{Y} = g(H)\)

Concept Bottleneck Models (CBM)
[Koh et al '20, Yang et al '23, Yuan et al '22, Yuksekgonul '22 ]
Precise notions of semantic importance
semantic XRT
\[H^{j,S}_0:~g(\widetilde{H}_{S \cup \{j\}}) \overset{d}{=} g(\widetilde{H}_S),\quad\widetilde{H}_C \sim P_{H | Z_C = z_C}\]
"The classifier (its distribution) does not change if we condition
on concepts \(S\) vs on concepts \(S\cup\{j\} \)"

semantics \(Z = c^TH\)
embeddings \(H = f(X)\)
predictions \(\hat{Y} = g(H)\)
Precise notions of semantic importance
semantic XRT
\[H^{j,S}_0:~g(\widetilde{H}_{S \cup \{j\}}) \overset{d}{=} g(\widetilde{H}_S),\quad\widetilde{H}_C \sim P_{H | Z_C = z_C}\]
"The classifier (its distribution) does not change if we condition
on concepts \(S\) vs on concepts \(S\cup\{j\} \)"


\(\hat{Y}_\text{gas pump}\)

\(Z_S\cup Z_{j}\)
\(Z_{S}\)
Precise notions of semantic importance
semantic XRT
\[H^{j,S}_0:~g(\widetilde{H}_{S \cup \{j\}}) \overset{d}{=} g(\widetilde{H}_S),\quad\widetilde{H}_C \sim P_{H | Z_C = z_C}\]
"The classifier (its distribution) does not change if we condition
on concepts \(S\) vs on concepts \(S\cup\{j\} \)"




\(\hat{Y}_\text{gas pump}\)
\(\hat{Y}_\text{gas pump}\)

\(Z_S\cup Z_{j}\)
\(Z_{S}\)
\(Z_S\cup Z_{j}\)
\(Z_{S}\)
Precise notions of semantic importance
semantic XRT
\[H^{j,S}_0:~g(\widetilde{H}_{S \cup \{j\}}) \overset{d}{=} g(\widetilde{H}_S),\quad\widetilde{H}_C \sim P_{H | Z_C = z_C}\]
Testing by Betting
- Instantiate a wealth process
\(K_0 = 1\)
\(K_t = K_{t-1}(1+\kappa_t v_t)\)
- Reject \(H_0\) when \(K_t \geq 1/\alpha\)

[Shaer et al. 2023, Shekhar and Ramdas 2023 ]
Precise notions of semantic importance

Important Semantic Concepts
(Reject \(H_0\))

Unimportant Semantic Concepts
(fail to reject \(H_0\))

- Type 1 error control
- False discovery rate control
rejection rate

rejection time

Precise notions of semantic importance

-
Exciting open problems to making AI tools safe, trustworthy and interpretable
-
Importance in clear definitions and guarantees
Concluding Remarks
* Fang, Z., Buchanan, S., & J.S. (2023). What's in a Prior? Learned Proximal Networks for Inverse Problems. International Conference on Learning Representations. * Teneggi, J., Luster, A., & J.S. (2022). Fast hierarchical games for image explanations. IEEE Transactions on Pattern Analysis and Machine Intelligence. * Teneggi, J., B. Bharti, Y. Romano, and J.S. (2023) SHAP-XRT: The Shapley Value Meets Conditional Independence Testing. Transactions on Machine Learning Research. * Teneggi, J. and J.S. I Bet You Did Not Mean That: Testing Semantic Importance via Betting. NeurIPS 2024 (to appear).


Interpretability
Fairness
Uncertainty Quantification


Inverse Problems
(some) Challenges in Biomedical Data Science


Interpretability
Fairness
Uncertainty Quantification


Inverse Problems
(some) Challenges in Biomedical Data Science



Fairness in Data Science

Is the model fair?



Pneumonia
Clear
95% accurate
Fairness in Data Science

Does your model achieve a \(\Delta_{\text{TPR}}\) of at most (say) 6% ?


Fairness in Data Science




Pneumonia
Clear
95% accurate


Tight upper bounds to fairness violations
(optimally) Actionable
Maximum TPR
discrepancy
True TPR
discrepancy
Bharti, B., Yi, P., & Sulam, J. (2023). Estimating and Controlling for Equalized Odds via Sensitive Attribute Predictors NeurIPS 2023
Fairness in Data Science




Pneumonia
Clear
95% accurate



How to quantify and report uncertainty?






Filtered Back Projection
Deep ConvNet Model
Diffusion Model
Uncertainty Quantification

For an observation \(y\)
\[y = x + \epsilon,~\epsilon \sim \mathcal{N}(0, \sigma^2\mathbb{I})\]
reconstruct \(x\) with
\[\hat{x} = F(y) \sim \mathcal{Q}_y \approx p(x \mid y)\]
Uncertainty Quantification

\(x\)
\(y\)
\(F(y)\)




Lemma
\(\mathcal I(y)\) provides entrywise coverage for pixel \(j\), i.e.
\[\mathbb{P}\left[\text{next sample}_j \in \mathcal{I}(y)_j\right] \geq 1 - \alpha\]
If \[\mathcal{I}(y)_j = \left[ \frac{\lfloor(m+1)Q_{\alpha/2}(y_j)\rfloor}{m} , \frac{\lceil(m+1)Q_{1-\alpha/2}(y_j)\rceil}{m}\right]\]
Uncertainty Quantification

\(0\)
\(1\)
low: \( l(y) \)
\(\mathcal{I}(y)\)
up: \( u(y) \)
(distribution free)

\(x\)
\(y\)
lower
upper
intervals
\(|\mathcal I(y)_j|\)
Uncertainty Quantification



\(0\)
\(1\)

Risk Controlling Prediction Sets
ground-truth is
contained
\(\mathcal{I}(y_j)\)
\(x_j\)
Procedure For pixel \(j\)
\[\mathcal{I}_{\lambda}(y)_j = [\text{lower} - \lambda, \text{upper} + \lambda]\]
choose
\[\hat{\lambda} = \inf\{\lambda \in \mathbb{R}:~\forall \lambda' \geq \lambda,~\text{risk}(\lambda') \leq \epsilon\}\]
ground-truth is
contained
\(0\)
\(1\)
\(\mathcal{I}(y_j)\)
\(\lambda\)
\(x_j\)
Risk Controlling Prediction Sets

Definition For risk level \(\epsilon\), failure probability \(\delta\), \(\mathcal{I}(y_j) \) is a RCPS if
\[\mathbb{P}\left[\mathbb{E}\left[\text{fraction of pixels not in intervals}\right] \leq \epsilon\right] \geq 1 - \delta\]

\(K\)-RCPS: High-dimensional Risk Control

scalar \(\lambda \in \mathbb{R}\)
\(\mathcal{I}_{\lambda}(y)_j = [\text{low} - \lambda, \text{up} + \lambda]\)
\(\rightarrow\)
vector \(\bm{\lambda} \in \mathbb{R}^d\)
\(\rightarrow\)
\(\mathcal{I}_{\bm{\lambda}}(y)_j = [\text{low} - \lambda_j, \text{up} + \lambda_j]\)
Guarantee: \(\mathcal{I}_{\bm{\lambda}}(y)_j = [\text{low} - \lambda_j, \text{up} + \lambda_j]\) are RCPS
For a \(K\)-partition of the pixels \(M \in \{0, 1\}^{d \times K}\)

\(K=4\)
\(K=8\)
\(K=32\)
scalar \(\lambda \in \mathbb{R}\)
\(\rightarrow\)
vector \(\bm{\lambda} \in \mathbb{R}^d\)
\(\mathcal{I}_{\lambda}(y)_j = [\text{low} - \lambda, \text{up} + \lambda]\)
\(\rightarrow\)
\(\mathcal{I}_{\bm{\lambda}}(y)_j = [\text{low} - \lambda_j, \text{up} + \lambda_j]\)
1. Solve
\[\tilde{\bm{\lambda}}_K = \arg\min~\sum_{k \in [K]}n_k\lambda_k~\quad\text{s.t. empirical risk} \leq \epsilon\]
2. Choose
\[\hat{\beta} = \inf\{\beta \in \mathbb{R}:~\forall \beta' \geq \beta,~\text{risk}(M\tilde{\bm{\lambda}}_K + \beta') \leq \epsilon\}\]
\(K\)-RCPS: High-dimensional Risk Control


\(\hat{\lambda}_K\)
conformalized uncertainty maps
\(K=4\)
\(K=8\)

\[\mathbb{P}\left[\mathbb{E}\left[\text{fraction of pixels not in intervals}\right] \leq \epsilon\right] \geq 1 - \delta\]
Teneggi, J., Tivnan, M., Stayman, W., & Sulam, J. (2023, July). How to trust your diffusion model: A convex optimization approach to conformal risk control. In International Conference on Machine Learning. PMLR.
\(K\)-RCPS: High-dimensional Risk Control





Thank you









Zhenghan Fang
JHU
Jacopo Teneggi
JHU
Beepul Bharti
JHU
Sam Buchanan
TTIC
Yaniv Romano
Technion
Appendix
Learned Proximal Networks
Convergence guarantees for PnP
- [Sreehari et al., 2016; Sun et al., 2019; Chan, 2019; Teodoro et al., 2019]
Convergence of PnP for non-expansive denoisers. -
[Ryu et al, 2019]
Convergence for close to contractive operators - [Xu et al, 2020]
Convergence of Plug-and-Play priors with MMSE denoisers -
[Hurault et al., 2022]
Lipschitz-bounded denoisers
Theorem (PGD with Learned Proximal Networks)
Let \(f_\theta = \text{prox}_{\hat{R}} {\color{grey}\text{ with } \alpha>0}, \text{ and } 0<\eta<1/\sigma_{\max}(A) \) with smooth activations
(Analogous results hold for ADMM)
Learned Proximal Networks
Convergence guarantees for PnP
Learned Proximal Networks
Convergence guarantees for PnP
- [Sreehari et al., 2016; Sun et al., 2019; Chan, 2019; Teodoro et al., 2019]
Convergence of PnP for non-expansive denoisers. -
[Ryu et al, 2019]
Convergence for close-to-contractive operators - [Xu et al, 2020]
Convergence of Plug-and-Play priors with MMSE denoisers -
[Hurault et al., 2022]
Lipschitz-bounded denoisers
Theorem (PGD with Learned Proximal Networks)
Let \(f_\theta = \text{prox}_{\hat{R}} {\color{grey}\text{ with } \alpha>0}, \text{ and } 0<\eta<1/\sigma_{\max}(A) \) with smooth activations
(Analogous results hold for ADMM)
Learned Proximal Networks
Convergence guarantees for PnP
inputs
responses
predictor
Shap-Explanations
inputs
responses
predictor
Shap-Explanations
Shap-Explanations



inputs
responses
predictor
Scott Lundberg and Su-In Lee. A Unified Approach to Interpreting Model Predictions, NeurIPS , 2017
efficiency
nullity
symmetry
exponential complexity
Shap-Explanations
inputs
responses
predictor





We focus on data with certain structure:
Theorem (informal)
-
h-Shap runs in linear time
-
Under A1, h-Shap \(\to\) Shapley
Hierarchical Shap (h-Shap)



Hierarchical Shap (h-Shap)



Teneggi, Luster & S. Fast hierarchical games for image explanations, IEEE Transactions on Pattern Analysis and Machine Intelligence, 2022




Hierarchical Shap (h-Shap)







[Chattopadhyay et al, 2024]
Cheaper predictors via Interpretability
Hemorrhage detection in head CT






Image-by-image supervision (strong learner)
true/false
Cheaper predictors via Interpretability
Image-by-image supervision (strong learner)


Study/volume supervision (weak learner)
one label per image!
one label per study!
true/false
true/false
Cheaper predictors via Interpretability


Both methods do as well for case screaning
Teneggi, J., Yi, P. H., & Sulam, J. (2023). Examination-level supervision for deep learning–based intracranial hemorrhage detection at head CT. Radiology: Artificial Intelligence, e230159.
Cheaper predictors via Interpretability
Weak learner is more efficient for detecting positive slices

training labels
Teneggi, J., Yi, P. H., & Sulam, J. (2023). Examination-level supervision for deep learning–based intracranial hemorrhage detection at head CT. Radiology: Artificial Intelligence, e230159.
Cheaper predictors via Interpretability

Teneggi, J., Yi, P. H., & Sulam, J. (2023). Examination-level supervision for deep learning–based intracranial hemorrhage detection at head CT. Radiology: Artificial Intelligence, e230159.

Cheaper predictors via Interpretability
What has my model learned (Yale)
By Jeremias Sulam
What has my model learned (Yale)
- 95