Project proposal: Field-level analysis of primordial non-Gaussianity with DESI tracers
Hugo SIMON-ONFROY,
PhD student supervised by
Arnaud DE MATTIA and François LANUSSE
PNG meeting, 2025/02/26
The universe recipe (so far)
H0H=Ωr+Ωb+Ωc+Ωκ+ΩΛ
instantaneous expansion rate
energy content
Cosmological principle + Einstein equation
+ Inflation
δL∼G(0,P)
σ8:=σ[δL∗1r≤8]
initial field
primordial power spectrum
std. of fluctuations smoothed at 8 Mpc/h

Ω:={Ωc,Ωb,ΩΛ,H0,σ8,ns,...}


Linear matter spectrum
Structure growth
Cosmological modeling and inference
Ω
δL
δg
inference
Please define model
Model: an observation generator, link between latent and observed variables
F
C
D
E
A
B
G
Frequentist:
fixed but unknown parameters
Bayesian:
random because unknown
A
B
C
F
D
E
G
Represented as a DAG which expresses a joint probability factorization, i.e. conditional dependencies between variables
latent param
latent variable
observed variable
Definitions
evidence/prior predictive: p(y):=∫p(y∣x)p(x)dx
posterior predictive: p(y1∣y0):=∫p(y1∣x)p(x∣y0)dx
Y
X
Bayes thm
∼
posteriorp(x∣y)evidencep(y)=likelihoodp(y∣x)priorp(x)
Y
X
Marginalization:
Y
X
Z
Z
X
Y
X
Y
X
Conditioning:
Definitions
evidence/prior predictive: p(y):=∫p(y∣x)p(x)dx
posterior predictive: p(y1∣y0):=∫p(y1∣x)p(x∣y0)dx
X
Z
Marginalization:
Y
X
Bayes thm
∼
posteriorp(x∣y)evidencep(y)=likelihoodp(y∣x)priorp(x)
Y
X
Y
X
Conditioning:

s
Ω
Cosmology is (sometimes) hard
Typical cosmological analyses
Field-level inference
Summary stat inference
δg
b
Ω
δg
b
Ω
δL
b
Ω
s
δg
b
Ω
δL
δg
δL
s
marginalize
invert
marginalize
invert
b
Ω
s
δg
b
Ω
δL
δL
Field-level inference
Summary stat inference
δg
b
Ω
b
Ω
s
δg
b
Ω
δL
s
marginalize
condition
marginalize
b
Ω
s
δg
b
Ω
δL
condition
Two approaches to cosmological inference
Probabilistic Programming: Wish List
⎩⎨⎧x,y,z samplesp(x,y,z)U=−logp∇U,∇2U,…
1. Modeling
e.g. in JAX with NumPyro, JaxCosmo, JaxPM, DISCO-EB...
2. DAG compilation
PPLs e.g. NumPyro, Stan, PyMC...
3. Inference
e.g. MCMC or VI with NumPyro, Stan, PyMC, emcee...
4. Extract and viz
e.g. KDE with GetDist, GMM-MI, mean(), std()...
Physics job
Stats job
Stats job
auto!
X
Z
Y
X
Z
1
2
3
4

Typical cosmological analyses
Field-level inference
Summary stat inference
δg
Ω
δg
Ω
δL
Ω
s
marginalize
invert
marginalize
invert
Ω
s
δg
Ω
δL
δL
δg
Ω
δL
δg
δL
s
Summary stat inference
Problem:
- s is too simple ⟹ lossy compression
- s is too complex ⟹ intractable marginalization
Ω
s
marginalize
invert
Ω
s
δg
Ω
δL
δg
δL
s
Field-level inference
The Problem:
- high-dimensional integral p(Ω∣δg)=∫p(Ω,δL∣δg)dδL
- To probe scales of 5 Mpc/h in DESI volume, dim(δL)≃10243
The Promise:
- "lossless" explicit inference
δg
Ω
δg
Ω
δL
marginalize
invert
δg
Ω
δL
δL
- Prior on
- Cosmology Ω
- Initial field δL
- Lagrangian galaxy biases b
(Dark matter-galaxy connection)
- LSS formation (LPT or N-body)
- Apply galaxy bias
- Redshift-Space Distorsions
- Observational noise
Field-Level Modeling





Fast and differentiable model thanks to (NumPyro and JaxPM)

MCMC sampling via Hamiltonian dynamics
Hamiltonian Monte Carlo algorithms sample target density p by simulating lightweight particle (position q, momentum p) in sparse medium:
- Particle follows Hamiltonian dynamic for some time before being kicked.
-
canonical Hamiltonian Monte Carlo (HMC) employs Newtonian HamiltonianH(q,p):=U(q)+KNewt(p)=−logp(q)+2mp2
Limitation: Particle must change of energy level to correctly sample the target
-
Micro-Canonical Hamiltonian Monte Carlo (MCHMC) employs Energy Sampling HamiltonianH(q,p):=U(q)+KESH(p)=−logp(q)+2dlogmdp2
Particle sample the target on a single energy level
MCMC sampling via Hamiltonian dynamics
-
canonical Hamiltonian Monte Carlo (HMC) samples target density p by simulating Newtonian lightweight particle (q,p) in sparse medium:
- Most of the time, particle follows Hamiltonian dynamic given by H(q,p):=U(q)+KNewt(p)=−logp(q)+21p⊤M−1p
- Sometimes particle gets kicked, momentum is refreshed p∼N(0,M)
- Ergodicity ensure particle samples canonical ensemble pC(q,p)∝e−H(q,p), and its positions follows the target p(q)=∫pC(q,p)dp
MCMC sampling via Hamiltonian dynamics
-
Micro-Canonical Hamiltonian Monte Carlo (MCHMC) samples target density p by simulating non-Newtonian particle (q,p) in sparse medium:
- Most of the time, particle follows Hamiltonian dynamic given by H(q,p):=U(q)+KESH(p)=−logp(q)+2dlogdp⊤M−1p
- Sometimes particle gets kicked, momentum is refreshed p∼N(0,M)
- Ergodicity ensure particle samples micro-canonical ensemble pC(q,p)∝δ(H(q,p)−E), and its positions follows the target p(q)=∫pMC(q,p)dp
- Hamiltonian samplers already in use for field-level (Lavaux+2018, Bayer+2023)
- We produce a consistent benchmark of these methods for galaxy clustering

MCMC sampler comparison
NUTS = auto-tuned variant of HMC
MCLMC = Langevin variant of MCHMC
10 times less model evaluations required
MCLMC scales better than NUTS/HMC in high-dimension
Model preconditioning
- Sampling is easier when posterior is isotropic Gaussian
- So we reparametrize model assuming a tractable Kaiser model:
linear growth + linear Eulerian bias + flat sky RSD + Gaussian noise

10 times less evaluations required
😊
😔


What it looks like
Infer jointly cosmological and bias parameters and initial linear matter field
Simon-Onfroy in prep.
wg=1+b1δL+b2δL2+bs2s2+b∇2∇2δL
Field-level inference of PNG
wg=1+b1δL+b2δL2+bs2s2+bϕfNLϕL+bϕ,δfNLϕLδL
ΦNG=ϕL+fNL(ϕL2−⟨ϕL2⟩)
PNG inference is a natural first application of our FL pipeline since most constraints come from large scales
(coarser mesh and simpler model)
Goal: application planned on DESI Y1 LRG and QSO.
If strategy proves effective, extend analysis to ELG Y1.

Preliminary results without light-cone and systematics
DESI Galaxy density field
- Understand bias and noise model at field-level
for typical voxel size of ≃40 Mpc/h - 2LPT and bias expansion + modified Poisson noise should be enough
- Model tested at field-level against HOD-populated AbacusSummit mocks
Field-level modeling of PNG: TODO List
DESI survey selection function
- Varying line-of-sight
- Light cone evolution
- Model for photometric systematics and integral constraints
Can generate mocks that we can process via standard analysis pipeline for apple-to-apple comparison
Eventually, we will try to extend this pipeline for application on smaller scales (equivalent to full-shape analysis)


Some useful programming tools
-
NumPyro
- Probabilistic Programming Language (PPL)
- Powered by JAX
- Integrated samplers
-
JAX
-
GPU acceleration
- Just-In-Time (JIT) compilation acceleration
- Automatic vectorization/parallelization
- Automatic differentiation
-
GPU acceleration
JAX reminder
-
GPU accelerate
- JIT compile
- Vectorize/Parallelize
- Auto-diff
import jax.numpy as np
# then enjoy
function = jax.jit(function)
# function is so fast now!
gradient = jax.grad(function)
# too bad if you love chain ruling by hand
vfunction = jax.vmap(function)
pfunction = jax.pmap(function)
# for-loops are for-loosers

Probabilistic Modeling and Compilation
def simulate(seed):
rng = np.random.RandomState(seed)
x = rng.randn()
y = rng.randn() + x**2
return x, y
from scipy.stats import norm
log_prior = lambda x: norm.logpdf(x, 0, 1)
log_lik = lambda x, y: norm.logpdf(y, x**2, 1)
log_joint= lambda x, y: log_lik(y, x) + log_prior(x)
grad_log_prior = lambda x: x**2
grad_log_lik = lambda x, y: stack([2 * x * (y - x**2), (y - x**2)])
grad_log_joint = lambda x, y: grad_log_lik(y, x) + stack([grad_log_prior(x), zeros_like(y)])
# Hessian left as an exercise...
To fully define a probabilistic model, we need
- its simulator
- its associated joint log prob
- possibly its gradients, Hessian... (useful for inference)
We should only care about the simulator!
Y
X
{X∼N(0,1)Y∣X∼N(X2,1)
my model on paper
my model on code
What if we modify to fit variances? non-Gaussian dist? Recompute everything? :'(

Modeling with Numpyro
y_sample = seed(model, 42)()
log_joint = lambda x, y: log_density(model,(),{},{'x':x, 'y':y})[0]
obs_model = condition(model, y_sample)
logp_fn = lambda x: log_density(obs_model,(),{},{'x':x})[0]
def model():
x = sample('x', dist.Normal(0, 1))
y = sample('y', dist.Normal(x**2, 1))
return y
from jax import jit, vmap, grad
force_vfn = jit(vmap(grad(logp_fn)))
- Define model as simulator... and that's it
- Simulate, get log prob
- +JAX machinery
- Condition model
- And more!

So NumPyro in practice?
def model():
x = sample('x', dist.Normal(0, 1))
y = sample('y', dist.Normal(x**2, 1))
return y
render_model(model, render_distributions=True)
y_sample = dict(y=seed(model, 42)())
obs_model = condition(model, x_sample)
logp_fn = lambda x: log_density(obs_model,(),{},{'x':x})[0]
from jax import jit, vmap, grad
force_vfn = jit(vmap(grad(logp_fn)))
kernel = infer.NUTS(obs_model)
mcmc = infer.MCMC(kernel, num_warmup, num_samples)
mcmc.run(jr.key(43))
samples = mcmc.get_samples()
- Probabilistic Programming
- JAX machinery
- Integrated samplers


Why care about differentiable model?
-
Classical MCMCs
-
agnostic random moves
+ MH acceptance step
= blinded natural selection
- small moves yield correlated samples
-
agnostic random moves
- SOTA MCMCs rely on the gradient of the model log-prob, to drive dynamic towards highest density regions


gradient descent
posterior mode
Brownian
exploding Gaussian
Langevin
posterior
+
=


Hamiltonian Monte Carlo (HMC)
- To travel farther, add inertia.
- sample particle at position q now have momentum p and mass matrix M
- target p(q) becomes p(q,p):=e−H(q,p), with Hamiltonian H(q,p):=−logp(q)+21p⊤M−1p
- at each step, resample momentum p∼N(0,M)
- let (q,p) follow the Hamiltonian dynamic during time length L, then arrival becomes new MH proposal.

Variations around HMC
-
No U-Turn Sampler (NUTS)
- trajectory length L auto-tuned
- samples drawn along trajectory
- NUTSGibbs i.e. alternating sampling over parameter subsets.
Why care about differentiable model?
Variations around HMC
-
No U-Turn Sampler (NUTS)
- trajectory length auto-tuned
- samples drawn along trajectory
- NUTSGibbs i.e. alternating sampling over parameter subsets
-
Model gradient drives sample particles towards high density regions
-
Hamiltonian Monte Carlo (HMC):
to travel farther, add inertia
- Yields less correlated chains

3) Hamiltonian dynamic
1) mass M particle at q
2) random kick p
2) random kick p
1) mass M particle at q
3) Hamiltonian dynamic
Why care about differentiable model?
-
Model gradient drives sample particles towards high density regions
-
Hamiltonian Monte Carlo (HMC):
to travel farther, add inertia
- Yields less correlated chains

3) Hamiltonian dynamic
1) mass M particle at q
2) random kick p
2) random kick p
1) mass M particle at q
3) Hamiltonian dynamic


Inference with NumPyro
def model():
x = sample('x', dist.Normal(0, 1))
y = sample('y', dist.Normal(x**2, 1))
return y
render_model(model, render_distributions=True)
y_sample = dict(y=seed(model, 42)())
obs_model = condition(model, x_sample)
logp_fn = lambda x: log_density(obs_model,(),{},{'x':x})[0]
from jax import jit, vmap, grad
force_vfn = jit(vmap(grad(logp_fn)))
kernel = infer.NUTS(obs_model)
mcmc = infer.MCMC(kernel, num_warmup, num_samples)
mcmc.run(jr.key(43))
samples = mcmc.get_samples()
- Probabilistic Model
- +JAX machinery
- = gradient-based MCMC


How to N-body-differentiate?







(q,p)
δ(x)
δ(k)
paint*
read*
fft*
ifft*
fft*
*: differentiable, e.g. with , see JaxPM
apply forces
to move particles
solve Vlasov-Poisson
to compute forces
- At large scales, matter density field almost Gaussian so power spectrum is almost lossless compression
- At smaller scales however, matter density field non-Gaussian
Gaussianity and beyond

2 fields, 1 power spectrum: Gaussian or N-body?
-
Effective Sample Size (ESS)
- number of i.i.d. samples that yield same statistical power.
- For sample sequence of size N and autocorrelation ρ Neff=1+2∑t=1+∞ρtNso aim for as less correlated sample as possible.
- Main limiting computational factor is model evaluation (e.g. N-body), so characterize MCMC efficiency by Neval/Neff
How to compare samplers?
2024PNGmeeting
By hsimonfroy
2024PNGmeeting
- 67