jax-cosmo: Finally an Automatically Differentiable Cosmology Library!
2021 July DESC Collaboration Meeting
Francois Lanusse (@EiffL), for the jax-cosmo contributors

And more (not yet in the picture because I'm lagging behind on PR merging)
- Jean-Eric Campagne
- Joe Zuntz
- Tilman Troester
- Ben Horowitz
do you want
any of this?
Fisher Forecasts become exact and instantaneous

import jax
import jax.numpy as np
import jax_cosmo as jc
# .... define probes, and load a data vector
def gaussian_likelihood( theta ):
# Build the cosmology for given parameters
cosmo = jc.Planck15(Omega_c=theta[0], sigma8=theta[1])
# Compute mean and covariance
mu, cov = jc.angular_cl.gaussian_cl_covariance_and_mean(cosmo,
ell, probes)
# returns likelihood of data under model
return jc.likelihood.gaussian_likelihood(data, mu, cov)
# Fisher matrix in just one line:
F = - jax.hessian(gaussian_likelihood)(theta)
- No derivatives were harmed by finite differences in the computation of this Fisher!
- You can scale to thousands of input parameters, no issues of numerical precision.
Inference becomes fast and scalable
- Gradients of the log posterior are required for modern efficient and scalable inference techniques:
- Variational Inference
- Hamiltonian Monte-Carlo
- With gradients, you can scale to thousands of parameters

DES Y1 posterior, jax-cosmo HMC vs Cobaya MH
(credit: Joe Zuntz)
def log_posterior( theta ):
return gaussian_likelihood( theta ) + log_prior(theta)
score = jax.grad(log_posterior)(theta)
You can differentiate through Cosmology
Given (g)riz photometry, find a tomographic bin assignment method that optimizes a 3x2pt analysis.
- Strategy with differentiable physics:
- Introduce a parametric "bin assignment function"
- Optimize this function to maximize DETF FoM

- Differentiable Physics for Score Compression
Makinen et al. 2021, arXiv:2107.07405

What Framework Should you Adopt?

Some top-level considerations
[Low Barrier to Entry] You don't want contributors to have to learn an entirely new language to start contributing (I'm looking at you Julia!).
[Maintainability] You want readable code, without several layers of implementation (I'm looking at you GalSim and CCL!)
[Trivially installable] This should always work:
- [Performance] Duh! But more precisely, you want to be able to take advantage of next gen platforms like Perlmutter. Code should work 100% the same on CPU or GPU for the user.
$ pip install mylibrary
=> a JAX-based library will meet all of these criteria

JAX: NumPy + Autograd + XLA
- JAX uses the NumPy API
=> You can copy/paste existing code, works pretty much out of the box
- JAX is a successor of autograd
=> You can transform any function to get forward or backward automatic derivatives (grad, jacobians, hessians, etc)
- JAX uses XLA as a backend
=> Same framework as used by TensorFlow (supports CPU, GPU, TPU execution)
import jax.numpy as np
m = np.eye(10) # Some matrix
def my_func(x):
return m.dot(x).sum()
x = np.linspace(0,1,10)
y = my_func(x)
from jax import grad
df_dx = grad(my_func)
y = df_dx(x)

Pure functions
=> JAX is designed for functional programming: your code should be built around pure functions (no side effects)- Enables caching functions as XLA expressions
- Enables JAX's extremely powerful concept of composable function transformations
Just in Time (jit) Compilation
=> jax.jit() transformation will compile an entire function as XLA, which then runs in one go on GPU.
Arbitrary order forward and backward autodiff
=> jax.grad() transformation will apply the d/dx operator to a function f
=> jax.vmap() transformation will add a batch dimension to the input/output of any function
def pure_fun(x):
return 2 * x**2 + 3 * x + 2
def impure_fun_side_effect(x):
print('I am a side effect')
return 2 * x**2 + 3 * x + 2
C = 10. # A global variable
def impure_fun_uses_globals(x):
return 2 * x**2 + 3 * x + C
# Decorator for jitting
def my_fun(W, x):
return W.dot(x)
# or as an explicit transformation
my_fun_jitted = jax.jit(my_fun)
def f(x):
return 2 * x**2 + 3 *x + 2
df_dx = jax.grad(f)
jac = jax.jacobian(f)
hess = jax.hessian(f)
# As decorator
def f(x):
return 2 * x**2 + 3 *x + 2
# Can be composed
df_dx = jax.jit(jax.vmap(jax.grad(f)))
- Cosmology library embracing JAX computing model
=> Autodiff, GPU acceleration, automated batching
- MIT licensed, collaboratively developed
=> Follows the all-contributors guidelines
- Inspired by CCL's tracer mechanism
=> See design document
- Unit tested against CCL

Status of Implementation
- Certainly not feature-complete but approximately DES Y1 3x2pt capable
- Power spectrum:
- Eisenstein-Hu Power
- halofit
- Tracers:
- Lensing (including NLA, multiplicative bias)
- NumberCounts
- 2pt functions
- Limber, harmonic space
- Gaussian covariance tools
- Power spectrum:
- These were the main features needed for the DESC 3x2pt Tomo Challenge
- Took about a week to implement, another week to tidy up a bit
Let's take it for a spin


JAX vs Julia vs TensorFlow vs PyTorch
Let's go with JAX:
- It's normal NumPy, don't need to learn anything new
- Under the hood, they will all use XLA, so no difference
- Very powerful function transformation mechanisms
- It's normal NumPy, don't need to learn anything new
Do We Need to Rewrite CCL from Scratch?
Yes, if we want differentiability there is no workaround.
Trying to progressively adapt CCL with differentiable components would be a nightmare.
- But:
- When you already have a validated numerical scheme and implementation to compare to, there is no "research/uncertainty", it's purely a matter of translating code. I coded jax-cosmo in about a week.
=> CCL would continue to play a major role as a standard reference.
- It can be an opportunity to bring all of our experience in the design of a new library.
- When you already have a validated numerical scheme and implementation to compare to, there is no "research/uncertainty", it's purely a matter of translating code. I coded jax-cosmo in about a week.
What about CAMB and CLASS?
- Quick and Easy solution: train an emulator
- E.g. CosmoPower from A. Spurio Mancini
- E.g. CosmoPower from A. Spurio Mancini
- Longer term solution:
- Someone (maybe Zack) will code a differentiable version of those
- We could kickstart such a project
Where to host such projects?
- These tools will be useful to all modern cosmological surveys, but they are non trivial to implement, test, and maintain.
- You want to ensure broad user adoption
Help build a community to support the maintenance and development of the project
- 2 options:
- A DESC-sponsored project (like CCL), but developed completely openly (think astropy)
- Open project hosted in some sort of external entity, enabling anyone to contribute, across any collaboration.
- That was the idea behind the Differentiable Universe Initiative
- That was the idea behind the Differentiable Universe Initiative
- A DESC-sponsored project (like CCL), but developed completely openly (think astropy)
- With our new DESC Software policy, I think option 1 is now doable.
In Conclusion...

But more seriously, happy to chat anytime if you are interested in contributing, or just want to know more about JAX :-)
By eiffl
Introduction to JAX and the jax-cosmo library
- 969