Francois Lanusse, @EiffL
Intel Alder Lake i5 CPU
6 Cores
NVIDIA GeForce RTX 3090 GPU
10,496 CUDA Cores
Connection to the CPU
Streaming Multi Processor
(128 cores)
High Speed GPU Memory
Fast Interconnect with other GPUs
Computer Graphics
Deep Learning
High Performance Computing
Attention layer
cornerstone of ChatGPT
Illustration of one node of the Perlmutter Supercomputer (NERSC)
NERSC 9 system: Perlmutter
=> GPUs are becoming essential for HPC, and key to reaching exascale!
=> Modern Deep Learning Frameworks look like normal Python but execute code on GPU!
import jax.numpy as np
m = np.eye(10) # Some matrix
def my_func(x):
return m.dot(x).sum()
x = np.linspace(0,1,10)
y = my_func(x)
from jax import grad
df_dx = grad(my_func)
y = df_dx(x)
from mpi4py import MPI
import jax
import jax.numpy as jnp
import mpi4jax
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
@jax.jit
def foo(arr):
arr = arr + rank
arr_sum, _ = mpi4jax.allreduce(arr, op=MPI.SUM, comm=comm)
return arr_sum
a = jnp.zeros((3, 3))
result = foo(a)
if rank == 0:
print(result)
This code is executed on all processes, each one has a single GPU
mpirun -n 4 python myapp.py
Find out more on the MPI4JAX doc: https://mpi4jax.readthedocs.io/en/latest/shallow-water.html
DES Y1 3x2pt posterior, jax-cosmo HMC vs Cobaya MH
(credit: Joe Zuntz)
def log_posterior( theta ):
return gaussian_likelihood( theta ) + log_prior(theta)
score = jax.grad(log_posterior)(theta)
import jax.numpy as np