Multi-GPU Computing in JAX for Automatically Differentiable
High Performance Computing
Francois Lanusse, @EiffL
Why do we need to start thinking about scaling up?
Motivation from ML perspective:
- Machine Learning models are getting better, but bigger
- The dimensionality of data increases (e.g. high resolution images, 3D)
Motivation from Physics perspective:
- Models become very large for Stage IV surveys (i.e. N-body sims)
- Models become very large for Stage IV surveys (i.e. N-body sims)
=> In both cases, a given model will not fit on a single GPU!
Credit: Li et al. 2021
We also have access to a new generation of supercomputers
How does parallel computing works?
- GPUs are great for SIMD (Single Instruction Multiple Data)
- This requires many many simple cores, which all have access to the same memory
If your problem fits in memory, this is the best solution!
- When the data is so large that it cannot fit into a single computer, you need SPMD (Single Program Multiple Data)
- Each process can live on a different physical device, and only in charge of storing and preprocessing a fraction of the total data
- Processes need to talk to each other in order to complete the desired global computation (e.g. MPI)
NVIDIA Ampere architecture
Technical solutions for fast communication between GPUs
- CUDA-aware MPI: Messaging Passing Interface (MPI) standard which allows for direct memory exchange between GPUs potentially on different physical machines
- NVIDIA Collective Communication Library (NCCL): Proprietary NVIDIA library, highly optimized for GPU communications directly within CUDA kernels
Where does JAX comes in in this picture?
- JAX is awesome for several reasons:
- Allows you to write NumPy code, that executes on GPU
- Provides automatic differentiation
- How can we use it for large-scale High Performance Computing?
The Manual Way - MPI4JAX
MPI4JAX - Zero-copy MPI communication of JAX arrays
In a nutshell, provides a JAX wrapper around MPI primitives
Compiled against MPI4PY, rely on CUDA-aware MPI for GPUDirect RDMA
- Primitives can be included directly in jitted code!
- Compiled against MPI4PY, rely on CUDA-aware MPI for GPUDirect RDMA
This code is executed on all processes, each one has a single GPU
mpirun -n 4 python myapp.py
How to make this work on Perlmutter?
Step I: Follow the instructions of the jax-perlmutter-tutorials GitHub repo to setup a JAX environment at NERSC:
Step II: For maximum convenience "The mpi4py provided by the python or cray-python modules is not CUDA-aware. You will have to build CUDA-aware mpi4py in a custom environment using the instructions below." (source). So, you need to build it:
- Step III: Launch your Python script like so (from an salloc'd node, for instance):
$ module load python cudnn/8.2.0 nccl/2.11.4 cudatoolkit $ pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
$ module load PrgEnv-gnu # In addition to the previously loaded modules $ MPICC="cc -target-accel=nvidia80 -shared" pip install --force --no-cache-dir --no-binary=mpi4py mpi4py
$ srun -n 4 -c 32 --gpus-per-task 1 --gpu-bind=none python my_script.py
Find out more on the MPI4JAX doc: https://mpi4jax.readthedocs.io/en/latest/shallow-water.html
- Example of a physical nonlinear shallow-water model distributed on 4 GPUs (Hafner & Vincentini, 2021)
- MPI is used to ensure proper boundary conditions between processes by performing a halo exchange
Examples of Applications
For a more cosmology-oriented problem: MPI parallelism in JaxPM
- A key ingredient in fast N-body solvers is the ability to compute distributed 3D Fast Fourier Transforms.
- Requires transposing a 3D density field, everytime redistributing the array differently on the processor mesh
- Requires AlltoAll operations
- In a WIP branch, JaxPM has the tools needed to distribute a FastPM implemenation.
Density field computed on 8 GPUs with MPI4JAX
So, this works, but...
- The developer (you!) needs to manually take care of all the collective operations needed to ensure the correct result.
- For complex collectives (i.e. other than all gather) the gradients are not known a priori. The developer will have to implement custom gradients around the functions that have communications.
- MPI has a well-known limitation that it does not handle messages larger than 2 GB. MPI4JAX currently doesnt implemement a workaround for that.
def fft3d(arr, comms=None): """ Computes forward FFT, note that the output is transposed """ if comms is not None: shape = list(arr.shape) nx = comms.Get_size() ny = comms.Get_size() # First FFT along z arr = jnp.fft.fft(arr) # [x, y, z] arr = arr.reshape(shape[:-1]+[nx, shape[-1] // nx]) arr = arr.transpose([2, 1, 3, 0]) # [y, z, x] arr, token = mpi4jax.alltoall(arr, comm=comms) arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [y, z, x] # Second FFT along x arr = jnp.fft.fft(arr) arr = arr.reshape(shape[:-1]+[ny, shape[-1] // ny]) arr = arr.transpose([2, 1, 3, 0]) # [z, x, y] arr, token = mpi4jax.alltoall(arr, comm=comms, token=token) arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [z, x, y] # Third FFT along y return jnp.fft.fft(arr)
=> It's very artisanal, not very jaxy!
Future of JAX
The low-level technical side
- JAX relies on the XLA (Accelerated Linear Algebra) library for compiling and executing jitted code.
- Around 2021-2022, support for low-level collective operations as been added to XLA, with NCCL as a backend on GPU clusters \o/
=> JAX is technically natively parallelisable through XLA communication primitives on machines like Perlmutter.
The high-level JAX parallelism API
- Things are still evolving a lot! Jax 0.4.0 is around the corner and will change everything!
The idea: You should be able to write your code as if it would execute on a single GPU, JAX should figure out the rest to make it run on many GPUs! Compatible with vmap, jit, grad, etc.
- Up until JAX v0.3 two methods exist, xmap and pjit, each documented here:
import jax from jax.experimental import maps from jax.experimental import PartitionSpec from jax.experimental.pjit import pjit import numpy as np mesh_shape = (4, 2) devices = np.asarray(jax.devices()).reshape(*mesh_shape) # 'x', 'y' axis names are used here for simplicity mesh = maps.Mesh(devices, ('x', 'y')) in_axis_resources=None out_axis_resources=PartitionSpec('x', 'y') f = pjit( lambda x: 2*x +1, in_axis_resources=None, out_axis_resources=PartitionSpec('x', 'y')) # Sends data to accelerators based on partition_spec with maps.Mesh(mesh.devices, mesh.axis_names): data = f(input_data)
- We should not shy away from thinking large-scale, it is already possible (e.g. mpi4jax), and will only get easier with time.
- JAX is moving in the direction of automated parallelisation on GPU clusters!
- Things to keep an eye on:
- New JAX Array mechanism with upcoming JAX v0.4.0
- jaxDecomp: JAX bindings to NVIDIA cuDecomp library (join me!)