Probabilistic Programming for Cosmology with JAX

Hugo SIMON-ONFROY, 
PhD student supervised by
Arnaud DE MATTIA and François LANUSSE

CoPhy, 2024/11/19

The universe recipe (so far)

$$\frac{H}{H_0} = \sqrt{\Omega_r + \Omega_b + \Omega_c+ \Omega_\kappa + \Omega_\Lambda}$$

instantaneous expansion rate

energy content

Cosmological principle + Einstein equation

+ Inflation

\(\delta_L \sim \mathcal G(0, \mathcal P)\)

\(\sigma_8:= \sigma[\delta_L * \boldsymbol 1_{r \leq 8}]\)

initial field

primordial power spectrum

std. of fluctuations smoothed at \(8 \text{ Mpc/h}\)

\(\Omega := \{ \Omega_c, \Omega_b, \Omega_\Lambda, H_0, \sigma_8, n_s,...\}\)

Linear matter spectrum

Structure growth

Cosmological modeling and inference

\(\Omega\)

\(\delta_L\)

\(\delta_g\)

inference

Please define model

Model: an observation generator, link between latent and observed variables

\(F\)

\(C\)

\(D\)

\(E\)

\(A\)

\(B\)

\(G\)

Frequentist:

fixed but unknown parameters

Bayesian:

random because unknown

\(A\)

\(B\)

\(C\)

\(F\)

\(D\)

\(E\)

\(G\)

Represented as a DAG which expresses a joint probability factorization, i.e. conditional dependencies between variables

latent param

latent variable

observed variable

Definitions

evidence/prior predictive: $$\boldsymbol{p}(y) := \int \boldsymbol{p}(y \mid x) \boldsymbol{p}(x)d x$$

posterior predictive: $$\boldsymbol{p}(y_1 \mid y_0) := \int \boldsymbol{p}(y_1 \mid x) \boldsymbol{p}(x \mid y_0)d x$$

\(Y\)

\(X\)

Bayes thm

\(\sim\)

\(\underbrace{\boldsymbol{p}(x \mid y)}_{\text{posterior}}\underbrace{\boldsymbol{p}(y)}_{\text{evidence}} = \underbrace{\boldsymbol{p}(y \mid x)}_{\text{likelihood}} \,\underbrace{\boldsymbol{p}(x)}_{\text{prior}}\)

\(Y\)

\(X\)

Marginalization:

\(Y\)

\(X\)

\(Z\)

\(Z\)

\(X\)

\(Y\)

\(X\)

Conditioning/Bayes inversion:

\(Y\)

\(X\)

Definitions

evidence/prior predictive: $$\boldsymbol{p}(y) := \int \boldsymbol{p}(y \mid x) \boldsymbol{p}(x)d x$$

posterior predictive: $$\boldsymbol{p}(y_1 \mid y_0) := \int \boldsymbol{p}(y_1 \mid x) \boldsymbol{p}(x \mid y_0)d x$$

\(Y\)

\(X\)

Conditioning/Bayes inversion:

\(X\)

\(Z\)

Marginalization:

\(Y\)

\(X\)

Bayes thm

\(\sim\)

\(\underbrace{\boldsymbol{p}(x \mid y)}_{\text{posterior}}\underbrace{\boldsymbol{p}(y)}_{\text{evidence}} = \underbrace{\boldsymbol{p}(y \mid x)}_{\text{likelihood}} \,\underbrace{\boldsymbol{p}(x)}_{\text{prior}}\)

\(Y\)

\(X\)

\(s\)

\(\Omega\)

Cosmology is (sometimes) hard

Probabilistic Programming: Wish List

\(\begin{cases}x,y,z \text{ samples}\\\boldsymbol{p}(x,y,z)\\\Phi= -\log \boldsymbol{p}\\\nabla\Phi,\nabla^2 \Phi,\dots\end{cases}\)

1. Modeling

e.g. in JAX with \(\texttt{NumPyro}\), \(\texttt{JaxCosmo}\), \(\texttt{JaxPM}\), \(\texttt{DISCO-EB}\)...

2. DAG compilation

PPLs e.g. \(\texttt{NumPyro}\), \(\texttt{Stan}\), \(\texttt{PyMC}\)...

3. Inference

e.g. MCMC or VI with \(\texttt{NumPyro}\), \(\texttt{Stan}\), \(\texttt{PyMC}\), \(\texttt{emcee}\)...

4. Extract and viz

e.g. KDE with \(\texttt{GetDist}\), \(\texttt{GMM-MI}\), \(\texttt{mean()}\), \(\texttt{std()}\)...

Physics job

Stats job

Stats job

auto!

\(X\)

\(Z\)

\(Y\)

\(X\)

\(Z\)

1

2

3

4

Some useful programming tools

  • NumPyro
    • Probabilistic Programming Language (PPL)
    • Powered by JAX
    • Integrated samplers
  • JAX
    • GPU acceleration
    • Just-In-Time (JIT) compilation acceleration
    • Automatic vectorization/parallelization
    • Automatic differentiation

JAX reminder

  • GPU accelerate

     
  • JIT compile

     
  • Vectorize/Parallelize

     
  • Auto-diff

 

import jax.numpy as np
# then enjoy
function = jax.jit(function)
# function is so fast now!
gradient = jax.grad(function)
# too bad if you love chain ruling by hand
vfunction = jax.vmap(function)
pfunction = jax.pmap(function)
# for-loops are for-loosers

Probabilistic Modeling and Compilation

def simulate(seed):
    rng = np.random.RandomState(seed)
    x = rng.randn()
    y = rng.randn() + x**2
    return x, y
from scipy.stats import norm
log_prior = lambda x: norm.logpdf(x, 0, 1)
log_lik = lambda x, y: norm.logpdf(y, x**2, 1)
log_joint= lambda x, y: log_lik(y, x) + log_prior(x)
grad_log_prior = lambda x: x**2
grad_log_lik = lambda x, y: stack([2 * x * (y - x**2), (y - x**2)])
grad_log_joint = lambda x, y: grad_log_lik(y, x) + stack([grad_log_prior(x), zeros_like(y)])
# Hessian left as an exercise...

To fully define a probabilistic model, we need

  1. its simulator



     
  2. its associated joint log prob



     
  3. possibly its gradients, Hessian... (useful for inference)



     

We should only care about the simulator!

\(Y\)

\(X\)

$$\begin{cases}X \sim \mathcal N(0,1)\\Y\mid X \sim \mathcal N(X^2, 1)\end{cases}$$

my model on paper

my model on code

What if we modify to fit variances? non-Gaussian dist? Recompute everything? :'(

Modeling with Numpyro

y_sample = seed(model, 42)()
log_joint = lambda x, y: log_density(model,(),{},{'x':x, 'y':y})[0]
obs_model = condition(model, y_sample)
logp_fn = lambda x: log_density(obs_model,(),{},{'x':x})[0]
def model():
    x = sample('x', dist.Normal(0, 1))
    y = sample('y', dist.Normal(x**2, 1))
    return y
from jax import jit, vmap, grad
force_vfn = jit(vmap(grad(logp_fn)))
  • Define model as simulator... and that's it


     
  • Simulate, get log prob


     
  • +JAX machinery


     
  • Condition model


     
  • And more!

So NumPyro in practice?

def model():
    x = sample('x', dist.Normal(0, 1))
    y = sample('y', dist.Normal(x**2, 1))
    return y

render_model(model, render_distributions=True)

y_sample = dict(y=seed(model, 42)())
obs_model = condition(model, x_sample)
logp_fn = lambda x: log_density(obs_model,(),{},{'x':x})[0]
from jax import jit, vmap, grad
force_vfn = jit(vmap(grad(logp_fn)))
kernel = infer.NUTS(obs_model)
mcmc = infer.MCMC(kernel, num_warmup, num_samples)
mcmc.run(jr.key(43))
samples = mcmc.get_samples()
  • Probabilistic Programming





     
  • JAX machinery

     
  • Integrated samplers

Why care about differentiable model?

  • Classical MCMCs
    • agnostic random moves
      + MH acceptance step
      = blinded natural selection

       
    • small moves yield correlated samples
       
  • SOTA MCMCs rely on the gradient of the model log-prob, to drive dynamic towards highest density regions

gradient descent
posterior mode

Brownian
exploding Gaussian

Langevin
posterior

+

=

Hamiltonian Monte Carlo (HMC)

  • To travel farther, add inertia.
    • sample particle at position \(q\) now have momentum \(p\) and mass matrix \(M\)
    • target \(\boldsymbol{p}(q)\) becomes \(\boldsymbol{p}(q , p) := e^{-\mathcal H(q,p)}\), with Hamiltonian $$\mathcal H(q,p) := -\log \boldsymbol{p}(q) + \frac 1 2 p^\top M^{-1} p$$
    • at each step, resample momentum \(p \sim \mathcal N(0,M)\)
    • let \((q,p)\) follow the Hamiltonian dynamic during time length \(L\), then arrival becomes new MH proposal.

Variations around HMC

  • No U-Turn Sampler (NUTS)
    • trajectory length \(L\) auto-tuned
    • samples drawn along trajectory
  • NUTSGibbs i.e. alternating sampling over parameter subsets.

Why care about differentiable model?

Variations around HMC

  • No U-Turn Sampler (NUTS)
    • trajectory length auto-tuned
    • samples drawn along trajectory
  • NUTSGibbs i.e. alternating sampling over parameter subsets
  • Model gradient drives sample particles towards high density regions
     
  • Hamiltonian Monte Carlo (HMC):
    to travel farther, add inertia
     
  • Yields less correlated chains

3) Hamiltonian dynamic

1) mass \(M\) particle at \(q\)

2) random kick \(p\)

2) random kick \(p\)

1) mass \(M\) particle at \(q\)

3) Hamiltonian dynamic

Why care about differentiable model?

  • Model gradient drives sample particles towards high density regions
     
  • Hamiltonian Monte Carlo (HMC):
    to travel farther, add inertia
     
  • Yields less correlated chains

3) Hamiltonian dynamic

1) mass \(M\) particle at \(q\)

2) random kick \(p\)

2) random kick \(p\)

1) mass \(M\) particle at \(q\)

3) Hamiltonian dynamic

Inference with NumPyro

def model():
    x = sample('x', dist.Normal(0, 1))
    y = sample('y', dist.Normal(x**2, 1))
    return y

render_model(model, render_distributions=True)

y_sample = dict(y=seed(model, 42)())
obs_model = condition(model, x_sample)
logp_fn = lambda x: log_density(obs_model,(),{},{'x':x})[0]
from jax import jit, vmap, grad
force_vfn = jit(vmap(grad(logp_fn)))
kernel = infer.NUTS(obs_model)
mcmc = infer.MCMC(kernel, num_warmup, num_samples)
mcmc.run(jr.key(43))
samples = mcmc.get_samples()
  • Probabilistic Model





     
  • +JAX machinery

     
  • = gradient-based MCMC

How to N-body-differentiate?

\((q, p)\)

\(\delta(x)\)

\(\delta(k)\)

paint*

read*

fft*

ifft*

fft*

*: differentiable, e.g. with                    , see JaxPM

apply forces
to move particles

solve Vlasov-Poisson
to compute forces

  1. Prior on
    • Cosmology \(\Omega\)
    • Initial field \(\delta_L\)
    • Dark matter-galaxy connection (Lagrangian galaxy biases) \(b\)
  2. Initialize matter particles
  3. LSS formation (LPT+PM)
  4. Populate matter field with galaxies
  5. Galaxy peculiar velocities (RSD)
  6. Observational noise

Let's build a cosmological model

  • Fast and differentiable model thx to JaxPM
  • still, \(\simeq 1024^3\) parameters is huge!
  • Some proposed methods by Lavaux+2018, Bayer+2023

Which sampling methods can scale to high dimensions?

  • At large scales, matter density field almost Gaussian so power spectrum is almost lossless compression
  • At smaller scales however, matter density field non-Gaussian

Gaussianity and beyond

2 fields, 1 power spectrum: Gaussian or N-body?

  • Effective Sample Size (ESS)
    • number of i.i.d. samples that yield same statistical power.
    • For sample sequence of size \(N\) and autocorrelation \(\rho\) $$N_\textrm{eff} = \frac{N}{1+2 \sum_{t=1}^{+\infty}\rho_t}$$so aim for as less correlated sample as possible.








       
  • Main limiting computational factor is model evaluation (e.g. N-body), so characterize MCMC efficiency by \(N_\text{eval} / N_\text{eff}\)

How to compare samplers?