jaxcosmo: Finally an Automatically Differentiable Cosmology Library!
2021 July DESC Collaboration Meeting
Francois Lanusse (@EiffL), for the jaxcosmo contributors
And more (not yet in the picture because I'm lagging behind on PR merging)
 JeanEric Campagne
 Joe Zuntz
 Tilman Troester
 Ben Horowitz
Why
do you want
any of this?
Fisher Forecasts become exact and instantaneous
import jax
import jax.numpy as np
import jax_cosmo as jc
# .... define probes, and load a data vector
def gaussian_likelihood( theta ):
# Build the cosmology for given parameters
cosmo = jc.Planck15(Omega_c=theta[0], sigma8=theta[1])
# Compute mean and covariance
mu, cov = jc.angular_cl.gaussian_cl_covariance_and_mean(cosmo,
ell, probes)
# returns likelihood of data under model
return jc.likelihood.gaussian_likelihood(data, mu, cov)
# Fisher matrix in just one line:
F =  jax.hessian(gaussian_likelihood)(theta)
 No derivatives were harmed by finite differences in the computation of this Fisher!
 You can scale to thousands of input parameters, no issues of numerical precision.
Inference becomes fast and scalable
 Gradients of the log posterior are required for modern efficient and scalable inference techniques:
 Variational Inference
 Hamiltonian MonteCarlo
 With gradients, you can scale to thousands of parameters
DES Y1 posterior, jaxcosmo HMC vs Cobaya MH
(credit: Joe Zuntz)
def log_posterior( theta ):
return gaussian_likelihood( theta ) + log_prior(theta)
score = jax.grad(log_posterior)(theta)
You can differentiate through Cosmology
Given (g)riz photometry, find a tomographic bin assignment method that optimizes a 3x2pt analysis.
 Strategy with differentiable physics:
 Introduce a parametric "bin assignment function"
 Optimize this function to maximize DETF FoM
 Differentiable Physics for Score Compression
Makinen et al. 2021, arXiv:2107.07405
What Framework Should you Adopt?
Some toplevel considerations

[Low Barrier to Entry] You don't want contributors to have to learn an entirely new language to start contributing (I'm looking at you Julia!).

[Maintainability] You want readable code, without several layers of implementation (I'm looking at you GalSim and CCL!)

[Trivially installable] This should always work:
 [Performance] Duh! But more precisely, you want to be able to take advantage of next gen platforms like Perlmutter. Code should work 100% the same on CPU or GPU for the user.
$ pip install mylibrary
=> a JAXbased library will meet all of these criteria
What
Is
JAX?
JAX: NumPy + Autograd + XLA
 JAX uses the NumPy API
=> You can copy/paste existing code, works pretty much out of the box
 JAX is a successor of autograd
=> You can transform any function to get forward or backward automatic derivatives (grad, jacobians, hessians, etc)
 JAX uses XLA as a backend
=> Same framework as used by TensorFlow (supports CPU, GPU, TPU execution)
import jax.numpy as np
m = np.eye(10) # Some matrix
def my_func(x):
return m.dot(x).sum()
x = np.linspace(0,1,10)
y = my_func(x)
from jax import grad
df_dx = grad(my_func)
y = df_dx(x)

Pure functions
=> JAX is designed for functional programming: your code should be built around pure functions (no side effects) Enables caching functions as XLA expressions
 Enables JAX's extremely powerful concept of composable function transformations

Just in Time (jit) Compilation
=> jax.jit() transformation will compile an entire function as XLA, which then runs in one go on GPU.

Arbitrary order forward and backward autodiff
=> jax.grad() transformation will apply the d/dx operator to a function f

Autovectorization
=> jax.vmap() transformation will add a batch dimension to the input/output of any function
def pure_fun(x):
return 2 * x**2 + 3 * x + 2
def impure_fun_side_effect(x):
print('I am a side effect')
return 2 * x**2 + 3 * x + 2
C = 10. # A global variable
def impure_fun_uses_globals(x):
return 2 * x**2 + 3 * x + C
# Decorator for jitting
@jax.jit
def my_fun(W, x):
return W.dot(x)
# or as an explicit transformation
my_fun_jitted = jax.jit(my_fun)
def f(x):
return 2 * x**2 + 3 *x + 2
df_dx = jax.grad(f)
jac = jax.jacobian(f)
hess = jax.hessian(f)
# As decorator
@jax.vmap
def f(x):
return 2 * x**2 + 3 *x + 2
# Can be composed
df_dx = jax.jit(jax.vmap(jax.grad(f)))
 Cosmology library embracing JAX computing model
=> Autodiff, GPU acceleration, automated batching
 MIT licensed, collaboratively developed
=> Follows the allcontributors guidelines
 Inspired by CCL's tracer mechanism
=> See design document
 Unit tested against CCL
Status of Implementation
 Certainly not featurecomplete but approximately DES Y1 3x2pt capable
 Power spectrum:
 EisensteinHu Power
 halofit
 Tracers:
 Lensing (including NLA, multiplicative bias)
 NumberCounts
 2pt functions
 Limber, harmonic space
 Gaussian covariance tools
 Power spectrum:
 These were the main features needed for the DESC 3x2pt Tomo Challenge
 Took about a week to implement, another week to tidy up a bit
Let's take it for a spin
My
Two
Cents
JAX vs Julia vs TensorFlow vs PyTorch

Let's go with JAX:
 It's normal NumPy, don't need to learn anything new
 Under the hood, they will all use XLA, so no difference
 Very powerful function transformation mechanisms
 It's normal NumPy, don't need to learn anything new
Do We Need to Rewrite CCL from Scratch?

Yes, if we want differentiability there is no workaround.
Trying to progressively adapt CCL with differentiable components would be a nightmare.
 But:
 When you already have a validated numerical scheme and implementation to compare to, there is no "research/uncertainty", it's purely a matter of translating code. I coded jaxcosmo in about a week.
=> CCL would continue to play a major role as a standard reference.
 It can be an opportunity to bring all of our experience in the design of a new library.
 When you already have a validated numerical scheme and implementation to compare to, there is no "research/uncertainty", it's purely a matter of translating code. I coded jaxcosmo in about a week.
What about CAMB and CLASS?
 Quick and Easy solution: train an emulator
 E.g. CosmoPower from A. Spurio Mancini
 E.g. CosmoPower from A. Spurio Mancini
 Longer term solution:
 Someone (maybe Zack) will code a differentiable version of those
 We could kickstart such a project
Where to host such projects?
 These tools will be useful to all modern cosmological surveys, but they are non trivial to implement, test, and maintain.
 You want to ensure broad user adoption

Help build a community to support the maintenance and development of the project
 2 options:
 A DESCsponsored project (like CCL), but developed completely openly (think astropy)
 Open project hosted in some sort of external entity, enabling anyone to contribute, across any collaboration.
 That was the idea behind the Differentiable Universe Initiative
 That was the idea behind the Differentiable Universe Initiative
 A DESCsponsored project (like CCL), but developed completely openly (think astropy)
 With our new DESC Software policy, I think option 1 is now doable.
In Conclusion...
But more seriously, happy to chat anytime if you are interested in contributing, or just want to know more about JAX :)
jaxcosmo
By eiffl
jaxcosmo
Introduction to JAX and the jaxcosmo library
 863