2021 July DESC Collaboration Meeting
Francois Lanusse (@EiffL), for the jax-cosmo contributors
And more (not yet in the picture because I'm lagging behind on PR merging)
- Jean-Eric Campagne
- Joe Zuntz
- Tilman Troester
- Ben Horowitz
import jax
import jax.numpy as np
import jax_cosmo as jc
# .... define probes, and load a data vector
def gaussian_likelihood( theta ):
# Build the cosmology for given parameters
cosmo = jc.Planck15(Omega_c=theta[0], sigma8=theta[1])
# Compute mean and covariance
mu, cov = jc.angular_cl.gaussian_cl_covariance_and_mean(cosmo,
ell, probes)
# returns likelihood of data under model
return jc.likelihood.gaussian_likelihood(data, mu, cov)
# Fisher matrix in just one line:
F = - jax.hessian(gaussian_likelihood)(theta)
DES Y1 posterior, jax-cosmo HMC vs Cobaya MH
(credit: Joe Zuntz)
def log_posterior( theta ):
return gaussian_likelihood( theta ) + log_prior(theta)
score = jax.grad(log_posterior)(theta)
Given (g)riz photometry, find a tomographic bin assignment method that optimizes a 3x2pt analysis.
Makinen et al. 2021, arXiv:2107.07405
$ pip install mylibrary
=> a JAX-based library will meet all of these criteria
import jax.numpy as np
m = np.eye(10) # Some matrix
def my_func(x):
return m.dot(x).sum()
x = np.linspace(0,1,10)
y = my_func(x)
from jax import grad
df_dx = grad(my_func)
y = df_dx(x)
def pure_fun(x):
return 2 * x**2 + 3 * x + 2
def impure_fun_side_effect(x):
print('I am a side effect')
return 2 * x**2 + 3 * x + 2
C = 10. # A global variable
def impure_fun_uses_globals(x):
return 2 * x**2 + 3 * x + C
# Decorator for jitting
@jax.jit
def my_fun(W, x):
return W.dot(x)
# or as an explicit transformation
my_fun_jitted = jax.jit(my_fun)
def f(x):
return 2 * x**2 + 3 *x + 2
df_dx = jax.grad(f)
jac = jax.jacobian(f)
hess = jax.hessian(f)
# As decorator
@jax.vmap
def f(x):
return 2 * x**2 + 3 *x + 2
# Can be composed
df_dx = jax.jit(jax.vmap(jax.grad(f)))
But more seriously, happy to chat anytime if you are interested in contributing, or just want to know more about JAX :-)