Jacopo Teneggi, Matthew Tivnan, Webster J. Stayman, Jeremias Sulam
Poster Session 2 #203 @ ICML 2023
How to Trust Your Diffusion Model:
A Convex Optimization Approach to Conformal Risk Control
Motivation
From Song et el. "Score-Based Generative Modeling through Stochastic Differential Equations" (2021)
Diffusion models can generate
varied and high-quality images
Setting
Stochastic image denoising
\[y = x + v,~v \sim \mathcal{N}(0, \mathbb{I}\sigma_0^2)\]
sample \(x\) from \(y\) via reverse-time diffusion
\[F(y) \sim \mathcal{Q}_y \approx p(x | y)\]
\(x\)
\(y\)
\(F(y)\)
\(x\)
\(y\)
\(F(y)\)
\(x\)
\(y\)
\(F(y)\)
\(x\)
\(y\)
\(F(y)\)
Open Questions
Reverse-time diffusion
\[F(y) \sim \mathcal{Q}_y \approx p(x | y)\]
Q1: How concentrated are the
samples on the same observation?
Q2: How far are the reconstructed images from the ground truth?
\[\downarrow\]
Calibrated Quantiles
Provide Entrywise Coverage
Lemma (informal). For a miscoverage level \(\alpha \in (0, 1)\), the entrywise calibrated quantiles over \(F(y)^{(1)}, \dots, F(y)^{(m)}\)
\[\mathcal{I}^{\alpha}(y)_j = [\hat{l}_{j,\alpha/2}, \hat{u}_{j,1 - \alpha/2}]\]
guarantee that for each feature \(j \in [d]\)
\[\mathbb{P}[F(y)_j \in \mathcal{I}^{\alpha}(y)_j] \geq 1 - \alpha\]
Optimal Mean Length Risk Control
Risk Controlling Prediction Set (RCPS) [Bates et al., 2021]
Risk level \(\epsilon\) and failure probability \(\delta\)
\[\hat{\lambda} = \inf \{\lambda \in \mathbb{R}:~\text{UCB}(\mathcal{S}_{\text{cal}}, \lambda, \delta) \leq \epsilon,~\forall \lambda' \geq \lambda\}\]
guarantees conformal risk control
\[\mathbb{P}[\mathbb{E}[\text{number of pixels not in intervals}] \leq \epsilon] \geq 1 - \delta\]
Define
\[\mathcal{I}_\lambda(y)_j = [\hat{l}_j - \lambda, \hat{u}_j + \lambda],~\lambda \in \mathbb{R}\]
Q: How many ground truth features \(x_j\) are outside \(\mathcal{I}(y)_j\)?
Optimal Mean Length Risk Control
Limitation: choosing a scalar \(\lambda\) is suboptimal
for the mean interval length of \(\mathcal{I}_{\lambda}(y)\)
Extend to \(\bm{\lambda} \in \mathbb{R}^d\) with
\[\mathcal{I}_\lambda(y)_j = [\hat{l}_j - \lambda_j, \hat{u}_j + \lambda],~\lambda_j \in \mathbb{R}\]
and minimize mean interval length
\[\hat{\bm{\lambda}} = \underset{\hat{\lambda} \in \mathbb{R}^d}{\arg\min} \sum_{j \in [d]} \lambda_j~\quad~\text{s.t. risk is controlled}\]
Optimal Mean Length Risk Control
Limitation: choosing a scalar \(\lambda\) is suboptimal
for the mean interval length of \(\mathcal{I}_{\lambda}(y)\)
Extend to \(\bm{\lambda} \in \mathbb{R}^d\) with
\[\mathcal{I}_\lambda(y)_j = [\hat{l}_j - \lambda_j, \hat{u}_j + \lambda],~\lambda_j \in \mathbb{R}\]
and minimize mean interval length
\[\hat{\bm{\lambda}} = \underset{\hat{\lambda} \in \mathbb{R}^d}{\arg\min} \sum_{j \in [d]} \lambda_j~\quad~\text{s.t. risk is controlled}\]
The constraint set is not convex
Optimal Mean Length Risk Control
We propose a convex upper bound
\[\ell^{\gamma}(x, \mathcal{I}_{\bm{\lambda}}(y))\]
Optimal Mean Length Risk Control
risk is
controlled
risk is not controlled
\(K\)-RCPS procedure
0. User-defined
\[\text{risk tolerance}~\epsilon,~\text{failure probability}~\delta\]
\[\text{partition matrix}~M \in \{0,1\}^{d \times K}\]
Optimal Mean Length Risk Control
risk is
controlled
risk is not controlled
\(K\)-RCPS procedure
1. Solve
\[\tilde{\bm{\lambda}}_K = \underset{\bm{\lambda} \in \mathbb{R}^K}{\arg\min} \sum_{k \in [K]} n_k\lambda_k~\quad~\text{s.t. convex upper bound} \leq \epsilon\]
0. User-defined
\[\text{risk tolerance}~\epsilon,~\text{failure probability}~\delta\]
\[\text{partition matrix}~M \in \{0,1\}^{d \times K}\]
Optimal Mean Length Risk Control
risk is
controlled
risk is not controlled
\(K\)-RCPS procedure
1. Solve
\[\tilde{\bm{\lambda}}_K = \underset{\bm{\lambda} \in \mathbb{R}^K}{\arg\min} \sum_{k \in [K]} n_k\lambda_k~\quad~\text{s.t. convex upper bound} \leq \epsilon\]
0. User-defined
\[\text{risk tolerance}~\epsilon,~\text{failure probability}~\delta\]
\[\text{partition matrix}~M \in \{0,1\}^{d \times K}\]
2. Backtrack along \(M\tilde{\lambda}_K +\beta\mathbb{1}\)
\[\hat{\beta} = \inf\{\beta \in \mathbb{R}:~\text{UCB}(\mathcal{S}_{\text{cal}}, M\tilde{\bm{\lambda}}_K + \beta'\mathbb{1}, \delta) < \epsilon,~\forall \beta' \geq \beta\}\]
risk is
controlled
risk is not controlled
Optimal Mean Length Risk Control
risk is
controlled
risk is not controlled
\(K\)-RCPS procedure
1. Solve
\[\tilde{\bm{\lambda}}_K = \underset{\bm{\lambda} \in \mathbb{R}^K}{\arg\min} \sum_{k \in [K]} n_k\lambda_k~\quad~\text{s.t. convex upper bound} \leq \epsilon\]
0. User-defined
\[\text{risk tolerance}~\epsilon,~\text{failure probability}~\delta\]
\[\text{partition matrix}~M \in \{0,1\}^{d \times K}\]
2. Backtrack along \(M\tilde{\lambda}_K +\beta\mathbb{1}\)
\[\hat{\beta} = \inf\{\beta \in \mathbb{R}:~\text{UCB}(\mathcal{S}_{\text{cal}}, M\tilde{\bm{\lambda}}_K + \beta'\mathbb{1}, \delta) < \epsilon,~\forall \beta' \geq \beta\}\]
3. Output
\[\hat{\bm{\lambda}}_K = M\tilde{\bm{\lambda}}_K + \hat{\beta}\mathbb{1}\]
risk is
controlled
risk is not controlled
\(\hat{\bm{\lambda}}_K\)
Optimal Mean Length Risk Control
RCPS \((\lambda_1 = \lambda_2 = \lambda)\)
\(K\)-RCPS
gain
Optimal Mean Length Risk Control
Theorem (informal). For any partition matrix \(M \in \{0, 1\}^{d \times K}\),
\[\hat{\bm{\lambda}}_K = M\tilde{\bm{\lambda}}_K + \hat{\beta}\mathbb{1}\]
with
\[\tilde{\bm{\lambda}}_K = \underset{\bm{\lambda} \in \mathbb{R}^K}{\arg\min} \sum_{k \in [K]} n_k\lambda_k~\quad~\text{s.t. convex upper bound} \leq \epsilon\]
and
\[\hat{\beta} = \inf\{\beta \in \mathbb{R}:~\text{UCB}(\mathcal{S}_{\text{cal}}, M\tilde{\bm{\lambda}} + \beta'\mathbb{1}, \delta) < \epsilon,~\forall \beta' \geq \beta\}\]
provide risk control at level \(\epsilon\) with failure probability \(\delta\)
Optimal Mean Length Risk Control
Thank you!
Jacopo Teneggi
Matt Tivnan
Web Stayman
Jeremias Sulam
Poster #203
Poster Session 2, Tue. July 25
Link to code
How to Trust Your Diffusion Model
By Jacopo Teneggi
How to Trust Your Diffusion Model
- 50