20th Heidelberg Summer School
SEpt. 2025
Université Paris-Saclay, Université Paris Cité, CEA, CNRS, AIM
The success of Deep Learning over the last ~10 years has relied on several factors:
The goal of this lecture is to steal reuse software and computational tools for Physics
Initial conditions of
cosmological simulation
Reconstructed initial conditions
Reconstructed late-time density field
Data being fitted
Typically, for a neural network defined by \( f(\mathbf{\theta}, \mathbf{x}) \) and a loss function \( \mathcal{L} \) we need to evaluate the gradient of the loss function with respect to parameters \( \mathbf{\theta} \): \( \nabla_{\mathbf{\theta}} \mathcal{L}(\theta, \mathbf{x}, \mathbf{y} ) \)
for \( h \in \mathbb{R}^{+ *} \) sufficiently small
This approach induces 2 types of errors:
Truncation error
Roundoff error
Method | Error rate | Function calls for gradient |
---|---|---|
Forward/backward difference | o(h) | n + 1 |
Central difference | o(h^2) | 2 n |
from sympy import Symbol, sin, diff
x = Symbol('x')
f = (sin(x))**(sin(x))
derivative_f = diff(f, x)
print(f"The function is: {f}")
print(f"The derivative is: {derivative_f}")
The function is: sin(x)**sin(x)
The derivative is: (log(sin(x))*cos(x) + cos(x))*sin(x)**sin(x)
def func(x):
for i in range(4):
x = 2*x**2 + 3*x + 1
return x
f = func(x)
df = diff(f, x)
def func(x):
if x > 2:
return x**2
else:
return x**3
try:
func(x)
except TypeError as err:
print("Error:", err)
Error: cannot determine truth value of Relational
Forward pass
Backward pass
Forward pass
Backward pass
import jax.numpy as np
m = np.eye(10) # Some matrix
def my_func(x):
return m.dot(x).sum()
x = np.linspace(0,1,10)
y = my_func(x)
from jax import grad
df_dx = grad(my_func)
y = df_dx(x)
def pure_fun(x):
return 2 * x**2 + 3 * x + 2
def impure_fun_side_effect(x):
print('I am a side effect')
return 2 * x**2 + 3 * x + 2
C = 10. # A global variable
def impure_fun_uses_globals(x):
return 2 * x**2 + 3 * x + C
# Decorator for jitting
@jax.jit
def my_fun(W, x):
return W.dot(x)
# or as an explicit transformation
my_fun_jitted = jax.jit(my_fun)
def f(x):
return 2 * x**2 + 3 *x + 2
df_dx = jax.grad(f)
jac = jax.jacobian(f)
hess = jax.hessian(f)
# As decorator
@jax.vmap
def f(x):
return 2 * x**2 + 3 *x + 2
# Can be composed
df_dx = jax.jit(jax.vmap(jax.grad(f)))
Deconvolution
Inpainting
Denoising
\( \mathbf{A} \) is known and encodes our physical understanding of the problem.
Some neural network \( f_{\theta} \)
Hubble Space Telescope
Ground-Based Telescope
With these concepts in hand, we can for instance estimate the Maximum A Posteriori solution:
Let's write down the difference pieces corresponding to our example deconvolution problem:
=> Maximizing the log posterior requires differentiating through the instrument model! (PSF convolution)
import jax
import jax.numpy as jnp
from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
### Create a mixture of two scalar Gaussians:
gm = tfd.MixtureSameFamily(
mixture_distribution=tfd.Categorical(
probs=[0.3, 0.7]),
components_distribution=tfd.Normal(
loc=[-1., 1],
scale=[0.1, 0.5]))
# Evaluate probability
gm.log_prob(1.0)
Let's go to this notebook:
https://tinyurl.com/jaxpm
solution: https://tinyurl.com/jaxpm-solution
class MDN(nn.Module):
num_components: int
@nn.compact
def __call__(self, x):
x = nn.relu(nn.Dense(128)(x))
x = nn.relu(nn.Dense(128)(x))
x = nn.tanh(nn.Dense(64)(x))
# Instead of regressing directly the value of the mass, the network
# will now try to estimate the parameters of a mass distribution.
categorical_logits = nn.Dense(self.num_components)(x)
loc = nn.Dense(self.num_components)(x)
scale = 1e-3 + nn.softplus(nn.Dense(self.num_components)(x))
dist =tfd.MixtureSameFamily(
mixture_distribution=tfd.Categorical(logits=categorical_logits),
components_distribution=tfd.Normal(loc=loc, scale=scale))
# To make it understand the batch dimension
dist = tfd.Independent(dist)
# Returns a distribution !
return dist
The core idea of Variational Inference is to assume a model distribution \( q_\phi \) and we fit it to the unknown posterior by minimizing the KL divergence:
$$ \mathbb{D}_{KL}\left( q_\phi(z) \parallel p(z | x) \right) = \mathbb{E}_{z \sim q_{\phi}} \left[ \log \frac{q_\phi(z)}{p(z | x)} \right] $$
$$ = \mathbb{E}_{z \sim q_{\phi}} \left[ \log q(z) \right] - \mathbb{E}_{z \sim q_{\phi}} \left[ \log p(z | x) \right] $$
$$ = \mathbb{E}_{z \sim q_{\phi}} \left[ \log q(z) \right] - \mathbb{E}_{z \sim q_{\phi}} \left[ \log p(x | z) + \log p(z) - \log p(x)\right] $$
$$ = \underbrace{\mathbb{D}_{KL}\left( q_\phi(z) \parallel p(z) \right) - \mathbb{E}_{z \sim q_{\phi}} \left[ \log p(x | z) \right]}_{{ELBO}} + \log p(x) $$