SCALABLE BAYESIAN INFERENCE WITH DIFFERENTIABLE SIMULATORS FOR THE
COSMOLOGICAL ANALYSIS OF THE DESI SPECTROSCOPIC SURVEY

Hugo SIMON-ONFROY, 
PhD student supervised by Arnaud DE MATTIA and François LANUSSE

CSI, 2024/10/18

Benchmarking field-level inference from galaxy surveys

Hugo SIMON-ONFROY,
PhD student supervised by Arnaud DE MATTIA and François LANUSSE

DESI Marseille meeting, 2024/07/12

Field-level inference from galaxy surveys

Hugo SIMON-ONFROY,
PhD student supervised by Arnaud DE MATTIA and François LANUSSE

Rodolphe Clédassou Summer School, 2024/08

The universe recipe (so far)

$$\frac{H}{H_0} = \sqrt{\Omega_r + \Omega_b + \Omega_c+ \Omega_\kappa + \Omega_\Lambda}$$

instantaneous expansion rate

energy content

Cosmological principle + Einstein equation

+ Inflation

\(\delta_L \sim \mathcal G(0, \mathcal P)\)

\(\sigma_8:= \sigma[\delta_L * \boldsymbol 1_{r \leq 8}]\)

initial field

primordial power spectrum

std. of fluctuations smoothed at \(8 \text{ Mpc/h}\)

\(\Omega := \{ \Omega_c, \Omega_b, \Omega_\Lambda, H_0, \sigma_8, n_s,...\}\)

Linear matter spectrum

Structure growth

$$\begin{align*}\operatorname{\boldsymbol{H}}(\Omega\mid \delta_g) &= \boldsymbol{H}(\delta_g \mid \Omega) + \boldsymbol{H}(\Omega) - \boldsymbol{H}(\delta_g)\\&= \boldsymbol{H}(\Omega) - \boldsymbol{I}(\Omega;\delta_g) \leq \boldsymbol{H}(\Omega)\end{align*}$$

$$\boldsymbol{I}(\Omega; \delta_g)$$

\(\boldsymbol{H}(X)\) = missing info on \(X\)

$$\boldsymbol{H}(\Omega)$$

$$\boldsymbol{H}(\delta_g)$$

$$\boldsymbol{H}(\Omega\mid\delta_g)$$

A high-dimensional inference problem

\(\Omega := \{ \Omega_c, \Omega_b, \Omega_\Lambda, H_0, \sigma_8, n_s,...\}\)

Linear matter spectrum

Structure growth

  • Cosmological model links cosmo \(\Omega\) to initial field \(\delta_L\) to galaxy density field \(\delta_g\)
  • Cosmological parameter inference obtained via marginalizing full posterior over initial field$$\boldsymbol{p}(\Omega \mid \delta_g) = \int \boldsymbol{p}(\Omega, \delta_L \mid \delta_g) \;\mathrm d \delta_L$$

A high-dimensional inference problem

How to perform this marginalization?

\(\Omega := \{ \Omega_c, \Omega_b, \Omega_\Lambda, H_0, \sigma_8, n_s,...\}\)

Linear matter spectrum

Structure growth

  • Cosmological model links cosmo \(\Omega\) to initial field \(\delta_L\) to galaxy density field \(\delta_g\)
  • Cosmological parameter inference obtained via marginalizing full posterior over initial field$$\boldsymbol{p}(\Omega \mid \delta_g) = \int \boldsymbol{p}(\Omega, \delta_L \mid \delta_g) \;\mathrm d \delta_L$$

A high-dimensional inference problem

How to perform this marginalization?

Bayes, information flavor

$$\boldsymbol{H}(X\mid Y_1)$$

$$\boldsymbol{H}(X)$$

$$\boldsymbol{H}( Y_1)$$

$$\boldsymbol{I}(X; Y_1)$$

$$\boldsymbol{H}( Y_2)$$

$$\boldsymbol{I}(X\mid Y_1; Y_2)$$

$$\boldsymbol{H}(X\mid Y_1,Y_2)$$

\(\boldsymbol{H}(X)\) = missing information on \(X\) = amount of bits to communicate \(X\)

$$\boldsymbol{H}(X\mid Y_1)$$

$$\begin{align*}\operatorname{\boldsymbol{H}}(X\mid Y) &= \boldsymbol{H}(Y \mid X) + \boldsymbol{H}(X) - \boldsymbol{H}(Y)\\&= \boldsymbol{H}(X) - \boldsymbol{I}(X;Y) \leq \boldsymbol{H}(X)\end{align*}$$

Use summary statistics

Marginalize then sample

  • build a marginal model linking cosmo \(\Omega\) and some summary stat \(s\), then sample from simpler \(\boldsymbol p(\Omega \mid s)\)
  • trade-off analytical/variational tractability vs. info content of \(s\)

summary stat inference

A tractable candidate: the power spectrum

We gotta pump this information up

  • Field-level

  • CNN...

  • WST...

  • Halo, Peak, Void, Split, Hole...

  • 3PCF, Bispectrum

  • 2PCF, Power spectrum

  • 1D-PDF

$$0-$$

$$\boldsymbol H(\delta_g)-$$

  • At large scales, matter density field almost Gaussian so power spectrum is almost lossless compression.
  • To prospect smaller non-Gaussian scales, let's use:

We gotta pump this information up

  • Field-level
     
  • CNN, GNN...
     
  • WST, 1D-PDFs, Holes...
     
  • Peak, Void, Split, Cluster...
     
  • 3PCF, Bispectrum
     
  • 2PCF, Power spectrum

$$0-$$

$$\boldsymbol H(\delta_g)-$$

  • At large scales, matter density field almost Gaussian so power spectrum is almost lossless compression, and is relatively tractable
  • To prospect smaller non-Gaussian scales, let's add:
  • all the data
     
  • learn the stat
     
  • multiscale count
     
  • object correlations
     
  • more correlations
     
  • standard analysis
  1. Marginalize then sample
    • build a marginal model linking cosmo \(\Omega\) and some summary stat \(s\)
    • trade-off analytical/variational tractability vs. info content of \(s\)
  2. Sample then marginalize
    • full history reconstruction without info loss
    • requires simulating LSS formation for each sample

Compress or not compress

How much information can we still gain (not lose)?

Model based field-level inference

summary stat inference

More information is better,

but how much better?

Simulation Based Inference from summary stat
(build a surrogate model from simulations)

Model Based Inference at the field level
(explicitly solve)

How much information can we still gain (not lose)?

field-level inference

summary stat inference

For full field marginalize then sample hardly tractable,
so sample then marginalize

  • sample jointly \(\Omega\) and \(\delta_L\), then marginalize over \(\delta_L\) samples
  • full history reconstruction without info loss
  • requires simulating LSS formation for each sample
  • for probing non-Gaussian scales of interest in DESI volume \(\operatorname{dim}(\delta_L)\geq1024^3\)

A high-dimensional sampling problem

Some useful programming tools

  • NumPyro
    • Probabilistic Programming Language (PPL)
    • Powered by JAX
    • Integrated samplers
  • JAX
    • GPU acceleration
    • Just-In-Time (JIT) compilation acceleration
    • Automatic vectorization/parallelization
    • Automatic differentiation
  1. Prior on
    • Cosmology \(\Omega\)
    • Initial field \(\delta_L\)
    • Dark matter-galaxy connection (Lagrangian galaxy biases) \(b\)
  2. Initialize matter particles
  3. LSS formation (LPT+PM)
  4. Populate matter field with galaxies
  5. Galaxy peculiar velocities (RSD)
  6. Observational noise

Let's build a cosmological model

  • Fast and differentiable model thx to JaxPM
  • still, \(\simeq 1024^3\) parameters is huge!
  • Some proposed methods by Lavaux+2018, Bayer+2023

Which sampling methods can scale to high dimensions?

  1. Prior on
    • Cosmology \(\Omega\)
    • Initial field \(\delta_L\)
    • Dark matter-galaxy connection (Lagrangian galaxy biases) \(b\)
  2. Initialize matter particles
  3. LSS formation (LPT+PM)
  4. Populate matter field with galaxies
  5. Galaxy peculiar velocities (RSD)
  6. Observational noise

Now let's build a cosmological model

  • \(\simeq 1024^3\) parameters is huge!
  • Need inference methods that scale to high dimensions
  • Some proposed by Lavaux+2018, Bayer+2023
  • Fast and differentiable model
  • No need to approx likelihood for your favorite stat from simulation:
    the simulation is the likelihood

Why care about differentiable model?

  • Classical MCMCs
    • agnostic random moves
      + MH acceptance step
      = blinded natural selection

       
    • small moves yield correlated samples.
       
  • SOTA MCMCs rely on the gradient of the model log-proba, to drive dynamic towards highest density regions.

gradient descent
posterior mode

Brownian
exploding Gaussian

Langevin
posterior

+

=

Hamiltonian Monte Carlo (HMC)

  • To travel farther, add inertia.
    • sample particle at position \(q\) now have momentum \(p\) and mass matrix \(M\)
    • target \(\boldsymbol{p}(q)\) becomes \(\boldsymbol{p}(q , p) := e^{-\mathcal H(q,p)}\), with Hamiltonian $$\mathcal H(q,p) := -\log \boldsymbol{p}(q) + \frac 1 2 p^\top M^{-1} p$$
    • at each step, resample momentum \(p \sim \mathcal N(0,M)\)
    • let \((q,p)\) follow the Hamiltonian dynamic during time length \(L\), then arrival becomes new MH proposal.

Variations around HMC

  • No U-Turn Sampler (NUTS)
    • trajectory length \(L\) auto-tuned
    • samples drawn along trajectory
  • NUTSGibbs i.e. alternating sampling over parameter subsets.

Why care about differentiable model?

Variations around HMC

  • No U-Turn Sampler (NUTS)
    • trajectory length auto-tuned
    • samples drawn along trajectory
  • NUTSGibbs i.e. alternating sampling over parameter subsets
  • Model gradient drives sample particles towards high density regions
     
  • Hamiltonian Monte Carlo (HMC):
    to travel farther, add inertia
     
  • Yields less correlated chains

3) Hamiltonian dynamic

1) mass \(M\) particle at \(q\)

2) random kick \(p\)

2)random kick \(p\)

1) mass \(M\) particle at \(q\)

3) Hamiltonian dynamic

Benchmarking

  • model setting: \(64^3\) mesh, \((640\textrm{ Mpc/h})^3\) box, 1LPT, 2nd order Lagrangian bias expansion, RSD and Gaussian observational noise.
  • parameter space: initial field \(\delta_L\), cosmology \(\Omega\! =\! \{\Omega_m, \sigma_8\}\), and galaxy biases \(b\!=\!\{b_1,b_2,b_{s^2},b_{\nabla^2}\}\). Total of \(64^3 + 2 +4\) parameters.
  • For NUTSGibbs: split sampling between \(\delta_L\) and the rest (common in lit.)










     
  • Results suggest no particular advantage to splitting sampling between initial field and rest, cf. Simon-Onfroy et al. in prep

Reconstruct the initial field simultaneously, yielding posterior on full universe history

number of evaluations to yield one effective sample: the higher the worse

Activities 2023-2024

  • Schools
    • Rodolphe Clédassou (2023, 2024)
    • Bayes@CIRM (2023)
  • Teaching
    • M1 Maths and L2 Biostats at UPS
  • Conferences and workshops
    • ML in Astro IAP/CCA
    • Poster at Cosmo21
    • Talk at EDSU Tools
    • Talk a DESI meeting
    • Bayes intro at Cosmostat
  • Visit at Flatiron Institute
  • DESI observing (cloudy...)

...and what's next

  • 1st author paper (< December 2024)

    • benchmark more proposed samplers
    • compare to SBI approaches
  • Longer term
    • applications on DESI data (survey selection fct)
    • annealing/diffusion based sampler

Recap...

  • Leverage modern computational tools to build fast and differentiable cosmological model

  • Field-level inference can scale to Stage-IV galaxy surveys, and is relevant to fully capture cosmological information from data

...and what's next

  • Include more proposed samplers.
  • Compare to SBI approaches.
  • Move towards applications on DESI data.

Recap...

  • Field-level inference may be relevant to fully capture cosmological information in data.

  • Leverage modern computational tools to build fast and differentiable cosmological model.

  • Standardized benchmark for comparing MCMC samplers on field-level inference tasks, selecting proposed methods for Stage-IV galaxy surveys.

  • At large scales, matter density field almost Gaussian so power spectrum is almost lossless compression
  • At smaller scales however, matter density field non-Gaussian

Gaussianity and beyond

2 fields, 1 power spectrum: Gaussian or N-body?

$$\boxed{\min_s \operatorname{\boldsymbol{H}}(\Omega\mid s(\delta_g))} = \boldsymbol{H}(\Omega)  - \max_s \boldsymbol{I}(\Omega  ; s(\delta_g))$$

$$\boldsymbol{H}(\Omega)$$

$$\boldsymbol{H}(\delta_g)$$

$$\boldsymbol{H}(\mathcal s_1)$$

$$\boldsymbol{H}(\mathcal s_2)$$

$$\boldsymbol{H}(\mathcal P)$$

non-Gaussianities

relevant stat
(low info but high mutual info)

irrelevant stat
(high info but low mutual info)

also a relevant stat
(high info and mutual info)

Which stats are relevant for cosmo inference?

Gaussianities

Compress or not compress

So JAX in practice?

  • GPU accelerate

     
  • JIT compile

     
  • Vectorize/Parallelize

     
  • Auto-diff

 

import jax.numpy as np
# then enjoy
function = jax.jit(function)
# function is so fast now!
gradient = jax.grad(function)
# too bad if you love chain ruling by hand
vfunction = jax.vmap(function)
pfunction = jax.pmap(function)
# for-loops are for-loosers

So NumPyro in practice?

def model():
    z = sample('Z', dist.Normal(0, 1))
    x = sample('X', dist.Normal(z**2, 1))
    return x

render_model(model, render_distributions=True)

x_sample = dict(x=seed(model, 42)())
obs_model = condition(model, x_sample)
logp_fn = lambda z: log_density(obs_model,(),{},{'Z':z})[0]
from jax import jit, vmap, grad
score_vfn = jit(vmap(grad(logp_fn)))
kernel = infer.NUTS(obs_model)
mcmc = infer.MCMC(kernel, num_warmup, num_samples)
mcmc.run(PRGNKey(43))
samples = mcmc.get_samples()
  • Probabilistic Programming





     
  • JAX machinery

     
  • Integrated samplers
  • Effective Sample Size (ESS)
    • number of i.i.d. samples that yield same statistical power.
    • For sample sequence of size \(N\) and autocorrelation \(\rho\) $$N_\textrm{eff} = \frac{N}{1+2 \sum_{t=1}^{+\infty}\rho_t}$$so aim for as less correlated sample as possible.








       
  • Main limiting computational factor is model evaluation (e.g. N-body), so characterize MCMC efficiency by \(N_\text{eval} / N_\text{eff}\)

How to compare samplers?

CSI 2024

By hsimonfroy

CSI 2024

  • 13