Francois Lanusse, @EiffL
=> In both cases, a given model will not fit on a single GPU!
Credit: Li et al. 2021
NERSC 9 system: Perlmutter
NVIDIA Ampere architecture
from mpi4py import MPI
import jax
import jax.numpy as jnp
import mpi4jax
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
@jax.jit
def foo(arr):
arr = arr + rank
arr_sum, _ = mpi4jax.allreduce(arr, op=MPI.SUM, comm=comm)
return arr_sum
a = jnp.zeros((3, 3))
result = foo(a)
if rank == 0:
print(result)
This code is executed on all processes, each one has a single GPU
mpirun -n 4 python myapp.py
$ module load python cudnn/8.2.0 nccl/2.11.4 cudatoolkit
$ pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
$ module load PrgEnv-gnu # In addition to the previously loaded modules
$ MPICC="cc -target-accel=nvidia80 -shared" pip install --force --no-cache-dir --no-binary=mpi4py mpi4py
$ srun -n 4 -c 32 --gpus-per-task 1 --gpu-bind=none python my_script.py
Find out more on the MPI4JAX doc: https://mpi4jax.readthedocs.io/en/latest/shallow-water.html
Density field computed on 8 GPUs with MPI4JAX
def fft3d(arr, comms=None):
""" Computes forward FFT, note that the output is transposed
"""
if comms is not None:
shape = list(arr.shape)
nx = comms[0].Get_size()
ny = comms[1].Get_size()
# First FFT along z
arr = jnp.fft.fft(arr) # [x, y, z]
arr = arr.reshape(shape[:-1]+[nx, shape[-1] // nx])
arr = arr.transpose([2, 1, 3, 0]) # [y, z, x]
arr, token = mpi4jax.alltoall(arr, comm=comms[0])
arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [y, z, x]
# Second FFT along x
arr = jnp.fft.fft(arr)
arr = arr.reshape(shape[:-1]+[ny, shape[-1] // ny])
arr = arr.transpose([2, 1, 3, 0]) # [z, x, y]
arr, token = mpi4jax.alltoall(arr, comm=comms[1], token=token)
arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [z, x, y]
# Third FFT along y
return jnp.fft.fft(arr)
=> It's very artisanal, not very jaxy!
import jax
from jax.experimental import maps
from jax.experimental import PartitionSpec
from jax.experimental.pjit import pjit
import numpy as np
mesh_shape = (4, 2)
devices = np.asarray(jax.devices()).reshape(*mesh_shape)
# 'x', 'y' axis names are used here for simplicity
mesh = maps.Mesh(devices, ('x', 'y'))
in_axis_resources=None
out_axis_resources=PartitionSpec('x', 'y')
f = pjit(
lambda x: 2*x +1,
in_axis_resources=None,
out_axis_resources=PartitionSpec('x', 'y'))
# Sends data to accelerators based on partition_spec
with maps.Mesh(mesh.devices, mesh.axis_names):
data = f(input_data)
pjit example