Feature vs. Lazy Training
Stefano Spigler
Matthieu Wyart
Mario Geiger
Arthur Jacot
https://arxiv.org/abs/1906.08034
Two different regimes in the dynamics of neural networks
feature training
the network learns features
lazy training
no features are learned
\(f(w) \approx f(w_0) + \nabla f(w_0) \cdot dw\)
\(f(w,x)\)
characteristic training time separating the two regimes
training procedure to force each regim
How does the network perform in the infinite width limit?
\(n\to\infty\)
There exist two limits in the literature
\(n\) neurons per layer
Overparametrization
- perform well
- theoretically tractable
The 2 limits can be understood in the context of the central limit thm
\(\displaystyle Y = \frac{1}{\color{red}\sqrt{n}} \sum_{i=1}^n X_i \quad \)
As \({\color{red}n} \to \infty, \quad Y \longrightarrow \) Gaussian
\(\langle Y \rangle={\color{red}\sqrt{n}}\langle X_i \rangle\)
\(\langle Y^2\rangle - \langle Y\rangle^2 = \langle X_i^2 \rangle - \langle X_i\rangle^2\)
\(\langle Y \rangle=\langle X_i \rangle\)
\(\langle Y^2\rangle - \langle Y\rangle^2 = {\color{red}\frac{1}{n}} (\langle X_i^2 \rangle - \langle X_i\rangle^2)\)
\(\displaystyle Y = \frac{1}{\color{red}n} \sum_{i=1}^n X_i\)
Central Limit Theorem:
\(X_i\)
(i.i.d.)
(Law of large numbers)
regime 1: kernel limit
\(x\)
\(f(w,x)\)
\(W^1_{ij}\)
\(\displaystyle z^{\ell+1}_i = {\frac1{\color{red}\sqrt{n}}} \sum_{j=1}^n W^\ell_{ij} \ \phi(z^\ell_j)\)
At initialization
\(W_{ij} \longleftarrow \mathcal{N}(0, 1)\)
\(w\) = all the weights
\(z^1_j\)
\(z^2_i\)
With this scaling, small correlations adds up significantly => the weights will change only a little during training
[1995 R. M. Neal]
Independent terms in the sum, CLT
=> when \({\color{red}n} \to\infty\) output is a Gaussian process
space of functions \(\mathbb{R}^d \to \mathbb{R}\)
(dimension \(\infty\))
regime 1: kernel limit
\(f(w_0)\)
\(f(w_0) + \sum_\mu c_\mu \Theta(w_0, x_\mu)\)
kernel space
(dimension \(m\))
tangent space
(dimension \(N\))
\(f(w_0) + \nabla f(w_0) \cdot dw\)
\(d\) input dimension
\(N\) number of parameters
\(m\) size of the trainset
\(f(w)\)
network manifold
(dimension \(N\))
[2018 Jacot et al.]
[2018 Du et al.]
[2019 Lee et al.]
The NTK is independent of the initialization
and is constant through learning
=> the network behaves like a kernel method
The weights barely change
\(\|dw\|\sim \mathcal{O}(1)\) and \(dW_{ij} \sim \mathcal{O}({\color{red}1/n})\) (\(\mathcal{O}({\color{red}1/\sqrt{n}})\) at the extremity)
The internal activations barely change
\( dz \sim \mathcal{O}({\color{red}1/\sqrt{n}})\)
=> no feature training
for \({\color{red}n}\to\infty\)
neural tengant kernel
\(\Theta(w,x_1,x_2) = \nabla_w f(w, x_1) \cdot \nabla_w f(w, x_2)\)
regime 1: kernel limit
regime 2: mean field limit
\(\displaystyle f(w,x) = \frac{1}{\color{red}n} \sum_{i=1}^n W_i \; \phi \! \left(\frac{1}{\color{red}\sqrt{n}} \sum_{j=1}^n W_{ij} x_j \right)\)
\(W_{ij}\)
\(x\)
\(f(w,x)\)
\(W_{i}\)
was \(\frac{1}{\color{red}\sqrt{n}}\) in the kernel limit
studied theoretically for 1 hidden layer
Another limit !
for \({\color{red}n}\longrightarrow \infty\)
\(\displaystyle f(w,x) = \frac{1}{\color{red}n} \sum_{i=1}^n W_i \; \phi \! \left(\frac{1}{\color{red}\sqrt{n}} \sum_{j=1}^n W_{ij} x_j \right)\)
\(\frac{1}{\color{red}n}\) instead of \(\frac{1}{\color{red}\sqrt{n}}\) implies
- no output fluctuations at initialization
- we can replace the sum by an integral
\(\displaystyle f(w,x) \approx f(\rho,x) = \int d\rho(W,\vec W) \; W \; \phi \! \left(\frac{1}{\sqrt{n}} \sum_{j=1}^n W_j x_j \right)\)
where \(\rho\) is the density of neuron's weights
regime 2: mean field limit
\(\displaystyle f(\rho,x) = \int d\rho(W,\vec W) \; W \; \phi \! \left(\frac{1}{\sqrt{n}} \sum_{j=1}^n W_j x_j \right)\)
\(\rho\) follows a differential equation
[2018 S. Mei et al], [2018 Rotskoff and Vanden-Eijnden], [2018 Chizat and Bach]
In this limit, the internal activation do change
\(\langle Y \rangle=\langle X_i \rangle\) => feature training
regime 2: mean field limit
What is the difference between the two limits
- which limit describe better finite \(n\) networks?
- are there corresponding regimes for finite \(n\)?
kernel limit and mean field limit
\(\frac{1}{\color{red}\sqrt{n}}\)
\(\frac{1}{\color{red}n}\)
frozen
change
internal activations \(z^\ell\)
\(\displaystyle{\color{red}\alpha} f(w,x) = \frac{{\color{red}\alpha}}{\sqrt{n}} \sum_{i=1}^n W_i \ \phi(z_i)\)
[2019 Chizat and Bach]
- if \(\alpha\) is fixed constant and \(n\to\infty\) then => kernel limit
- if \({\color{red}\alpha} \sim \frac{1}{\sqrt{n}}\) and \(n\to\infty\) then => mean field limit
\(z_i\)
\(W_i\)
\(f(w,x)\)
we use the scaling factor \(\alpha\) to investigate
the transition between the two regimes
\( \alpha \cdot ( f(w,x) - f(w_0, x) ) \)
linearize the network with \(f - f_0\)
We would like that for any finite \(n\), in the limit \(\alpha \to \infty\), the network behaves linearly
This \(\alpha^2\) is here to converge in a time that does not scale with \(\alpha\) in the limit \(\alpha \to \infty\)
\(\displaystyle \mathcal{L}(w) = \frac{1}{\alpha^2 |\mathcal{D}|} \sum_{(x,y)\in\ \mathcal{D}} \ell\left( {\color{red}\alpha (f(w,x) - f(w_0,x))}, y \right) \)
loss function
\(\displaystyle f(w,x) =\\ \frac{1}{{\color{red}\sqrt{n}}} \sum_{i=1}^n W^3_i \ \phi(z^3_i)\)
\(W_{ij}(t=0) \longleftarrow \mathcal{N}(0, 1)\)
\(\displaystyle z^{\ell+1}_i = {\frac1{\color{red}\sqrt{n}}} \sum_{j=1}^n W^\ell_{ij} \ \phi(z^\ell_j)\)
\(z^3_i\)
\(W^3_i\)
\(z^2_i\)
\(z^1_i\)
\(W^0_{ij}\)
\(W^1_{ij}\)
\(W^2_{ij}\)
\(x_i\)
\(\dot w = -\nabla_w \mathcal{L}(w)\)
Implemented with a dynamical adaptation of the time step \(dt\) such that,
\(10^{-4} < \frac{\|\nabla \mathcal{L}(t_{i+1})- \nabla \mathcal{L}(t_i)\|^2}{\|\nabla \mathcal{L}(t_{i+1})\|\cdot\|\nabla \mathcal{L}(t_i)\|} < 10^{-2}\)
(works well only with full batch and smooth loss)
continuous dynamics
continuous dynamics
\(10^{-4} < \frac{\|\nabla \mathcal{L}(t_{i+1})- \nabla \mathcal{L}(t_i)\|^2}{\|\nabla \mathcal{L}(t_{i+1})\|\cdot\|\nabla \mathcal{L}(t_i)\|} < 10^{-2}\)
momentum dynamics
\(\dot v = -\frac{1}{\tau}(v + \nabla \mathcal{L})\)
\(\dot w = v\)
there is a plateau for large values of \(\alpha\)
MNIST 10k parity, FC L=3, softplus, gradient flow with momentum
lazy regime
MNIST 10k parity, FC L=3, softplus, gradient flow with momentum
ensemble average
the ensemble average converge with \(n \to \infty\)
no overlap
\(\alpha\)
the ensemble average
\(\displaystyle \bar f(x) = \int f(w(w_0),x) \;d\mu(w_0)\)
MNIST 10k parity, FC L=3, softplus, gradient flow with momentum
plot in function of \(\sqrt{n} \alpha\) overlap the lines
overlap !
\(\sqrt{n}\alpha\)
\(\alpha\)
\(\sqrt{n}\alpha\)
\(\alpha\)
\(\frac{\|\Theta - \Theta_0 \|}{\|\Theta_0\|}\)
the kernel evolution displays two regimes
MNIST 10k parity, FC L=3, softplus, gradient flow with momentum
the phase space is split in two by \(\alpha^*\) who decays with \(\sqrt{n}\)
\(n\)
\(\alpha\)
feature training
kernel limit
mean field limit
\(\alpha^* \sim \frac{1}{\sqrt{n}}\)
lazy training
same for other datasets: the trends depends on the dataset
MNIST 10k
reduced to 2 classes
10PCA MNIST 10k
reduced to 2 classes
FC L=3, softplus, gradient flow with momentum
EMNIST 10k
reduced to 2 classes
Fashion MNIST 10k
reduced to 2 classes
FC L=3, softplus, gradient flow with momentum
same for other datasets: the trends depends on the dataset
CIFAR10 10k
reduced to 2 classes
CNN SGD ADAM
CNN: the tendency is inverted
\(\sqrt{n}\alpha\)
how does the learning curves depends on \(n\) and \(\alpha\)
MNIST 10k parity, L=3, softplus, gradient flow with momentum
\(\sqrt{n} \alpha\)
overlap !
same time in lazy
lazy
MNIST 10k parity, L=3, softplus, gradient flow with momentum
there is a characteristic time in the learning curves
\(\sqrt{n} \alpha\)
overlap !
characteristic time \(t_1\)
lazy
\(t_1\) characterise the curvature of the network manifold
\(t_1\)
\(t\)
\(f(w)\)
network manifold
(dimension \(N\))
\(d\) input dimension
\(N\) number of parameters
\(m\) size of the trainset
tangent space
(dimension \(N\))
\(f(w_0) + \nabla f(w_0) \cdot dw\)
\(f(w_0) + \sum_\mu c_\mu \Theta(w_0, x_\mu)\)
kernel space
(dimension \(m\))
\(f(w_0)\)
space of functions \(\mathbb{R}^d \to \mathbb{R}\)
(dimension \(\infty\))
\(t_1\) is the time you need to drive to realize the earth is curved
\(v\)
\(R\)
\(t_1 \sim R/v\)
the rate of change of \(W\) determines when we leave the tangent space, aka \(t_1 \sim \sqrt{n}\alpha\)
\(x\)
\(f(w,x)\)
\(W_{ij}\)
\(\displaystyle z^{\ell+1}_i = \frac1{\sqrt{n}} \sum_{j=1}^n W^\ell_{ij} \ \phi(z^\ell_j)\)
\(z_j\)
\(z_i\)
\(W_{ij}\) and \(z_i\) are initialized \(\sim 1\)
\(\dot W_{ij}\) and \(\dot z_i\) at initialization \(\Rightarrow t_1\)
\(\dot W^\ell = \mathcal{O}(\frac{1}{\alpha n})\)
\(\dot W^0 = \mathcal{O}(\frac{1}{\alpha \sqrt{n}})\)
\(\dot z = \mathcal{O}(\frac{1}{\alpha \sqrt{n}})\)
actually the correct scaling
Upper bound: consider weight of last layer
\(\Rightarrow t_1 \sim \sqrt{n} \alpha \)
\(\dot W^L= -\frac{\partial\mathcal{L}}{\partial W^L} = \mathcal{O}\left(\frac{1}{\alpha \sqrt{n}}\right)\)
convergence time is of order 1 in lazy regime
-
\(\displaystyle \mathcal{L}(w) = \frac{1}{\alpha^2 |\mathcal{D}|} \sum_{(x,y)\in\ \mathcal{D}} \ell\left( {\color{red}\alpha (f(w,x) - f(w_0,x))}, y \right) \)
- \(\dot f(w, x) = \nabla_w f(w,x) \cdot \dot w\)
- \(= - \nabla_w f(w,x) \cdot \nabla_w \mathcal{L} \)
- \(\sim - \nabla_w f \cdot (\frac{\partial}{\partial f} \mathcal{L} \; \nabla_w f)\)
- \(\sim \frac{1}{\alpha} \Theta\)
- "\( \alpha \dot f t = 1 \)" \(\Rightarrow\) "\( t_{lazy} = \|\Theta\|^{-1} \)"
\(\Longrightarrow\) \({\color{red}\alpha^* \sim \frac{1}{\sqrt{n}}}\)
when the dynamics stops before \(t_1\) we are in the lazy regime
\(t_1 \sim \sqrt{n} \alpha \)
\(t_{lazy} \sim 1 \)
(time to converge in the lazy regime)
(time to exit the tangent space)
\(\alpha^* \sim 1/\sqrt{n}\)
then for large \(n\)
\(\Rightarrow\)
\( \alpha^* f(w_0, x) \ll 1 \)
\(\Rightarrow\)
\( \alpha^* (f(w,x) - f(w_0, x)) \approx \alpha^* f(w,x) \)
for large \(n\) our conclusions should holds without this trick
linearize the network with \(f - f_0\) was not necessary
arxiv.org/abs/1906.08034
github.com/mariogeiger/feature_lazy
\(n\)
\(\alpha\)
lazy training
stop before \(t_1\)
feature training
go beyond \(t_1\)
kernel limit
mean field limit
\(\alpha^* \sim \frac{1}{\sqrt{n}}\)
time to leave the tangent space \(t_1 \sim \sqrt{n}\alpha\)
=> time for learning features
[SPML] feature-lazy
By Mario Geiger