Project proposal: Field-level analysis of primordial non-Gaussianity with DESI tracers

Hugo SIMON-ONFROY, 
PhD student supervised by
Arnaud DE MATTIA and François LANUSSE

PNG meeting, 2025/02/26

The universe recipe (so far)

HH0=Ωr+Ωb+Ωc+Ωκ+ΩΛ\frac{H}{H_0} = \sqrt{\Omega_r + \Omega_b + \Omega_c+ \Omega_\kappa + \Omega_\Lambda}

instantaneous expansion rate

energy content

Cosmological principle + Einstein equation

+ Inflation

δLG(0,P)\delta_L \sim \mathcal G(0, \mathcal P)

σ8:=σ[δL1r8]\sigma_8:= \sigma[\delta_L * \boldsymbol 1_{r \leq 8}]

initial field

primordial power spectrum

std. of fluctuations smoothed at 8 Mpc/h8 \text{ Mpc}/h

Ω:={Ωc,Ωb,ΩΛ,H0,σ8,ns,...}\Omega := \{ \Omega_c, \Omega_b, \Omega_\Lambda, H_0, \sigma_8, n_s,...\}

Linear matter spectrum

Structure growth

Cosmological modeling and inference

Ω\Omega

δL\delta_L

δg\delta_g

inference

Please define model

Model: an observation generator, link between latent and observed variables

FF

CC

DD

EE

AA

BB

GG

Frequentist:

fixed but unknown parameters

Bayesian:

random because unknown

AA

BB

CC

FF

DD

EE

GG

Represented as a DAG which expresses a joint probability factorization, i.e. conditional dependencies between variables

latent param

latent variable

observed variable

Definitions

evidence/prior predictive: p(y):=p(yx)p(x)dx\boldsymbol{p}(y) := \int \boldsymbol{p}(y \mid x) \boldsymbol{p}(x)d x

posterior predictive: p(y1y0):=p(y1x)p(xy0)dx\boldsymbol{p}(y_1 \mid y_0) := \int \boldsymbol{p}(y_1 \mid x) \boldsymbol{p}(x \mid y_0)d x

YY

XX

Bayes thm

\sim

p(xy)posteriorp(y)evidence=p(yx)likelihoodp(x)prior\underbrace{\boldsymbol{p}(x \mid y)}_{\text{posterior}}\underbrace{\boldsymbol{p}(y)}_{\text{evidence}} = \underbrace{\boldsymbol{p}(y \mid x)}_{\text{likelihood}} \,\underbrace{\boldsymbol{p}(x)}_{\text{prior}}

YY

XX

Marginalization:

YY

XX

ZZ

ZZ

XX

YY

XX

YY

XX

Conditioning:

Definitions

evidence/prior predictive: p(y):=p(yx)p(x)dx\boldsymbol{p}(y) := \int \boldsymbol{p}(y \mid x) \boldsymbol{p}(x)d x

posterior predictive: p(y1y0):=p(y1x)p(xy0)dx\boldsymbol{p}(y_1 \mid y_0) := \int \boldsymbol{p}(y_1 \mid x) \boldsymbol{p}(x \mid y_0)d x

XX

ZZ

Marginalization:

YY

XX

Bayes thm

\sim

p(xy)posteriorp(y)evidence=p(yx)likelihoodp(x)prior\underbrace{\boldsymbol{p}(x \mid y)}_{\text{posterior}}\underbrace{\boldsymbol{p}(y)}_{\text{evidence}} = \underbrace{\boldsymbol{p}(y \mid x)}_{\text{likelihood}} \,\underbrace{\boldsymbol{p}(x)}_{\text{prior}}

YY

XX

YY

XX

Conditioning:

ss

Ω\Omega

Cosmology is (sometimes) hard

Typical cosmological analyses

Field-level inference

Summary stat inference

δg\delta_g

bb

Ω\Omega

δg\delta_g

bb

Ω\Omega

δL\delta_L

bb

Ω\Omega

ss

δg\delta_g

bb

Ω\Omega

δL\delta_L

δg\delta_g

δL\delta_L

ss

marginalize

invert

marginalize

invert

bb

Ω\Omega

ss

δg\delta_g

bb

Ω\Omega

δL\delta_L

δL\delta_L

Field-level inference

Summary stat inference

δg\delta_g

bb

Ω\Omega

bb

Ω\Omega

ss

δg\delta_g

bb

Ω\Omega

δL\delta_L

ss

marginalize

condition

marginalize

bb

Ω\Omega

ss

δg\delta_g

bb

Ω\Omega

δL\delta_L

condition

Two approaches to cosmological inference

Probabilistic Programming: Wish List

{x,y,z samplesp(x,y,z)U=logpU,2U,\begin{cases}x,y,z \text{ samples}\\\boldsymbol{p}(x,y,z)\\ U= -\log \boldsymbol{p}\\\nabla U,\nabla^2 U,\dots\end{cases}

1. Modeling

e.g. in JAX with NumPyro\texttt{NumPyro}, JaxCosmo\texttt{JaxCosmo}, JaxPM\texttt{JaxPM}, DISCO-EB\texttt{DISCO-EB}...

2. DAG compilation

PPLs e.g. NumPyro\texttt{NumPyro}, Stan\texttt{Stan}, PyMC\texttt{PyMC}...

3. Inference

e.g. MCMC or VI with NumPyro\texttt{NumPyro}, Stan\texttt{Stan}, PyMC\texttt{PyMC}, emcee\texttt{emcee}...

4. Extract and viz

e.g. KDE with GetDist\texttt{GetDist}, GMM-MI\texttt{GMM-MI}, mean()\texttt{mean()}, std()\texttt{std()}...

Physics job

Stats job

Stats job

auto!

XX

ZZ

YY

XX

ZZ

1

2

3

4

Typical cosmological analyses

Field-level inference

Summary stat inference

δg\delta_g

Ω\Omega

δg\delta_g

Ω\Omega

δL\delta_L

Ω\Omega

ss

marginalize

invert

marginalize

invert

Ω\Omega

ss

δg\delta_g

Ω\Omega

δL\delta_L

δL\delta_L

δg\delta_g

Ω\Omega

δL\delta_L

δg\delta_g

δL\delta_L

ss

Summary stat inference

Problem:

  • ss is too simple     \implies lossy compression
  • ss is too complex     \implies intractable marginalization

Ω\Omega

ss

marginalize

invert

Ω\Omega

ss

δg\delta_g

Ω\Omega

δL\delta_L

δg\delta_g

δL\delta_L

ss

Field-level inference

The Problem:

  • high-dimensional integral p(Ωδg)=p(Ω,δLδg)  dδL\boldsymbol{p}(\Omega \mid \delta_g) = \int \boldsymbol{p}(\Omega, \delta_L \mid \delta_g) \;\mathrm d \delta_L
  • To probe scales of 5 Mpc/h5\ \rm{Mpc}/h in DESI volume, dim(δL)10243\operatorname{dim}(\delta_L) \simeq 1024^3
     

The Promise:

  • "lossless" explicit inference

δg\delta_g

Ω\Omega

δg\delta_g

Ω\Omega

δL\delta_L

marginalize

invert

δg\delta_g

Ω\Omega

δL\delta_L

δL\delta_L

  1. Prior on
    • Cosmology Ω\Omega
    • Initial field δL\delta_L
    • Lagrangian galaxy biases bb
      (Dark matter-galaxy connection)
  2. LSS formation (LPT or N-body)
  3. Apply galaxy bias
  4. Redshift-Space Distorsions
  5. Observational noise

Field-Level Modeling

Fast and differentiable model thanks to                        (NumPyro and JaxPM)

MCMC sampling via Hamiltonian dynamics

Hamiltonian Monte Carlo algorithms sample target density p\boldsymbol{p} by simulating lightweight particle (position qq, momentum pp) in sparse medium:

  • Particle follows Hamiltonian dynamic for some time before being kicked.

     
  • canonical Hamiltonian Monte Carlo (HMC) employs Newtonian HamiltonianH(q,p):=U(q)+KNewt(p)=logp(q)+p22m\mathcal H(q,p) := U(q) + K_\text{Newt}(p) = -\log \boldsymbol{p}(q) + \frac{p^2}{2m}
    Limitation: Particle must change of energy level to correctly sample the target

     
  • Micro-Canonical Hamiltonian Monte Carlo (MCHMC) employs Energy Sampling HamiltonianH(q,p):=U(q)+KESH(p)=logp(q)+d2logp2md\mathcal H(q,p) := U(q) + K_\text{ESH}(p) = -\log \boldsymbol{p}(q) + \frac d 2 \log \frac{p^2}{md}
    Particle sample the target on a single energy level

MCMC sampling via Hamiltonian dynamics

  • canonical Hamiltonian Monte Carlo (HMC) samples target density p\boldsymbol{p} by simulating Newtonian lightweight particle (q,p)(q , p) in sparse medium:
    • Most of the time, particle follows Hamiltonian dynamic given by H(q,p):=U(q)+KNewt(p)=logp(q)+12pM1p\mathcal H(q,p) := U(q) + K_\text{Newt}(p) = -\log \boldsymbol{p}(q) + \frac 1 2 p^\top M^{-1} p
    • Sometimes particle gets kicked, momentum is refreshed pN(0,M)p \sim \mathcal N(0, M)
    • Ergodicity ensure particle samples canonical ensemble pC(q,p)eH(q,p)\boldsymbol p_{C}(q, p) \propto e^{-\mathcal H(q, p)}, and its positions follows the target p(q)=pC(q,p)dp\boldsymbol p(q) = \int \boldsymbol p_{C}(q, p) \mathrm{d} p

MCMC sampling via Hamiltonian dynamics

  • Micro-Canonical Hamiltonian Monte Carlo (MCHMC) samples target density p\boldsymbol{p} by simulating non-Newtonian particle (q,p)(q , p) in sparse medium:
    • Most of the time, particle follows Hamiltonian dynamic given by H(q,p):=U(q)+KESH(p)=logp(q)+d2logpM1pd\mathcal H(q,p) := U(q) + K_\text{ESH}(p) = -\log \boldsymbol{p}(q) + \frac d 2 \log \frac{p^\top M^{-1} p}{d}
    • Sometimes particle gets kicked, momentum is refreshed pN(0,M)p \sim \mathcal N(0, M)
    • Ergodicity ensure particle samples micro-canonical ensemble pC(q,p)δ(H(q,p)E)\boldsymbol p_{C}(q, p) \propto \delta(\mathcal H(q, p) - E), and its positions follows the target p(q)=pMC(q,p)dp\boldsymbol p(q) = \int \boldsymbol p_{MC}(q, p) \mathrm{d} p
  • Hamiltonian samplers already in use for field-level (Lavaux+2018, Bayer+2023)
     
  • We produce a consistent benchmark of these methods for galaxy clustering

MCMC sampler comparison

NUTS = auto-tuned variant of HMC
MCLMC = Langevin variant of MCHMC

10 times less model evaluations required

MCLMC scales better than NUTS/HMC in high-dimension

Model preconditioning

  • Sampling is easier when posterior is isotropic Gaussian













  • So we reparametrize model assuming a tractable Kaiser model:
    linear growth + linear Eulerian bias + flat sky RSD + Gaussian noise

10 times less evaluations required

😊

😔

What it looks like

Infer jointly cosmological and bias parameters and initial linear matter field

Simon-Onfroy in prep.

wg=1+b1δL+b2δL2+bs2s2+b22δLw_g=1+b_{1}\,\delta_L+b_{2}\delta_L^{2}+b_{s^2}s^{2} + b_{\nabla^2} \nabla^2 \delta_L

Field-level inference of PNG

wg=1+b1δL+b2δL2+bs2s2+bϕfNLϕL+bϕ,δfNLϕLδLw_g=1+b_{1}\,\delta_L+b_{2}\delta_L^{2}+b_{s^2}s^{2} + b_\phi f_{\rm NL} \phi_L + b_{\phi, \delta} f_{\rm NL} \phi_L \delta_L

ΦNG=ϕL+fNL(ϕL2<ϕL2>)\Phi_{\mathrm{NG}}=\phi_L+f_{\mathrm{NL}}\,\left(\phi_L^{2}-\left\lt \phi_L^{2}\right\gt \right)

PNG inference is a natural first application of our FL pipeline since most constraints come from large scales
(coarser mesh and simpler model)

Goal: application planned on DESI Y1 LRG and QSO. 
If strategy proves effective, extend analysis to ELG Y1.

Preliminary results without light-cone and systematics

DESI Galaxy density field

  • Understand bias and noise model at field-level
    for typical voxel size of 40 Mpc/h\simeq 40\ \text{Mpc}/h
  • 2LPT and bias expansion + modified Poisson noise should be enough
  • Model tested at field-level against HOD-populated AbacusSummit mocks

Field-level modeling of PNG: TODO List

DESI survey selection function

  • Varying line-of-sight
  • Light cone evolution
  • Model for photometric systematics and integral constraints

Can generate mocks that we can process via standard analysis pipeline for apple-to-apple comparison


Eventually, we will try to extend this pipeline for application on smaller scales (equivalent to full-shape analysis)

Some useful programming tools

  • NumPyro
    • Probabilistic Programming Language (PPL)
    • Powered by JAX
    • Integrated samplers
  • JAX
    • GPU acceleration
    • Just-In-Time (JIT) compilation acceleration
    • Automatic vectorization/parallelization
    • Automatic differentiation

JAX reminder

  • GPU accelerate

     
  • JIT compile

     
  • Vectorize/Parallelize

     
  • Auto-diff

 

import jax.numpy as np
# then enjoy
function = jax.jit(function)
# function is so fast now!
gradient = jax.grad(function)
# too bad if you love chain ruling by hand
vfunction = jax.vmap(function)
pfunction = jax.pmap(function)
# for-loops are for-loosers

Probabilistic Modeling and Compilation

def simulate(seed):
    rng = np.random.RandomState(seed)
    x = rng.randn()
    y = rng.randn() + x**2
    return x, y
from scipy.stats import norm
log_prior = lambda x: norm.logpdf(x, 0, 1)
log_lik = lambda x, y: norm.logpdf(y, x**2, 1)
log_joint= lambda x, y: log_lik(y, x) + log_prior(x)
grad_log_prior = lambda x: x**2
grad_log_lik = lambda x, y: stack([2 * x * (y - x**2), (y - x**2)])
grad_log_joint = lambda x, y: grad_log_lik(y, x) + stack([grad_log_prior(x), zeros_like(y)])
# Hessian left as an exercise...

To fully define a probabilistic model, we need

  1. its simulator



     
  2. its associated joint log prob



     
  3. possibly its gradients, Hessian... (useful for inference)



     

We should only care about the simulator!

YY

XX

{XN(0,1)YXN(X2,1)\begin{cases}X \sim \mathcal N(0,1)\\Y\mid X \sim \mathcal N(X^2, 1)\end{cases}

my model on paper

my model on code

What if we modify to fit variances? non-Gaussian dist? Recompute everything? :'(

Modeling with Numpyro

y_sample = seed(model, 42)()
log_joint = lambda x, y: log_density(model,(),{},{'x':x, 'y':y})[0]
obs_model = condition(model, y_sample)
logp_fn = lambda x: log_density(obs_model,(),{},{'x':x})[0]
def model():
    x = sample('x', dist.Normal(0, 1))
    y = sample('y', dist.Normal(x**2, 1))
    return y
from jax import jit, vmap, grad
force_vfn = jit(vmap(grad(logp_fn)))
  • Define model as simulator... and that's it


     
  • Simulate, get log prob


     
  • +JAX machinery


     
  • Condition model


     
  • And more!

So NumPyro in practice?

def model():
    x = sample('x', dist.Normal(0, 1))
    y = sample('y', dist.Normal(x**2, 1))
    return y

render_model(model, render_distributions=True)

y_sample = dict(y=seed(model, 42)())
obs_model = condition(model, x_sample)
logp_fn = lambda x: log_density(obs_model,(),{},{'x':x})[0]
from jax import jit, vmap, grad
force_vfn = jit(vmap(grad(logp_fn)))
kernel = infer.NUTS(obs_model)
mcmc = infer.MCMC(kernel, num_warmup, num_samples)
mcmc.run(jr.key(43))
samples = mcmc.get_samples()
  • Probabilistic Programming





     
  • JAX machinery

     
  • Integrated samplers

Why care about differentiable model?

  • Classical MCMCs
    • agnostic random moves
      + MH acceptance step
      = blinded natural selection

       
    • small moves yield correlated samples
       
  • SOTA MCMCs rely on the gradient of the model log-prob, to drive dynamic towards highest density regions

gradient descent
posterior mode

Brownian
exploding Gaussian

Langevin
posterior

+

=

Hamiltonian Monte Carlo (HMC)

  • To travel farther, add inertia.
    • sample particle at position qq now have momentum pp and mass matrix MM
    • target p(q)\boldsymbol{p}(q) becomes p(q,p):=eH(q,p)\boldsymbol{p}(q , p) := e^{-\mathcal H(q,p)}, with Hamiltonian H(q,p):=logp(q)+12pM1p\mathcal H(q,p) := -\log \boldsymbol{p}(q) + \frac 1 2 p^\top M^{-1} p
    • at each step, resample momentum pN(0,M)p \sim \mathcal N(0,M)
    • let (q,p)(q,p) follow the Hamiltonian dynamic during time length LL, then arrival becomes new MH proposal.

Variations around HMC

  • No U-Turn Sampler (NUTS)
    • trajectory length LL auto-tuned
    • samples drawn along trajectory
  • NUTSGibbs i.e. alternating sampling over parameter subsets.

Why care about differentiable model?

Variations around HMC

  • No U-Turn Sampler (NUTS)
    • trajectory length auto-tuned
    • samples drawn along trajectory
  • NUTSGibbs i.e. alternating sampling over parameter subsets
  • Model gradient drives sample particles towards high density regions
     
  • Hamiltonian Monte Carlo (HMC):
    to travel farther, add inertia
     
  • Yields less correlated chains

3) Hamiltonian dynamic

1) mass MM particle at qq

2) random kick pp

2) random kick pp

1) mass MM particle at qq

3) Hamiltonian dynamic

Why care about differentiable model?

  • Model gradient drives sample particles towards high density regions
     
  • Hamiltonian Monte Carlo (HMC):
    to travel farther, add inertia
     
  • Yields less correlated chains

3) Hamiltonian dynamic

1) mass MM particle at qq

2) random kick pp

2) random kick pp

1) mass MM particle at qq

3) Hamiltonian dynamic

Inference with NumPyro

def model():
    x = sample('x', dist.Normal(0, 1))
    y = sample('y', dist.Normal(x**2, 1))
    return y

render_model(model, render_distributions=True)

y_sample = dict(y=seed(model, 42)())
obs_model = condition(model, x_sample)
logp_fn = lambda x: log_density(obs_model,(),{},{'x':x})[0]
from jax import jit, vmap, grad
force_vfn = jit(vmap(grad(logp_fn)))
kernel = infer.NUTS(obs_model)
mcmc = infer.MCMC(kernel, num_warmup, num_samples)
mcmc.run(jr.key(43))
samples = mcmc.get_samples()
  • Probabilistic Model





     
  • +JAX machinery

     
  • = gradient-based MCMC

How to N-body-differentiate?

(q,p)(q, p)

δ(x)\delta(x)

δ(k)\delta(k)

paint*

read*

fft*

ifft*

fft*

*: differentiable, e.g. with                    , see JaxPM

apply forces
to move particles

solve Vlasov-Poisson
to compute forces

  • At large scales, matter density field almost Gaussian so power spectrum is almost lossless compression
  • At smaller scales however, matter density field non-Gaussian

Gaussianity and beyond

2 fields, 1 power spectrum: Gaussian or N-body?

  • Effective Sample Size (ESS)
    • number of i.i.d. samples that yield same statistical power.
    • For sample sequence of size NN and autocorrelation ρ\rho Neff=N1+2t=1+ρtN_\textrm{eff} = \frac{N}{1+2 \sum_{t=1}^{+\infty}\rho_t}so aim for as less correlated sample as possible.








       
  • Main limiting computational factor is model evaluation (e.g. N-body), so characterize MCMC efficiency by Neval/NeffN_\text{eval} / N_\text{eff}

How to compare samplers?

2024PNGmeeting

By hsimonfroy

2024PNGmeeting

  • 67