Multi-GPU Computing in JAX for Automatically Differentiable
High Performance Computing
Francois Lanusse, @EiffL
Why do we need to start thinking about scaling up?
-
Motivation from ML perspective:
- Machine Learning models are getting better, but bigger
- The dimensionality of data increases (e.g. high resolution images, 3D)
-
Motivation from Physics perspective:
- Models become very large for Stage IV surveys (i.e. N-body sims)
- Models become very large for Stage IV surveys (i.e. N-body sims)
=> In both cases, a given model will not fit on a single GPU!
Credit: Li et al. 2021
We also have access to a new generation of supercomputers
NERSC 9 system: Perlmutter
- 1536 GPU nodes, each one with 4x NVIDIA A100 (40GB)
- High performance HPE/Cray Slingshot interconnect
- Ranks in top 10 most powerful systems in the world
How does parallel computing works?
- GPUs are great for SIMD (Single Instruction Multiple Data)
- This requires many many simple cores, which all have access to the same memory
-
If your problem fits in memory, this is the best solution!
- When the data is so large that it cannot fit into a single computer, you need SPMD (Single Program Multiple Data)
- Each process can live on a different physical device, and only in charge of storing and preprocessing a fraction of the total data
- Processes need to talk to each other in order to complete the desired global computation (e.g. MPI)
NVIDIA Ampere architecture
Technical solutions for fast communication between GPUs
- CUDA-aware MPI: Messaging Passing Interface (MPI) standard which allows for direct memory exchange between GPUs potentially on different physical machines
- NVIDIA Collective Communication Library (NCCL): Proprietary NVIDIA library, highly optimized for GPU communications directly within CUDA kernels
Where does JAX comes in in this picture?
- JAX is awesome for several reasons:
- Allows you to write NumPy code, that executes on GPU
- Provides automatic differentiation
- How can we use it for large-scale High Performance Computing?
The Manual Way - MPI4JAX
MPI4JAX - Zero-copy MPI communication of JAX arrays
-
In a nutshell, provides a JAX wrapper around MPI primitives
-
Compiled against MPI4PY, rely on CUDA-aware MPI for GPUDirect RDMA
- Primitives can be included directly in jitted code!
-
Compiled against MPI4PY, rely on CUDA-aware MPI for GPUDirect RDMA
from mpi4py import MPI
import jax
import jax.numpy as jnp
import mpi4jax
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
@jax.jit
def foo(arr):
arr = arr + rank
arr_sum, _ = mpi4jax.allreduce(arr, op=MPI.SUM, comm=comm)
return arr_sum
a = jnp.zeros((3, 3))
result = foo(a)
if rank == 0:
print(result)
This code is executed on all processes, each one has a single GPU
mpirun -n 4 python myapp.py
How to make this work on Perlmutter?
-
Step I: Follow the instructions of the jax-perlmutter-tutorials GitHub repo to setup a JAX environment at NERSC:
-
Step II: For maximum convenience "The mpi4py provided by the python or cray-python modules is not CUDA-aware. You will have to build CUDA-aware mpi4py in a custom environment using the instructions below." (source). So, you need to build it:
- Step III: Launch your Python script like so (from an salloc'd node, for instance):
$ module load python cudnn/8.2.0 nccl/2.11.4 cudatoolkit
$ pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
$ module load PrgEnv-gnu # In addition to the previously loaded modules
$ MPICC="cc -target-accel=nvidia80 -shared" pip install --force --no-cache-dir --no-binary=mpi4py mpi4py
$ srun -n 4 -c 32 --gpus-per-task 1 --gpu-bind=none python my_script.py
Find out more on the MPI4JAX doc: https://mpi4jax.readthedocs.io/en/latest/shallow-water.html
- Example of a physical nonlinear shallow-water model distributed on 4 GPUs (Hafner & Vincentini, 2021)
- MPI is used to ensure proper boundary conditions between processes by performing a halo exchange
Examples of Applications
For a more cosmology-oriented problem: MPI parallelism in JaxPM
- A key ingredient in fast N-body solvers is the ability to compute distributed 3D Fast Fourier Transforms.
- Requires transposing a 3D density field, everytime redistributing the array differently on the processor mesh
- Requires AlltoAll operations
- In a WIP branch, JaxPM has the tools needed to distribute a FastPM implemenation.
Density field computed on 8 GPUs with MPI4JAX
So, this works, but...
- The developer (you!) needs to manually take care of all the collective operations needed to ensure the correct result.
- For complex collectives (i.e. other than all gather) the gradients are not known a priori. The developer will have to implement custom gradients around the functions that have communications.
- MPI has a well-known limitation that it does not handle messages larger than 2 GB. MPI4JAX currently doesnt implemement a workaround for that.
def fft3d(arr, comms=None):
""" Computes forward FFT, note that the output is transposed
"""
if comms is not None:
shape = list(arr.shape)
nx = comms[0].Get_size()
ny = comms[1].Get_size()
# First FFT along z
arr = jnp.fft.fft(arr) # [x, y, z]
arr = arr.reshape(shape[:-1]+[nx, shape[-1] // nx])
arr = arr.transpose([2, 1, 3, 0]) # [y, z, x]
arr, token = mpi4jax.alltoall(arr, comm=comms[0])
arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [y, z, x]
# Second FFT along x
arr = jnp.fft.fft(arr)
arr = arr.reshape(shape[:-1]+[ny, shape[-1] // ny])
arr = arr.transpose([2, 1, 3, 0]) # [z, x, y]
arr, token = mpi4jax.alltoall(arr, comm=comms[1], token=token)
arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [z, x, y]
# Third FFT along y
return jnp.fft.fft(arr)
=> It's very artisanal, not very jaxy!
The Magical
Near
Future of JAX
The low-level technical side
- JAX relies on the XLA (Accelerated Linear Algebra) library for compiling and executing jitted code.
- Around 2021-2022, support for low-level collective operations as been added to XLA, with NCCL as a backend on GPU clusters \o/
=> JAX is technically natively parallelisable through XLA communication primitives on machines like Perlmutter.
The high-level JAX parallelism API
- Things are still evolving a lot! Jax 0.4.0 is around the corner and will change everything!
-
The idea: You should be able to write your code as if it would execute on a single GPU, JAX should figure out the rest to make it run on many GPUs! Compatible with vmap, jit, grad, etc.
- Up until JAX v0.3 two methods exist, xmap and pjit, each documented here:
import jax
from jax.experimental import maps
from jax.experimental import PartitionSpec
from jax.experimental.pjit import pjit
import numpy as np
mesh_shape = (4, 2)
devices = np.asarray(jax.devices()).reshape(*mesh_shape)
# 'x', 'y' axis names are used here for simplicity
mesh = maps.Mesh(devices, ('x', 'y'))
in_axis_resources=None
out_axis_resources=PartitionSpec('x', 'y')
f = pjit(
lambda x: 2*x +1,
in_axis_resources=None,
out_axis_resources=PartitionSpec('x', 'y'))
# Sends data to accelerators based on partition_spec
with maps.Mesh(mesh.devices, mesh.axis_names):
data = f(input_data)
pjit example
Conclusion
- We should not shy away from thinking large-scale, it is already possible (e.g. mpi4jax), and will only get easier with time.
- JAX is moving in the direction of automated parallelisation on GPU clusters!
- Things to keep an eye on:
- New JAX Array mechanism with upcoming JAX v0.4.0
- jaxDecomp: JAX bindings to NVIDIA cuDecomp library (join me!)
jax-hpc
By eiffl
jax-hpc
A little overview of how to use JAX for High Performance Computing on GPU clusters
- 531