Francois Lanusse @EiffL
From this excellent tutorial:
Figures and data from Ho et al. 2019
What we will feed the model:
What kind of uncertainties should I most worry about?
Training data from MultiDark Planck 2 N-body simulation (Klypin et al. 2016) with 261287 clusters.
and what is it?
import jax.numpy as np
m = np.eye(10) # Some matrix
def my_func(x):
return m.dot(x).sum()
x = np.linspace(0,1,10)
y = my_func(x)
from jax import grad
df_dx = grad(my_func)
y = df_dx(x)
def pure_fun(x):
return 2 * x**2 + 3 * x + 2
def impure_fun_side_effect(x):
print('I am a side effect')
return 2 * x**2 + 3 * x + 2
C = 10. # A global variable
def impure_fun_uses_globals(x):
return 2 * x**2 + 3 * x + C
# Decorator for jitting
@jax.jit
def my_fun(W, x):
return W.dot(x)
# or as an explicit transformation
my_fun_jitted = jax.jit(my_fun)
def f(x):
return 2 * x**2 + 3 *x + 2
df_dx = jax.grad(f)
jac = jax.jacobian(f)
hess = jax.hessian(f)
# As decorator
@jax.vmap
def f(x):
return 2 * x**2 + 3 *x + 2
# Can be composed
df_dx = jax.jit(jax.vmap(jax.grad(f)))
import flax.linen as nn
import optax
class MLP(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.relu(nn.Dense(128)(x))
x = nn.relu(nn.Dense(128)(x))
x = nn.Dense(1)(x)
return x
# Instantiate the Neural Network
model = MLP()
# Initialize the parameters
params = model.init(jax.random.PRNGKey(0), x)
prediction = model.apply(params, x)
# Instantiate Optimizer
tx = optax.adam(learning_rate=0.001)
opt_state = tx.init(params)
# Define loss function
def loss_fn(params, x, y):
mse = model.apply(params, x) -y)**2
return jnp.mean(mse)
# Compute gradients
grads = jax.grad(loss_fn)(params, x, y)
# Update parameters
updates, opt_state = tx.update(grads, opt_state)
params = optax.apply_updates(params, updates)
We will be using this notebook
Your goal: Building a regression model with a Mean Squared Error loss in JAX/Flax
def loss_fn(params, x, y):
mse = model.apply(params, x) -y)**2
return jnp.mean(mse)
Raise your hand when you reach the cluster mass prediction plot
class MLP(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.relu(nn.Dense(128)(x))
x = nn.relu(nn.Dense(128)(x))
x = nn.tanh(nn.Dense(64)(x))
x = nn.Dense(1)(x)
return x
def loss_fn(params, x, y):
prediction = model.apply(params, x)
return jnp.mean( (prediction - y)**2 )
What is going on???
Let's try to understand the neural network output by looking at the loss function
$$ \mathcal{L} = \sum_{(x_i, y_i) \in \mathcal{D}} \parallel y_i - f_\theta(x_i)\parallel^2 \quad \simeq \quad \int \parallel y - f_\theta(x) \parallel^2 \ p(x,y) \ dx dy $$ $$\Longrightarrow \int \left[ \int \parallel y - f_\theta(x) \parallel^2 \ p(y|x) \ dy \right] p(x) dx$$
This is minimized when $$f_{\theta^\star}(x) = \int y \ p(y|x) \ dy $$
i.e. when the network is predicting the mean of p(y|x).
There are intrinsic uncertainties in this problem, at each x there is a full
I have a set of data points {x, y} where I observe x and want to predict y.
credit: Venkatesh Tata
=> This means expressing the posterior as a Bernoulli distribution with parameter predicted by a neural network
Step 1: We neeed some data
cat or dog image
label 1 for cat, 0 for dog
Probability of including cats and dogs in my dataset
Implicit prior
Image search results for cats and dogs
Implicit likelihood
A distance between distributions: the Kullback-Leibler Divergence
Step 2: We need a tool to compare distributions
Minimizing this KL divergence is equivalent to minimizing the negative log likelihood of the model
In our case of binary classification:
We recover the binary cross entropy loss function !
We need a parametric conditional distribution to
compute
class MDN(nn.Module):
num_components: int
@nn.compact
def __call__(self, x):
x = nn.relu(nn.Dense(128)(x))
x = nn.relu(nn.Dense(128)(x))
x = nn.tanh(nn.Dense(64)(x))
# Instead of regressing directly the value of the mass, the network
# will now try to estimate the parameters of a mass distribution.
categorical_logits = nn.Dense(self.num_components)(x)
loc = nn.Dense(self.num_components)(x)
scale = 1e-3 + nn.softplus(nn.Dense(self.num_components)(x))
dist =tfd.MixtureSameFamily(
mixture_distribution=tfd.Categorical(logits=categorical_logits),
components_distribution=tfd.Normal(loc=loc, scale=scale))
# To make it understand the batch dimension
dist = tfd.Independent(dist)
# Returns a distribution !
return dist
Your goal: Implement a Mixture Density Network in JAX/Flax/TFP, and use it to get unbiased mass estimates.
When everyone is done, we will discuss the results.
def loss_fn(params, x, y):
q = model.apply(params, x)
return jnp.mean( - q.log_prob(y[:,0]) )
Distribution of masses in our training data
We can reweight the predictions for a desired prior
From this excellent tutorial:
Given a training set D = {X,Y}, the predictions from a Neural Network can be expressed as:
Weight Estimation by Maximum Likelihood
Weight Estimation by Variational Inference
TensorFlow Probability implementation
Quick reminder on dropout
Hinton 2012, Srivastava 2014