How to Trust Your
Diffusion Model
Jacopo Teneggi, Jeremias Sulam
SPIE Photonics West 2024
January 28, 2024


A Step Back: Responsible ML
discriminative model
→
input
prediction
Real-world
performance
Reproducibility
Explainability
Fairness
Privacy
→
A Step Back: Responsible ML
discriminative model
→
input
prediction
Real-world
performance
Reproducibility
Explainability
Fairness
Privacy
→
Explaining with The Shapley Value
=
"hemorrhage"

f
(
)
What are the important parts of
the image for this prediction?
JT, Alexandre Luster, JS (2021) "Fast Hierarchical Games for Image Explanations", TPAMI
JT*, Beepul Bharti*, Yaniv Romano, JS (2023) "SHAP-XRT: The Shapley Value Meets Conditional Independence Testing", TMLR
h-Shap
Efficient computation with accuracy guarantees
SHAP-XRT
Statistical meaning of
large Shapley Values
Explaining with The Shapley Value
Deploying these methods to inform
better practices in real-world scenarios

JT, Paul H Yi, JS (2023) "Examination-level Supervision for Deep Learning-based Intracranial Hemorrhage Detection on Head CT", Radiology: AI
A Different Scenario
generative
model
→
input
samples
→
For example:
- Large Language Models
- Text-to-image Diffusion Models
- Solving Inverse Problems Stochastically
A Different Scenario
generative
model
→
input
samples
→
Real-world
performance
Reproducibility
Explainability
Fairness
Privacy
Real-world
performance
Reproducibility
Explainability
Fairness
Privacy
?
Running Example: CT Denoising
For an observation y
y=x+ϵ, ϵ∼N(0,σ2I)
reconstruct x with
F(y)∼Qy≈p(x∣y)
x
y
F(y)

Is the Model Right?




Contributions: K-RCPS
Conformal prediction
On the same observation
samples are concentrated
Conformal risk control
For future observations
ground-truth is close
Statistically-valid uncertainty
quantification with shortest intervals
JT, Matt Tivnan, J W Stayman, JS (2023) "How to Trust Your Diffusion Model: A Convex Optimization Approach to Conformal Risk Control", ICML
Calibrated Quantiles
Lemma For pixel j
I(y)j=[lower,upper]
provides entrywise coverage, i.e.
P[next samplej∈I(y)j]≥1−α
Definition For m samples and miscoverage level α
0
1
m⌊(m+1)α/2⌋
m⌈(m+1)(1−α/2)⌉
I(y)
Calibrated Quantiles

x
y
lower
upper
interval
Risk Controlling Prediction Sets
Definition For risk level ϵ, failure probability δ
P[E[fraction of pixels not in intervals]≤ϵ]≥1−δ
0
1
I(y)
x
ground-truth is
not contained
Risk Controlling Prediction Sets
Procedure For pixel j
Iλ(y)j=[lower−λ,upper+λ]
choose
λ^=inf{λ∈R: ∀λ′≥λ, risk(λ′)≤ϵ}
ground-truth is
contained
0
1
I(y)
λ
x
Risk Controlling Prediction Sets
Definition For risk level ϵ, failure probability δ
P[E[number of pixels not in intervals]≤ϵ]≥1−δ
Procedure For pixel j
Iλ(y)j=[lower−λ,upper+λ]
choose
λ^=inf{λ∈R: ∀λ′≥λ, risk(λ′)≤ϵ}
↓
Using the same λ for all pixels is
suboptimal for interval length
↓
Risk Controlling Prediction Sets
Definition For risk level ϵ, failure probability δ
P[E[number of pixels not in intervals]≤ϵ]≥1−δ
Procedure For pixel j
Iλ(y)j=[lower−λ,upper+λ]
choose
λ^=inf{λ∈R: ∀λ′≥λ, risk(λ′)≤ϵ}
↓
We want to find the shortest
intervals that control risk
K-RCPS: High-dimensional Risk Control
scalar λ∈R
→
vector λ∈Rd
Iλ(y)j=[low−λ,up+λ]
→
Iλ(y)j=[low−λj,up+λj]

→

RCPS (λ1=λ2=λ)
RCPS
K-RCPS
gain
interval length
K-RCPS: The Procedure
1. Solve
λ~K=argmin k∈[K]∑nkλk s.t. empirical risk≤ϵ
2. Choose
β^=inf{β∈R: ∀β′≥β, risk(Mλ~K+β′)≤ϵ}
For a K-partition of the pixels M∈{0,1}d×K

K=4
K=8
K=32
K-RCPS: Results


λ^K
conformalized uncertainty maps
K=4
K=8
Thank You!

Jacopo Teneggi

Matt Tivnan

Web Stayman

Jeremias Sulam

GitHub Repo
[SPIE] K-RCPS
By Jacopo Teneggi
[SPIE] K-RCPS
- 83