SCALABLE BAYESIAN INFERENCE WITH DIFFERENTIABLE SIMULATORS FOR THE
COSMOLOGICAL ANALYSIS OF THE DESI SPECTROSCOPIC SURVEY
Hugo SIMON-ONFROY,
PhD student supervised by Arnaud DE MATTIA and François LANUSSE
CSI, 2024/10/18
The universe recipe (so far)
H0H=Ωr+Ωb+Ωc+Ωκ+ΩΛ
instantaneous expansion rate
energy content
Cosmological principle + Einstein equation
+ Inflation
δL∼G(0,P)
σ8:=σ[δL∗1r≤8]
initial field
primordial power spectrum
std. of fluctuations smoothed at 8 Mpc/h

Ω:={Ωc,Ωb,ΩΛ,H0,σ8,ns,...}


Linear matter spectrum
Structure growth
H(Ω∣δg)=H(δg∣Ω)+H(Ω)−H(δg)=H(Ω)−I(Ω;δg)≤H(Ω)
I(Ω;δg)
H(X) = missing info on X
H(Ω)
H(δg)
H(Ω∣δg)
A high-dimensional inference problem

Ω:={Ωc,Ωb,ΩΛ,H0,σ8,ns,...}


Linear matter spectrum
Structure growth
- Cosmological model links cosmo Ω to initial field δL to galaxy density field δg
- Cosmological parameter inference obtained via marginalizing full posterior over initial fieldp(Ω∣δg)=∫p(Ω,δL∣δg)dδL
A high-dimensional inference problem
How to perform this marginalization?
Use summary statistics
Marginalize then sample
- build a marginal model linking cosmo Ω and some summary stat s, then sample from simpler p(Ω∣s)
- trade-off analytical/variational tractability vs. info content of s




summary stat inference
A tractable candidate: the power spectrum
We gotta pump this information up
- Field-level
- CNN, GNN...
- WST, 1D-PDFs, Holes...
- Peak, Void, Split, Cluster...
- 3PCF, Bispectrum
- 2PCF, Power spectrum

0−
H(δg)−













- At large scales, matter density field almost Gaussian so power spectrum is almost lossless compression, and is relatively tractable
- To prospect smaller non-Gaussian scales, let's add:
- all the data
- learn the stat
- multiscale count
- object correlations
- more correlations
- standard analysis
More information is better,
but how much better?
Model Based Inference at the field level
(explicitly solve)

How much information can we still gain (not lose)?



field-level inference
summary stat inference
For full field marginalize then sample hardly tractable,
so sample then marginalize
- sample jointly Ω and δL, then marginalize over δL samples
- full history reconstruction without info loss
- requires simulating LSS formation for each sample
- for probing non-Gaussian scales of interest in DESI volume dim(δL)≥10243
A high-dimensional sampling problem


Some useful programming tools
-
NumPyro
- Probabilistic Programming Language (PPL)
- Powered by JAX
- Integrated samplers
-
JAX
-
GPU acceleration
- Just-In-Time (JIT) compilation acceleration
- Automatic vectorization/parallelization
- Automatic differentiation
-
GPU acceleration
- Prior on
- Cosmology Ω
- Initial field δL
- Dark matter-galaxy connection (Lagrangian galaxy biases) b
- Initialize matter particles
- LSS formation (LPT+PM)
- Populate matter field with galaxies
- Galaxy peculiar velocities (RSD)
- Observational noise
Let's build a cosmological model





- Fast and differentiable model thx to JaxPM
- still, ≃10243 parameters is huge!
- Some proposed methods by Lavaux+2018, Bayer+2023
Which sampling methods can scale to high dimensions?
Why care about differentiable model?
Variations around HMC
-
No U-Turn Sampler (NUTS)
- trajectory length auto-tuned
- samples drawn along trajectory
- NUTSGibbs i.e. alternating sampling over parameter subsets
-
Model gradient drives sample particles towards high density regions
-
Hamiltonian Monte Carlo (HMC):
to travel farther, add inertia
- Yields less correlated chains

3) Hamiltonian dynamic
1) mass M particle at q
2) random kick p
2)random kick p
1) mass M particle at q
3) Hamiltonian dynamic
Benchmarking
- model setting: 643 mesh, (640 Mpc/h)3 box, 1LPT, 2nd order Lagrangian bias expansion, RSD and Gaussian observational noise.
- parameter space: initial field δL, cosmology Ω={Ωm,σ8}, and galaxy biases b={b1,b2,bs2,b∇2}. Total of 643+2+4 parameters.
- For NUTSGibbs: split sampling between δL and the rest (common in lit.)
- Results suggest no particular advantage to splitting sampling between initial field and rest, cf. Simon-Onfroy et al. in prep
Reconstruct the initial field simultaneously, yielding posterior on full universe history
number of evaluations to yield one effective sample: the higher the worse


Activities 2023-2024
- Schools
- Rodolphe Clédassou (2023, 2024)
- Bayes@CIRM (2023)
- Teaching
- M1 Maths and L2 Biostats at UPS
- Conferences and workshops
- ML in Astro IAP/CCA
- Poster at Cosmo21
- Talk at EDSU Tools
- Talk a DESI meeting
- Bayes intro at Cosmostat
- Visit at Flatiron Institute
- DESI observing (cloudy...)
...and what's next
-
1st author paper (< December 2024)
- benchmark more proposed samplers
- compare to SBI approaches
- Longer term
- applications on DESI data (survey selection fct)
- annealing/diffusion based sampler
Recap...
-
Leverage modern computational tools to build fast and differentiable cosmological model
-
Field-level inference can scale to Stage-IV galaxy surveys, and is relevant to fully capture cosmological information from data
- At large scales, matter density field almost Gaussian so power spectrum is almost lossless compression
- At smaller scales however, matter density field non-Gaussian
Gaussianity and beyond

2 fields, 1 power spectrum: Gaussian or N-body?
sminH(Ω∣s(δg))=H(Ω) −smaxI(Ω ;s(δg))
H(Ω)
H(δg)
H(s1)
H(s2)
H(P)
non-Gaussianities
relevant stat
(low info but high mutual info)
irrelevant stat
(high info but low mutual info)
also a relevant stat
(high info and mutual info)
Which stats are relevant for cosmo inference?
Gaussianities
Compress or not compress
So JAX in practice?
-
GPU accelerate
- JIT compile
- Vectorize/Parallelize
- Auto-diff
import jax.numpy as np # then enjoy
function = jax.jit(function) # function is so fast now!
gradient = jax.grad(function) # too bad if you love chain ruling by hand
vfunction = jax.vmap(function) pfunction = jax.pmap(function) # for-loops are for-loosers


So NumPyro in practice?
def model(): z = sample('Z', dist.Normal(0, 1)) x = sample('X', dist.Normal(z**2, 1)) return x render_model(model, render_distributions=True) x_sample = dict(x=seed(model, 42)()) obs_model = condition(model, x_sample) logp_fn = lambda z: log_density(obs_model,(),{},{'Z':z})[0]
def model(): z = sample('Z', dist.Normal(0, 1)) x = sample('X', dist.Normal(z**2, 1)) return x render_model(model, render_distributions=True) x_sample = dict(x=seed(model, 42)()) obs_model = condition(model, x_sample) logp_fn = lambda z: log_density(obs_model,(),{},{'Z':z})[0]
def model(): z = sample('Z', dist.Normal(0, 1)) x = sample('X', dist.Normal(z**2, 1)) return x render_model(model, render_distributions=True) x_sample = dict(x=seed(model, 42)()) obs_model = condition(model, x_sample) logp_fn = lambda z: log_density(obs_model,(),{},{'Z':z})[0]
def model(): z = sample('Z', dist.Normal(0, 1)) x = sample('X', dist.Normal(z**2, 1)) return x render_model(model, render_distributions=True) x_sample = dict(x=seed(model, 42)()) obs_model = condition(model, x_sample) logp_fn = lambda z: log_density(obs_model,(),{},{'Z':z})[0]
def model(): z = sample('Z', dist.Normal(0, 1)) x = sample('X', dist.Normal(z**2, 1)) return x render_model(model, render_distributions=True) x_sample = dict(x=seed(model, 42)()) obs_model = condition(model, x_sample) logp_fn = lambda z: log_density(obs_model,(),{},{'Z':z})[0]
from jax import jit, vmap, grad score_vfn = jit(vmap(grad(logp_fn)))
kernel = infer.NUTS(obs_model) mcmc = infer.MCMC(kernel, num_warmup, num_samples) mcmc.run(PRGNKey(43)) samples = mcmc.get_samples()
- Probabilistic Programming
- JAX machinery
- Integrated samplers


-
Effective Sample Size (ESS)
- number of i.i.d. samples that yield same statistical power.
- For sample sequence of size N and autocorrelation ρ Neff=1+2∑t=1+∞ρtNso aim for as less correlated sample as possible.
- Main limiting computational factor is model evaluation (e.g. N-body), so characterize MCMC efficiency by Neval/Neff
How to compare samplers?
CSI 2024
By hsimonfroy
CSI 2024
- 88