Multi-GPU Computing in JAX for Automatically Differentiable
High Performance Computing

Francois Lanusse, @EiffL

Why do we need to start thinking about scaling up?

  • Motivation from ML perspective:
    • Machine Learning models are getting better, but bigger
    • The dimensionality of data increases (e.g. high resolution images, 3D)
  • Motivation from Physics perspective:
    • Models become very large for Stage IV surveys (i.e. N-body sims)

=> In both cases, a given model will not fit on a single GPU!

Credit: Li et al. 2021

We also have access to a new generation of supercomputers

NERSC 9 system: Perlmutter

  • 1536 GPU nodes, each one with 4x NVIDIA A100 (40GB)
  • High performance HPE/Cray Slingshot interconnect
  • Ranks in top 10 most powerful systems in the world

How does parallel computing works?

  • GPUs are great for SIMD (Single Instruction Multiple Data)
    • This requires many many simple cores, which all have access to the same memory
    • If your problem fits in memory, this is the best solution!
  • When the data is so large that it cannot fit into a single computer, you need SPMD (Single Program Multiple Data)
    • Each process can live on a different physical device, and only in charge of storing and preprocessing a fraction of the total data
    • Processes need to talk to each other in order to complete the desired global computation (e.g. MPI)

NVIDIA Ampere architecture

Technical solutions for fast communication between GPUs

  • CUDA-aware MPI: Messaging Passing Interface (MPI) standard which allows for direct memory exchange between GPUs potentially on different physical machines


  •  NVIDIA Collective Communication Library (NCCL): Proprietary NVIDIA library, highly optimized for GPU communications directly within CUDA kernels

Where does JAX comes in in this picture?

  • JAX is awesome for several reasons:
    • Allows you to write NumPy code, that executes on GPU
    • Provides automatic differentiation

  • How can we use it for large-scale High Performance Computing?

The Manual Way - MPI4JAX

MPI4JAX - Zero-copy MPI communication of JAX arrays

  • In a nutshell, provides a JAX wrapper around MPI primitives
    • Compiled against MPI4PY, rely on CUDA-aware MPI for GPUDirect RDMA
    • Primitives can be included directly in jitted code!
from mpi4py import MPI
import jax
import jax.numpy as jnp
import mpi4jax

rank = comm.Get_rank()

def foo(arr):
   arr = arr + rank
   arr_sum, _ = mpi4jax.allreduce(arr, op=MPI.SUM, comm=comm)
   return arr_sum

a = jnp.zeros((3, 3))
result = foo(a)

if rank == 0:

This code is executed on all processes, each one has a single GPU

mpirun -n 4 python

How to make this work on Perlmutter?

  • Step I: Follow the instructions of the jax-perlmutter-tutorials GitHub repo to setup a JAX environment at NERSC:

  • Step II: For maximum convenience "The mpi4py provided by the python or cray-python modules is not CUDA-aware. You will have to build CUDA-aware mpi4py in a custom environment using the instructions below." (source). So, you need to build it:

  • Step III: Launch your Python script like so (from an salloc'd node, for instance):


$ module load python cudnn/8.2.0 nccl/2.11.4 cudatoolkit
$ pip install --upgrade "jax[cuda]" -f
$ module load PrgEnv-gnu # In addition to the previously loaded modules
$ MPICC="cc -target-accel=nvidia80 -shared" pip install --force --no-cache-dir --no-binary=mpi4py mpi4py
$ srun  -n 4 -c 32 --gpus-per-task 1 --gpu-bind=none python
  • Example of a physical nonlinear shallow-water model distributed on 4 GPUs (Hafner & Vincentini, 2021)
  • MPI is used to ensure proper boundary conditions between processes by performing a halo exchange

Examples of Applications

For a more cosmology-oriented problem: MPI parallelism in JaxPM

  • A key ingredient in fast N-body solvers is the ability to compute distributed 3D Fast Fourier Transforms.
    • Requires transposing a 3D density field, everytime redistributing the array differently on the processor mesh
    • Requires AlltoAll operations
  • In a WIP branch, JaxPM has the tools  needed to distribute a FastPM implemenation.

Density field computed on 8 GPUs with MPI4JAX

So, this works, but...

  • The developer (you!) needs to manually take care of all the collective operations needed to ensure the correct result.
  • For complex collectives (i.e. other than all gather) the gradients are not known a priori. The developer will have to implement custom gradients around the functions that have communications.
  • MPI has a well-known limitation that it does not handle messages larger than 2 GB. MPI4JAX currently doesnt implemement a workaround for that.
def fft3d(arr, comms=None):
    """ Computes forward FFT, note that the output is transposed
    if comms is not None:
        shape = list(arr.shape)
        nx = comms[0].Get_size()
        ny = comms[1].Get_size()

    # First FFT along z
    arr = jnp.fft.fft(arr)  # [x, y, z]
    arr = arr.reshape(shape[:-1]+[nx, shape[-1] // nx])
    arr = arr.transpose([2, 1, 3, 0])  # [y, z, x]
    arr, token = mpi4jax.alltoall(arr, comm=comms[0])
    arr = arr.transpose([1, 2, 0, 3]).reshape(shape)  # Now [y, z, x]

    # Second FFT along x
    arr = jnp.fft.fft(arr)
    arr = arr.reshape(shape[:-1]+[ny, shape[-1] // ny])
    arr = arr.transpose([2, 1, 3, 0])  # [z, x, y]
    arr, token = mpi4jax.alltoall(arr, comm=comms[1], token=token)
    arr = arr.transpose([1, 2, 0, 3]).reshape(shape)  # Now [z, x, y]

    # Third FFT along y
    return jnp.fft.fft(arr)

=> It's very artisanal, not very jaxy!

The Magical

Future of JAX

The low-level technical side 

  • JAX relies on the XLA (Accelerated Linear Algebra) library for compiling and executing jitted code.

  • Around 2021-2022, support for low-level collective operations as been added to XLA, with NCCL as a backend on GPU clusters \o/

    => JAX is technically natively parallelisable through XLA communication primitives on machines like Perlmutter.

The high-level JAX parallelism API

  • Things are still evolving a lot! Jax 0.4.0 is around the corner and will change everything!
  • The idea: You should be able to write your code as if it would execute on a single GPU, JAX should figure out the rest to make it run on many GPUs! Compatible with vmap, jit, grad, etc.
  • Up until JAX v0.3 two methods exist, xmap and pjit, each documented here:
import jax
from jax.experimental import maps
from jax.experimental import PartitionSpec
from jax.experimental.pjit import pjit
import numpy as np

mesh_shape = (4, 2)
devices = np.asarray(jax.devices()).reshape(*mesh_shape)
# 'x', 'y' axis names are used here for simplicity
mesh = maps.Mesh(devices, ('x', 'y'))

out_axis_resources=PartitionSpec('x', 'y')

f = pjit(
  lambda x: 2*x +1,
  out_axis_resources=PartitionSpec('x', 'y'))
# Sends data to accelerators based on partition_spec
with maps.Mesh(mesh.devices, mesh.axis_names):
 data = f(input_data)

pjit example



By eiffl


A little overview of how to use JAX for High Performance Computing on GPU clusters

  • 459