Hugo SIMON,
PhD student supervised by
Arnaud DE MATTIA and François LANUSSE
CoBALt, 2025/06/30
Hugo SIMON,
PhD student supervised by
Arnaud DE MATTIA and François LANUSSE
CoBALt, 2025/06/30
Hugo SIMON,
PhD student supervised by
Arnaud DE MATTIA and François LANUSSE
Sesto, 2025/07/17
Hugo SIMON,
PhD student supervised by
Arnaud DE MATTIA and François LANUSSE
CSI, 2025/09/02
Hugo SIMON,
PhD student supervised by
Arnaud DE MATTIA and François LANUSSE
PNG Meeting, 2025/06/18
$$\frac{H}{H_0} = \sqrt{\Omega_r + \Omega_b + \Omega_c+ \Omega_\kappa + \Omega_\Lambda}$$
instantaneous expansion rate
energy content
Cosmological principle + Einstein equation
+ Inflation
\(\delta_L \sim \mathcal G(0, \mathcal P)\)
\(\sigma_8:= \sigma[\delta_L * \boldsymbol 1_{r \leq 8}]\)
initial field
primordial power spectrum
std. of fluctuations smoothed at \(8 \text{ Mpc}/h\)
\(\Omega := \{ \Omega_m, \Omega_\Lambda, H_0, \sigma_8, f_\mathrm{NL},...\}\)
Linear matter spectrum
Structure growth
\(\Omega\)
\(\delta_L\)
\(\delta_g\)
inference
\(\Omega := \{ \Omega_m, \Omega_\Lambda, H_0, \sigma_8, f_\mathrm{NL},...\}\)
inference
\(P\)
\(\Omega\)
\(\delta_L\)
\(\delta_g\)
\(\Omega := \{ \Omega_m, \Omega_\Lambda, H_0, \sigma_8, f_\mathrm{NL},...\}\)
\(\Omega\)
\(\delta_L\)
\(\delta_g\)
inference
inference
\(\Omega := \{ \Omega_m, \Omega_\Lambda, H_0, \sigma_8, f_\mathrm{NL},...\}\)
\(\Omega\)
\(\delta_L\)
\(\delta_g\)
\(P\)
\(\Omega := \{ \Omega_m, \Omega_\Lambda, H_0, \sigma_8, f_\mathrm{NL},...\}\)
\(\Omega\)
\(\delta_L\)
\(\delta_g\)
inference
\(128^3\) PM on 8GPU:
4h MCLMC vs. \(\geq\) 80h HMC
Fast & differentiable model with
model \(=\begin{cases}x \sim \mathcal N(0,1)\\y\mid x \sim \mathcal N(x^3, 1)\end{cases}\)
Among all possible worlds \(x,y\), restrict to the ones compatible with observation \(y_0\):
$$\underbrace{\mathrm p(x \mid y_0)}_{\text{posterior}} = \frac{\overbrace{\mathrm p(y_0 \mid x)}^{\text{likelihood}}}{\underbrace{\mathrm p(y_0)}_{\text{evidence}}}\underbrace{\mathrm p(x)}_{\text{prior}}$$
\(x\)
\(y\)
condition
\(x\)
\(y\)
\(\mathrm{p}(x,y)\)
\(\mathrm{p}(x \mid y_0)\)
Numerically:
Model: an observation generator, link between latent and observed variables
\(F\)
\(C\)
\(D\)
\(E\)
\(A\)
\(B\)
\(G\)
Frequentist:
fixed but unknown parameters
Bayesian:
random because unknown
\(A\)
\(B\)
\(C\)
\(F\)
\(D\)
\(E\)
\(G\)
Represented as a DAG which expresses a joint probability factorization, i.e. conditional dependencies between variables
latent param
latent variable
observed variable
evidence/prior predictive: $$\mathrm{p}(y) := \int \mathrm{p}(y \mid x) \mathrm{p}(x)\mathrm d x$$
posterior predictive: $$\mathrm{p}(y_1 \mid y_0) := \int \mathrm{p}(y_1 \mid x) \mathrm{p}(x \mid y_0)\mathrm d x$$
\(y\)
\(x\)
Bayes thm
\(\sim\)
\(\underbrace{\mathrm{p}(x \mid y)}_{\text{posterior}}\underbrace{\mathrm{p}(y)}_{\text{evidence}} = \underbrace{\mathrm{p}(y \mid x)}_{\text{likelihood}} \,\underbrace{\mathrm{p}(x)}_{\text{prior}}\)
\(y\)
\(x\)
Marginalization:
\(y\)
\(x\)
\(z\)
\(z\)
\(x\)
\(y\)
\(x\)
\(y\)
\(x\)
Conditioning:
\(x\)
\(z\)
Marginalization:
\(y\)
\(x\)
Conditioning:
evidence/prior predictive: $$\mathrm{p}(y) := \int \mathrm{p}(y \mid x) \mathrm{p}(x)\mathrm d x$$
posterior predictive: $$\mathrm{p}(y_1 \mid y_0) := \int \mathrm{p}(y_1 \mid x) \mathrm{p}(x \mid y_0)\mathrm d x$$
\(y\)
\(x\)
Bayes thm
\(\sim\)
\(\underbrace{\mathrm{p}(x \mid y)}_{\text{posterior}}\underbrace{\mathrm{p}(y)}_{\text{evidence}} = \underbrace{\mathrm{p}(y \mid x)}_{\text{likelihood}} \,\underbrace{\mathrm{p}(x)}_{\text{prior}}\)
\(y\)
\(x\)
\(s\)
\(\Omega\)
\(\begin{cases}x,y,z \text{ samples}\\\mathrm{p}(x,y,z)\\ U= -\log \mathrm{p}\\\nabla U,\nabla^2 U,\dots\end{cases}\)
1. Modeling
e.g. in JAX with \(\texttt{NumPyro}\), \(\texttt{JaxCosmo}\), \(\texttt{JaxPM}\), \(\texttt{DISCO-EB}\)...
2. DAG compilation
PPLs e.g. \(\texttt{NumPyro}\), \(\texttt{Stan}\), \(\texttt{PyMC}\)...
3. Inference
e.g. MCMC or VI with \(\texttt{NumPyro}\), \(\texttt{Stan}\), \(\texttt{PyMC}\), \(\texttt{emcee}\)...
4. Extract and viz
e.g. KDE with \(\texttt{GetDist}\), \(\texttt{GMM-MI}\), \(\texttt{mean()}\), \(\texttt{std()}\)...
Physics job
Stats job
Stats job
auto!
\(x\)
\(z\)
\(y\)
\(x\)
\(y\)
1
2
3
4
Field-level inference
Summary stat inference
\(\Omega\)
\(s\)
\(\delta_g\)
\(\Omega\)
\(\delta_L\)
\(s\)
marginalize
condition
marginalize
\(\Omega\)
\(s\)
\(\delta_g\)
\(\Omega\)
\(\delta_L\)
condition
Cosmo model
\(\mathrm{p}(\Omega,s)\)
\(\mathrm{p}(\Omega \mid s)\)
\(\Omega\)
\(\delta_g\)
\(\mathrm{p}(\Omega,\delta_L,\delta_g, s)= \mathrm{p}(s \mid \delta_g) \, \mathrm{p}(\delta_g \mid \Omega,\delta_L)\, \mathrm{p}(\delta_L \mid \Omega)\, \mathrm{p}(\Omega)\)
\(\mathrm{p}(\Omega,\delta_L \mid \delta_g)\)
\(\mathrm{p}(\Omega \mid \delta_g)\)
\(\delta_g\)
\(\Omega\)
\(\delta_L\)
\(s\)
Cosmo model
Problem:
The Problem:
The Promise:
Field-level inference
Summary stat inference
$$0-$$
$$\mathrm H(\delta_g)-$$
$$0-$$
$$\mathrm H(\delta_g)-$$
$$\boldsymbol{H}(X\mid Y_1)$$
$$\boldsymbol{H}(X)$$
$$\boldsymbol{H}( Y_1)$$
$$\boldsymbol{I}(X; Y_1)$$
$$\boldsymbol{H}( Y_2)$$
$$\boldsymbol{I}(X\mid Y_1; Y_2)$$
$$\boldsymbol{H}(X\mid Y_1,Y_2)$$
\(\boldsymbol{H}(X)\) = missing information on \(X\) = amount of bits to communicate \(X\)
$$\boldsymbol{H}(X\mid Y_1)$$
$$\begin{align*}\operatorname{\boldsymbol{H}}(X\mid Y) &= \boldsymbol{H}(Y \mid X) + \boldsymbol{H}(X) - \boldsymbol{H}(Y)\\&= \boldsymbol{H}(X) - \boldsymbol{I}(X;Y) \leq \boldsymbol{H}(X)\end{align*}$$
$$\boxed{\min_s \operatorname{\mathrm{H}}(\Omega\mid s(\delta_g))} = \mathrm{H}(\Omega) - \max_s \mathrm{I}(\Omega ; s(\delta_g))$$
$$\mathrm{H}(\Omega)$$
$$\mathrm{H}(\delta_g)$$
$$\mathrm{H}(\mathcal s_1)$$
$$\mathrm{H}(\mathcal s_2)$$
$$\mathrm{H}(\mathcal P)$$
non-Gaussianities
relevant stat
(low info but high mutual info)
irrelevant stat
(high info but low mutual info)
also a relevant stat
(high info and mutual info)
Which stats are relevant for cosmo inference?
Fast and differentiable model thanks to (\(\texttt{NumPyro}\) and \(\texttt{JaxPM}\))
Fast and differentiable model thanks to (\(\texttt{NumPyro}\) and \(\texttt{JaxPM}\))
gradients,
they make me:
\((\boldsymbol q, \boldsymbol p)\)
\(\delta(\boldsymbol x)\)
\(\delta(\boldsymbol k)\)
paint*
read*
fft*
ifft*
fft*
*: differentiable, e.g. with via \(\texttt{JaxPM}\), in \(\mathcal O(n \log n)\)
apply forces
to move particles
solve Vlasov-Poisson
to compute forces
\(\begin{cases}\dot {\boldsymbol q} \propto \boldsymbol p\\ \dot{\boldsymbol p} = \boldsymbol f \end{cases}\)
\(\begin{cases}\nabla^2 \phi \propto \delta\\ \boldsymbol f = -\nabla \phi \end{cases} \implies \boldsymbol f \propto \frac{i\boldsymbol k}{k^2} \delta\)
Radial mass profile in \(\mathrm{cMph}\) of initially point-like overdensity at origin
(Infinite Impulse Response)
\(z=0.3\)
$$\begin{cases}\frac{\Delta \theta^\mathrm{fid}}{\Delta \theta} = \frac{D_A}{D_A^\mathrm{fid}}\frac{r_d^\mathrm{fid}}{r_d}=: \alpha_\perp\\\frac{\Delta z^\mathrm{fid}}{\Delta z} = \frac{D_H}{D_H^\mathrm{fid}}\frac{r_d^\mathrm{fid}}{r_d}=: \alpha_\parallel\end{cases}$$
\(r_d\)
\(r_d\)
\(\Delta \theta\)
\(\Delta z\)
\(D_A(z)\)
\(\Delta \theta = \frac{r_d}{D_A}\)
Measuring \(\alpha_\perp(z)\) and \(\alpha_\parallel(z)\) for multiple \(z\) constraints \(\Omega_m\) and \(H_0 r_d\)
\(\alpha_\mathrm{iso}\)
\(\alpha_\mathrm{AP}\)
\(\alpha_\parallel\)
\(\alpha_\perp\)
$$\begin{cases}\frac{\Delta \theta^\mathrm{fid}}{\Delta \theta} = \frac{D_A}{D_A^\mathrm{fid}}\frac{r_d^\mathrm{fid}}{r_d}=: \alpha_\perp\\\frac{\Delta z^\mathrm{fid}}{\Delta z} = \frac{D_H}{D_H^\mathrm{fid}}\frac{r_d^\mathrm{fid}}{r_d}=: \alpha_\parallel\end{cases}$$
Pictural BAO pre-recon
Pictural BAO post-recon
HMC (e.g. Neal2011)
Inferring jointly cosmology, bias parameters, and initial matter field
𝓐 𝓭𝓻𝓾𝓷𝓴 𝓶𝓪𝓷 𝔀𝓲𝓵𝓵 𝓯𝓲𝓷𝓭 𝓱𝓲𝓼 𝔀𝓪𝔂 𝓱𝓸𝓶𝓮, 𝓫𝓾𝓽 𝓪 𝓭𝓻𝓾𝓷𝓴 𝓫𝓲𝓻𝓭 𝓶𝓪𝔂 𝓰𝓮𝓽 𝓵𝓸𝓼𝓽 𝓯𝓸𝓻𝓮𝓿𝓮𝓻 (\(\mathrm p \approx 0.66\))
🌸 𝓢𝓱𝓲𝔃𝓾𝓸 𝓚𝓪𝓴𝓾𝓽𝓪𝓷𝓲
\(-\nabla\)
\(d \approx 1\)
🏠
🚶♀️
To maintain constant move-away probability, step-size \(\simeq d^{-1/2}\)
\(d \gg 1\)
🪺
🐦
Recipe😋 to sample from \(\mathrm p \propto e^{-U}\)
gradient guides particle toward high density sets
scales poorly with dimension
must average over all energy levels
Hamiltonian Monte Carlo (e.g. Neal2011)
Recipe😋 to sample from \(\mathrm p \propto e^{-U}\)
single energy/speed level
let's try avoiding that
gradient guides particle toward high density sets
MicroCanonical HMC (Robnik+2022)
Hamiltonian Monte Carlo (e.g. Neal2011)
MicroCanonical HMC (Robnik+2022)
Inferring jointly cosmology, bias parameters, and initial matter field allows full universe history reconstruction
= NUTS within Gibbs
= auto-tuned HMC
= adjusted MCHMC
= unadjusted Langevin MCHMC
10 times less evaluations required
Unadjusted microcanonical sampler outperforms any adjusted sampler
10 times less evaluations required
\(128^3\) PM on 8GPU:
4h MCLMC vs. \(\geq\)80h NUTS
Mildly dependent with respect to formation model and volume
Probing smaller scales could be harder
reducing stepsize rapidly brings bias under Monte Carlo error
Local-type PNG is constrained by the induced scale-dependent bias
\(\phi_{\mathrm{NL}}=\phi+{\color{purple}f_{\mathrm{NL}}}\phi^{2}\)
\(\delta(\boldsymbol k)\simeq\left(b_{1}+ b_\phi {\color{purple}f_\mathrm{NL}}k^{-2} \right) \delta_L(\boldsymbol k)\)
$$\begin{align*}w_g&=1+{\color{purple}b_{1}}\,\delta_{\rm L}+{\color{purple}b_{2}}\delta_{\rm L}^{2}+{\color{purple}b_{s^2}}s^{2}+ {\color{purple}b_{\nabla^2}} \nabla^2 \delta _{\rm L}\\&\quad\quad\! + {\color{purple}b_\phi f_{\rm NL}} \phi + {\color{purple} b_{\phi\delta} f_{\rm NL}} \phi \delta_{\rm L}\\\Delta \boldsymbol q_\parallel &= H^{-1} \dot{\boldsymbol q}_\parallel + {\color{purple}b_{\nabla_\parallel}} \nabla_\parallel \delta_\mathrm{L}\end{align*}$$
\(\phi_{\mathrm{NL}}=\phi+{\color{purple}f_{\mathrm{NL}}}\phi^{2}\)
\(\boldsymbol q_\mathrm{LPT} \simeq \boldsymbol q_\mathrm{in} + \Psi_\mathrm{LPT}(\boldsymbol q_\mathrm{in}, z(\boldsymbol q_\mathrm{in}))\)
one-shot 2LPT light-cone
\(n_g^\mathrm{obs}(\boldsymbol q) = (1+\delta_g(\boldsymbol q))\, {\color{purple}\bar n_g(\,r)}\, {\color{blue}W(\boldsymbol q)}\, {\color{purple}\beta_i} {\color{green}T^i(\theta)}\)
RIC relax + selection + imag. templates
\(\delta_g \sim \mathcal N(\delta_g^\mathrm{det}, \sigma^2)\) with
\(\sigma(k) = {\color{purple}\sigma_0}(1+{\color{purple}\sigma_2} k^2 + {\color{purple}\sigma_{\mu2}}(k\mu)^2)\)
EFT-based modeling, many scale cuts alleviating discretization effects (see Stadler+2024)
Radial Integral Constraint
\(\delta_g \propto n_g - \braket{n_g}\approx n_g - \bar n_g(r)\)
i.e. impose \(\bar \delta_g(r) = 0\)
Global Integral Constraint
\(\delta_g \propto n_g - \braket{n_g} \approx n_g - \bar n_g\)
i.e. impose \(\bar \delta_g = 0\)
PRELIMINARY\(k_\mathrm{max} \approx 0.04\, [h/\mathrm{Mpc}]\)
\(\sigma[f_\mathrm{NL}] \approx 20\), consistent with power spectrum analysis (Chaussidon+2024)
Next steps:
PRELIMINARY RESULTS\(\alpha_\mathrm{iso}\)/\(\Omega_m\) to well-constrained with AP
AP modifies simulated tracer density, that we in practice don't know
w/o Jacobian correction w/ Jacobian correction
r [Mpc/h]
r [Mpc/h]
n(r)
n(r)
Did not solved \(\alpha_\mathrm{iso}\) constraints. Probably discretization artificats:
Next steps:
import jax.numpy as np
# then enjoyfunction = jax.jit(function)
# function is so fast now!gradient = jax.grad(function)
# too bad if you love chain ruling by handvfunction = jax.vmap(function)
pfunction = jax.pmap(function)
# for-loops are for-loosersdef simulate(seed):
rng = np.random.RandomState(seed)
x = rng.randn()
y = rng.randn() + x**2
return x, yfrom scipy.stats import norm
log_prior = lambda x: norm.logpdf(x, 0, 1)
log_lik = lambda x, y: norm.logpdf(y, x**2, 1)
log_joint= lambda x, y: log_lik(y, x) + log_prior(x)grad_log_prior = lambda x: x**2
grad_log_lik = lambda x, y: stack([2 * x * (y - x**2), (y - x**2)])
grad_log_joint = lambda x, y: grad_log_lik(y, x) + stack([grad_log_prior(x), zeros_like(y)])
# Hessian left as an exercise...To fully define a probabilistic model, we need
We should only care about the simulator!
\(Y\)
\(X\)
$$\begin{cases}X \sim \mathcal N(0,1)\\Y\mid X \sim \mathcal N(X^2, 1)\end{cases}$$
my model on paper
my model on code
What if we modify to fit variances? non-Gaussian dist? Recompute everything? :'(
def model():
x = sample('x', dist.Normal(0, 1))
y = sample('y', dist.Normal(x**3, 1))
return y
render_model(model, render_distributions=True)
y0 = dict(y=seed(model, 42)())
obs_model = condition(model, y0)
logd_fn = lambda x: log_density(obs_model,(),{},{'x':x})[0]from jax import jit, vmap, grad
force_vfn = jit(vmap(grad(logd_fn)))kernel = infer.NUTS(obs_model)
mcmc = infer.MCMC(kernel, n_warmup, n_samples)
mcmc.run(jr.key(43))
samples = mcmc.get_samples()y_sample = seed(model, 42)()
log_joint = lambda x, y: log_density(model,(),{},{'x':x, 'y':y})[0]obs_model = condition(model, y_sample)
logp_fn = lambda x: log_density(obs_model,(),{},{'x':x})[0]def model():
x = sample('x', dist.Normal(0, 1))
y = sample('y', dist.Normal(x**2, 1))
return yfrom jax import jit, vmap, grad
force_vfn = jit(vmap(grad(logp_fn)))gradient descent
posterior mode
Brownian
exploding Gaussian
Langevin
posterior
+
=
Variations around HMC
Variations around HMC
3) Hamiltonian dynamic
1) mass \(M\) particle at \(q\)
2) random kick \(p\)
2) random kick \(p\)
1) mass \(M\) particle at \(q\)
3) Hamiltonian dynamic
Thank you!