Hugo SIMON-ONFROY,
PhD student supervised by
Arnaud DE MATTIA and François LANUSSE
CoPhy, 2024/11/19
$$\frac{H}{H_0} = \sqrt{\Omega_r + \Omega_b + \Omega_c+ \Omega_\kappa + \Omega_\Lambda}$$
instantaneous expansion rate
energy content
Cosmological principle + Einstein equation
+ Inflation
\(\delta_L \sim \mathcal G(0, \mathcal P)\)
\(\sigma_8:= \sigma[\delta_L * \boldsymbol 1_{r \leq 8}]\)
initial field
primordial power spectrum
std. of fluctuations smoothed at \(8 \text{ Mpc/h}\)
\(\Omega := \{ \Omega_c, \Omega_b, \Omega_\Lambda, H_0, \sigma_8, n_s,...\}\)
Linear matter spectrum
Structure growth
\(\Omega\)
\(\delta_L\)
\(\delta_g\)
inference
Model: an observation generator, link between latent and observed variables
\(F\)
\(C\)
\(D\)
\(E\)
\(A\)
\(B\)
\(G\)
Frequentist:
fixed but unknown parameters
Bayesian:
random because unknown
\(A\)
\(B\)
\(C\)
\(F\)
\(D\)
\(E\)
\(G\)
Represented as a DAG which expresses a joint probability factorization, i.e. conditional dependencies between variables
latent param
latent variable
observed variable
evidence/prior predictive: $$\boldsymbol{p}(y) := \int \boldsymbol{p}(y \mid x) \boldsymbol{p}(x)d x$$
posterior predictive: $$\boldsymbol{p}(y_1 \mid y_0) := \int \boldsymbol{p}(y_1 \mid x) \boldsymbol{p}(x \mid y_0)d x$$
\(Y\)
\(X\)
Bayes thm
\(\sim\)
\(\underbrace{\boldsymbol{p}(x \mid y)}_{\text{posterior}}\underbrace{\boldsymbol{p}(y)}_{\text{evidence}} = \underbrace{\boldsymbol{p}(y \mid x)}_{\text{likelihood}} \,\underbrace{\boldsymbol{p}(x)}_{\text{prior}}\)
\(Y\)
\(X\)
Marginalization:
\(Y\)
\(X\)
\(Z\)
\(Z\)
\(X\)
\(Y\)
\(X\)
Conditioning/Bayes inversion:
\(Y\)
\(X\)
evidence/prior predictive: $$\boldsymbol{p}(y) := \int \boldsymbol{p}(y \mid x) \boldsymbol{p}(x)d x$$
posterior predictive: $$\boldsymbol{p}(y_1 \mid y_0) := \int \boldsymbol{p}(y_1 \mid x) \boldsymbol{p}(x \mid y_0)d x$$
\(Y\)
\(X\)
Conditioning/Bayes inversion:
\(X\)
\(Z\)
Marginalization:
\(Y\)
\(X\)
Bayes thm
\(\sim\)
\(\underbrace{\boldsymbol{p}(x \mid y)}_{\text{posterior}}\underbrace{\boldsymbol{p}(y)}_{\text{evidence}} = \underbrace{\boldsymbol{p}(y \mid x)}_{\text{likelihood}} \,\underbrace{\boldsymbol{p}(x)}_{\text{prior}}\)
\(Y\)
\(X\)
\(s\)
\(\Omega\)
\(\begin{cases}x,y,z \text{ samples}\\\boldsymbol{p}(x,y,z)\\\Phi= -\log \boldsymbol{p}\\\nabla\Phi,\nabla^2 \Phi,\dots\end{cases}\)
1. Modeling
e.g. in JAX with \(\texttt{NumPyro}\), \(\texttt{JaxCosmo}\), \(\texttt{JaxPM}\), \(\texttt{DISCO-EB}\)...
2. DAG compilation
PPLs e.g. \(\texttt{NumPyro}\), \(\texttt{Stan}\), \(\texttt{PyMC}\)...
3. Inference
e.g. MCMC or VI with \(\texttt{NumPyro}\), \(\texttt{Stan}\), \(\texttt{PyMC}\), \(\texttt{emcee}\)...
4. Extract and viz
e.g. KDE with \(\texttt{GetDist}\), \(\texttt{GMM-MI}\), \(\texttt{mean()}\), \(\texttt{std()}\)...
Physics job
Stats job
Stats job
auto!
\(X\)
\(Z\)
\(Y\)
\(X\)
\(Z\)
1
2
3
4
import jax.numpy as np
# then enjoy
function = jax.jit(function)
# function is so fast now!
gradient = jax.grad(function)
# too bad if you love chain ruling by hand
vfunction = jax.vmap(function)
pfunction = jax.pmap(function)
# for-loops are for-loosers
def simulate(seed):
rng = np.random.RandomState(seed)
x = rng.randn()
y = rng.randn() + x**2
return x, y
from scipy.stats import norm
log_prior = lambda x: norm.logpdf(x, 0, 1)
log_lik = lambda x, y: norm.logpdf(y, x**2, 1)
log_joint= lambda x, y: log_lik(y, x) + log_prior(x)
grad_log_prior = lambda x: x**2
grad_log_lik = lambda x, y: stack([2 * x * (y - x**2), (y - x**2)])
grad_log_joint = lambda x, y: grad_log_lik(y, x) + stack([grad_log_prior(x), zeros_like(y)])
# Hessian left as an exercise...
To fully define a probabilistic model, we need
We should only care about the simulator!
\(Y\)
\(X\)
$$\begin{cases}X \sim \mathcal N(0,1)\\Y\mid X \sim \mathcal N(X^2, 1)\end{cases}$$
my model on paper
my model on code
What if we modify to fit variances? non-Gaussian dist? Recompute everything? :'(
y_sample = seed(model, 42)()
log_joint = lambda x, y: log_density(model,(),{},{'x':x, 'y':y})[0]
obs_model = condition(model, y_sample)
logp_fn = lambda x: log_density(obs_model,(),{},{'x':x})[0]
def model():
x = sample('x', dist.Normal(0, 1))
y = sample('y', dist.Normal(x**2, 1))
return y
from jax import jit, vmap, grad
force_vfn = jit(vmap(grad(logp_fn)))
def model():
x = sample('x', dist.Normal(0, 1))
y = sample('y', dist.Normal(x**2, 1))
return y
render_model(model, render_distributions=True)
y_sample = dict(y=seed(model, 42)())
obs_model = condition(model, x_sample)
logp_fn = lambda x: log_density(obs_model,(),{},{'x':x})[0]
from jax import jit, vmap, grad
force_vfn = jit(vmap(grad(logp_fn)))
kernel = infer.NUTS(obs_model)
mcmc = infer.MCMC(kernel, num_warmup, num_samples)
mcmc.run(jr.key(43))
samples = mcmc.get_samples()
gradient descent
posterior mode
Brownian
exploding Gaussian
Langevin
posterior
+
=
Variations around HMC
Variations around HMC
3) Hamiltonian dynamic
1) mass \(M\) particle at \(q\)
2) random kick \(p\)
2) random kick \(p\)
1) mass \(M\) particle at \(q\)
3) Hamiltonian dynamic
3) Hamiltonian dynamic
1) mass \(M\) particle at \(q\)
2) random kick \(p\)
2) random kick \(p\)
1) mass \(M\) particle at \(q\)
3) Hamiltonian dynamic
def model():
x = sample('x', dist.Normal(0, 1))
y = sample('y', dist.Normal(x**2, 1))
return y
render_model(model, render_distributions=True)
y_sample = dict(y=seed(model, 42)())
obs_model = condition(model, x_sample)
logp_fn = lambda x: log_density(obs_model,(),{},{'x':x})[0]
from jax import jit, vmap, grad
force_vfn = jit(vmap(grad(logp_fn)))
kernel = infer.NUTS(obs_model)
mcmc = infer.MCMC(kernel, num_warmup, num_samples)
mcmc.run(jr.key(43))
samples = mcmc.get_samples()
\((q, p)\)
\(\delta(x)\)
\(\delta(k)\)
paint*
read*
fft*
ifft*
fft*
*: differentiable, e.g. with , see JaxPM
apply forces
to move particles
solve Vlasov-Poisson
to compute forces
Which sampling methods can scale to high dimensions?
2 fields, 1 power spectrum: Gaussian or N-body?