Jacopo Teneggi, Matthew Tivnan, Webster J. Stayman, Jeremias Sulam

Poster Session 2 #203 @ ICML 2023

How to Trust Your Diffusion Model:

A Convex Optimization Approach to Conformal Risk Control

Motivation

From Song et el. "Score-Based Generative Modeling through Stochastic Differential Equations" (2021)

Diffusion models can generate

varied and high-quality images

Setting

Stochastic image denoising

\[y = x + v,~v \sim \mathcal{N}(0, \mathbb{I}\sigma_0^2)\]

sample \(x\) from \(y\) via reverse-time diffusion

\[F(y) \sim \mathcal{Q}_y \approx p(x | y)\]

\(x\)

\(y\)

\(F(y)\)

\(x\)

\(y\)

\(F(y)\)

\(x\)

\(y\)

\(F(y)\)

\(x\)

\(y\)

\(F(y)\)

Open Questions

Reverse-time diffusion

\[F(y) \sim \mathcal{Q}_y \approx p(x | y)\]

Q1: How concentrated are the

samples on the same observation?

Q2: How far are the reconstructed images from the ground truth?

\[\downarrow\]

Calibrated Quantiles

Provide Entrywise Coverage

Lemma (informal). For a miscoverage level \(\alpha \in (0, 1)\), the entrywise calibrated quantiles over \(F(y)^{(1)}, \dots, F(y)^{(m)}\)

\[\mathcal{I}^{\alpha}(y)_j = [\hat{l}_{j,\alpha/2}, \hat{u}_{j,1 - \alpha/2}]\]

guarantee that for each feature \(j \in [d]\)

\[\mathbb{P}[F(y)_j \in \mathcal{I}^{\alpha}(y)_j] \geq 1 - \alpha\]

Optimal Mean Length Risk Control

Risk Controlling Prediction Set (RCPS) [Bates et al., 2021]

Risk level \(\epsilon\) and failure probability \(\delta\)

\[\hat{\lambda} = \inf \{\lambda \in \mathbb{R}:~\text{UCB}(\mathcal{S}_{\text{cal}}, \lambda, \delta) \leq \epsilon,~\forall \lambda' \geq \lambda\}\]

guarantees conformal risk control

\[\mathbb{P}[\mathbb{E}[\text{number of pixels not in intervals}] \leq \epsilon] \geq 1 - \delta\]

Define

\[\mathcal{I}_\lambda(y)_j = [\hat{l}_j - \lambda, \hat{u}_j + \lambda],~\lambda \in \mathbb{R}\]

Q: How many ground truth features \(x_j\) are outside \(\mathcal{I}(y)_j\)?

Optimal Mean Length Risk Control

Limitation: choosing a scalar \(\lambda\) is suboptimal

for the mean interval length of \(\mathcal{I}_{\lambda}(y)\)

Extend to \(\bm{\lambda} \in \mathbb{R}^d\) with

\[\mathcal{I}_\lambda(y)_j = [\hat{l}_j - \lambda_j, \hat{u}_j + \lambda],~\lambda_j \in \mathbb{R}\]

and minimize mean interval length

\[\hat{\bm{\lambda}} = \underset{\hat{\lambda} \in \mathbb{R}^d}{\arg\min} \sum_{j \in [d]} \lambda_j~\quad~\text{s.t. risk is controlled}\]

Optimal Mean Length Risk Control

Limitation: choosing a scalar \(\lambda\) is suboptimal

for the mean interval length of \(\mathcal{I}_{\lambda}(y)\)

Extend to \(\bm{\lambda} \in \mathbb{R}^d\) with

\[\mathcal{I}_\lambda(y)_j = [\hat{l}_j - \lambda_j, \hat{u}_j + \lambda],~\lambda_j \in \mathbb{R}\]

and minimize mean interval length

\[\hat{\bm{\lambda}} = \underset{\hat{\lambda} \in \mathbb{R}^d}{\arg\min} \sum_{j \in [d]} \lambda_j~\quad~\text{s.t. risk is controlled}\]

The constraint set is not convex

Optimal Mean Length Risk Control

We propose a convex upper bound

\[\ell^{\gamma}(x, \mathcal{I}_{\bm{\lambda}}(y))\]

Optimal Mean Length Risk Control

risk is

controlled

risk is not controlled

\(K\)-RCPS procedure

0. User-defined

\[\text{risk tolerance}~\epsilon,~\text{failure probability}~\delta\]

\[\text{partition matrix}~M \in \{0,1\}^{d \times K}\]

Optimal Mean Length Risk Control

risk is

controlled

risk is not controlled

\(K\)-RCPS procedure

1. Solve

\[\tilde{\bm{\lambda}}_K = \underset{\bm{\lambda} \in \mathbb{R}^K}{\arg\min} \sum_{k \in [K]} n_k\lambda_k~\quad~\text{s.t. convex upper bound} \leq \epsilon\]

0. User-defined

\[\text{risk tolerance}~\epsilon,~\text{failure probability}~\delta\]

\[\text{partition matrix}~M \in \{0,1\}^{d \times K}\]

Optimal Mean Length Risk Control

risk is

controlled

risk is not controlled

\(K\)-RCPS procedure

1. Solve

\[\tilde{\bm{\lambda}}_K = \underset{\bm{\lambda} \in \mathbb{R}^K}{\arg\min} \sum_{k \in [K]} n_k\lambda_k~\quad~\text{s.t. convex upper bound} \leq \epsilon\]

0. User-defined

\[\text{risk tolerance}~\epsilon,~\text{failure probability}~\delta\]

\[\text{partition matrix}~M \in \{0,1\}^{d \times K}\]

2. Backtrack along \(M\tilde{\lambda}_K +\beta\mathbb{1}\)

\[\hat{\beta} = \inf\{\beta \in \mathbb{R}:~\text{UCB}(\mathcal{S}_{\text{cal}}, M\tilde{\bm{\lambda}}_K + \beta'\mathbb{1}, \delta) < \epsilon,~\forall \beta' \geq \beta\}\]

risk is

controlled

risk is not controlled

Optimal Mean Length Risk Control

risk is

controlled

risk is not controlled

\(K\)-RCPS procedure

1. Solve

\[\tilde{\bm{\lambda}}_K = \underset{\bm{\lambda} \in \mathbb{R}^K}{\arg\min} \sum_{k \in [K]} n_k\lambda_k~\quad~\text{s.t. convex upper bound} \leq \epsilon\]

0. User-defined

\[\text{risk tolerance}~\epsilon,~\text{failure probability}~\delta\]

\[\text{partition matrix}~M \in \{0,1\}^{d \times K}\]

2. Backtrack along \(M\tilde{\lambda}_K +\beta\mathbb{1}\)

\[\hat{\beta} = \inf\{\beta \in \mathbb{R}:~\text{UCB}(\mathcal{S}_{\text{cal}}, M\tilde{\bm{\lambda}}_K + \beta'\mathbb{1}, \delta) < \epsilon,~\forall \beta' \geq \beta\}\]

3. Output

\[\hat{\bm{\lambda}}_K = M\tilde{\bm{\lambda}}_K + \hat{\beta}\mathbb{1}\]

risk is

controlled

risk is not controlled

\(\hat{\bm{\lambda}}_K\)

Optimal Mean Length Risk Control

RCPS \((\lambda_1 = \lambda_2 = \lambda)\)

\(K\)-RCPS

gain

Optimal Mean Length Risk Control

Theorem (informal). For any partition matrix \(M \in \{0, 1\}^{d \times K}\),

\[\hat{\bm{\lambda}}_K = M\tilde{\bm{\lambda}}_K + \hat{\beta}\mathbb{1}\]

with

\[\tilde{\bm{\lambda}}_K = \underset{\bm{\lambda} \in \mathbb{R}^K}{\arg\min} \sum_{k \in [K]} n_k\lambda_k~\quad~\text{s.t. convex upper bound} \leq \epsilon\]

and

\[\hat{\beta} = \inf\{\beta \in \mathbb{R}:~\text{UCB}(\mathcal{S}_{\text{cal}}, M\tilde{\bm{\lambda}} + \beta'\mathbb{1}, \delta) < \epsilon,~\forall \beta' \geq \beta\}\]

provide risk control at level \(\epsilon\) with failure probability \(\delta\)

Optimal Mean Length Risk Control

Thank you!

Jacopo Teneggi

Matt Tivnan

Web Stayman

Jeremias Sulam

Poster #203

Poster Session 2, Tue. July 25

Link to code

Made with Slides.com