Feature vs. Lazy Training
Stefano Spigler
Matthieu Wyart
Mario Geiger
Arthur Jacot
https://arxiv.org/abs/1906.08034
Two different regimes in the dynamics of neural networks
We want to understand why NN are so good at learning
- versatile: images, songs, game agents, etc...
- generalize well
- overparametrized
- why do they generalize so well?
- how many parameters do we need?
- how the architecture affects its ability to learn
- there exist two different regimes in training NN!
We want to understand why NN are so good at learning
- we can measure everything
- variety of architecture
- variety of tasks
- variety of dynamics
How do we proceed?
how the number of parameters affects learning
[2017 B. Neyshabur et al.]
[2018 S. Spigler et al.]
[2019 M. Geiger et al.] ...
Test error decreases in over-parametrized regime
\(n\)
\(L\)
parameters: \(N \approx n^2 L\)
\(N^*\)
Above \(N^*\) test error depends on \(n\) through fluctuations
\(\mathit{Variance}(f) \sim 1 / n\)
parameters: \(N \approx n^2 L\)
\(\mathit{Variance}(f) = \| f - \langle f \rangle\|^2\)
\(n\)
\(L\)
\(N^*\)
\(N^*\)
How does the network perform in the infinite width limit?
\(n\to\infty\)
There exist two limits in the literature
The 2 limits can be understood in the context of the central limit thm
\(\displaystyle Y = \frac{1}{\color{red}\sqrt{n}} \sum_{i=1}^n X_i \quad \)
As \({\color{red}n} \to \infty, \quad Y \longrightarrow \) Gaussian
\(\langle Y \rangle={\color{red}\sqrt{n}}\langle X_i \rangle\)
\(\langle Y^2\rangle - \langle Y\rangle^2 = \langle X_i^2 \rangle - \langle X_i\rangle^2\)
\(\langle Y \rangle=\langle X_i \rangle\)
\(\langle Y^2\rangle - \langle Y\rangle^2 = {\color{red}\frac{1}{n}} (\langle X_i^2 \rangle - \langle X_i\rangle^2)\)
\(\displaystyle Y = \frac{1}{\color{red}n} \sum_{i=1}^n X_i\)
Central Limit Theorem:
\(X_i\)
(i.i.d.)
(Law of large numbers)
regime 1: kernel limit
\(x\)
\(f(w,x)\)
\(W_{ij}\)
\(\displaystyle z^{\ell+1}_i = \sigma \left({\frac1{\color{red}\sqrt{n}}} \sum_{j=1}^n W^\ell_{ij} z^\ell_j \right)\)
initialization
\(W_{ij} \longleftarrow \mathcal{N}(0, 1)\)
\(w\) = all the weights
[1995 R. M. Neal]
\(z_j\)
\(z_i\)
Independent terms in the sum, CLT
=> when \({\color{red}n} \to\infty\) output is a Gaussian process
[2018 Jacot et al.]
[2018 Du et al.]
[2019 Lee et al.]
- is independent of the initialization
and is constant through learning
- the weights does not change
=> it behaves like a kernel method
- internal activations does not change
\(\langle Y \rangle={\color{red}\sqrt{n}}\langle X_i \rangle\) => no feature learning
During training,
the kernel: \(\Theta(w,x_1,x_2) = \nabla_w f(w, x_1) \cdot \nabla_w f(w, x_2)\)
when \({\color{red}n}\to\infty\)
regime 1: kernel limit
\(f(w_0)\)
\(f(w_0) + \nabla f(w_0) \cdot dw\)
\(f(w_0) + \sum_\mu c_\mu \Theta(w_0, x_\mu)\)
space of functions: \(x \mapsto \mathbb{R}\)
\(f(w)\)
regime 1: kernel limit
regime 2: mean field limit
\(\displaystyle f(w,x) = \frac{1}{\color{red}n} \sum_{i=1}^n W_i \; \sigma \! \left(\frac{1}{\color{red}\sqrt{n}} \sum_{j=1}^n W_{ij} x_j \right)\)
\(W_{ij}\)
\(x\)
\(f(w,x)\)
\(W_{i}\)
was \(\frac{1}{\color{red}\sqrt{n}}\) in the kernel limit
studied theoretically for 1 hidden layer
Another limit !
for \(n\longrightarrow \infty\)
\(\displaystyle f(w,x) = \frac{1}{\color{red}n} \sum_{i=1}^n W_i \; \sigma \! \left(\frac{1}{\color{red}\sqrt{n}} \sum_{j=1}^n W_{ij} x_j \right)\)
\(\frac{1}{\color{red}n}\) instead of \(\frac{1}{\color{red}\sqrt{n}}\) implies
- no output fluctuations at initialization
- we can replace the sum by an integral
\(\displaystyle f(w,x) \approx f(\rho,x) = \int d\rho(W,\vec W) \; W \; \sigma \! \left(\frac{1}{\sqrt{n}} \sum_{j=1}^n W_j x_j \right)\)
where \(\rho\) is the density of neuron's weights
regime 2: mean field limit
\(\displaystyle f(\rho,x) = \int d\rho(W,\vec W) \; W \; \sigma \! \left(\frac{1}{\sqrt{n}} \sum_{j=1}^n W_j x_j \right)\)
\(\rho\) follows a differential equation
[2018 S. Mei et al], [2018 Rotskoff and Vanden-Eijnden]
activation do change
\(\langle Y \rangle=\langle X_i \rangle\) => feature learning
regime 2: mean field limit
What is the difference between the two limits
- which limit describe better finite \(n\) networks?
- are there corresponding regimes for finite \(n\)?
kernel limit and mean field limit
\(\frac{1}{\color{red}\sqrt{n}}\)
\(\frac{1}{\color{red}n}\)
we use the scaling factor \(\alpha\) to investigate
the transition between the two regimes
\({\color{red}\alpha} f(w,x) = \frac{{\color{red}\alpha}}{\sqrt{n}} \sum_{i=1}^n W_i z_i\)
[2019 Chizat and Bach]
- if \(\alpha\) is fixed constant and \(n\to\infty\) then => kernel limit
- if \({\color{red}\alpha} \sim \frac{1}{\sqrt{n}}\) and \(n\to\infty\) then => mean field limit
\(z_i\)
\(W_i\)
\(f(w,x)\)
\( \alpha \cdot ( f(w,x) - f(w_0, x) ) \)
\(\Longrightarrow\)
\(\displaystyle \mathcal{L}(w) = \frac{1}{\alpha^2 |\mathcal{D}|} \sum_{(x,y)\in\ \mathcal{D}} \ell\left( \alpha (f(w,x) - f(w_0,x)), y \right) \)
linearize the network with \(f - f_0\)
We would like that for any finite \(n\), in the limit \(\alpha \to \infty\), the network behaves linearly
This \(\alpha^2\) is here to converge in a time that does not scale with \(\alpha\) in the limit \(\alpha \to \infty\)
Loss:
there is a plateau for large values of \(\alpha\)
MNIST 10k parity, FC L=3, softplus, gradient flow with momentum
lazy regime
MNIST 10k parity, FC L=3, softplus, gradient flow with momentum
ensemble average
the ensemble average converge with \(n \to \infty\)
no overlap
\(\alpha\)
MNIST 10k parity, FC L=3, softplus, gradient flow with momentum
plot in function of \(\sqrt{n} \alpha\) overlap the lines
overlap !
\(\sqrt{n}\alpha\)
\(\alpha\)
the phase space is split in two by \(\alpha^*\) who decays with \(\sqrt{n}\)
\(n\)
\(\alpha\)
feature learning
kernel limit
mean field limit
\(\alpha^* \sim \frac{1}{\sqrt{n}}\)
lazy learning
same for other datasets: the trends depends on the dataset
MNIST 10k
reduced to 2 classes
10PCA MNIST 10k
reduced to 2 classes
FC L=3, softplus, gradient flow with momentum
EMNIST 10k
reduced to 2 classes
Fashion MNIST 10k
reduced to 2 classes
FC L=3, softplus, gradient flow with momentum
same for other datasets: the trends depends on the dataset
CIFAR10 10k
reduced to 2 classes
CNN SGD ADAM
CNN: the tendency is inverted
how does the learning curves depends on \(n\) and \(\alpha\)
MNIST 10k parity, L=3, softplus, gradient flow with momentum
\(\sqrt{n} \alpha\)
overlap !
same time in lazy
lazy
MNIST 10k parity, L=3, softplus, gradient flow with momentum
there is a characteristic time in the learning curves
\(\sqrt{n} \alpha\)
overlap !
characteristic time \(t_1\)
lazy
\(f(w_0) + \nabla f(w_0) \cdot dw\)
\(f(w)\)
\(f(w_0)\)
\(t_1\) characterise the curvature of the network manifold
\(t_1\)
\(t\)
\(t_1\) is the time you need to drive to realize the earth is curved
\(v\)
\(R\)
\(t_1 \sim R/v\)
\(\dot W^L= -\frac{\partial\mathcal{L}}{\partial W^L} = \mathcal{O}\left(\frac{1}{\alpha \sqrt{n}}\right)\)
\(\Rightarrow t_1 \sim \sqrt{n} \alpha \)
the rate of change of \(W\) determines when we leave the tangent space, aka \(t_1 \sim \sqrt{n}\alpha\)
\(x\)
\(f(w,x)\)
\(W_{ij}\)
\(\displaystyle z^{\ell+1}_i = \sigma \left(\frac1{\sqrt{n}} \sum_{j=1}^n W^\ell_{ij} z^\ell_j \right)\)
\(z_j\)
\(z_i\)
\(W_{ij}\) and \(z_i\) are initialized \(\sim 1\)
\(\dot W_{ij}\) and \(\dot z_i\) at initialization \(\Rightarrow t_1\)
Upper bound: consider weight of last layer
actually the correct scaling
for large \(\alpha\)
- \(f\) is almost linear: \(f(w,x) \approx f(w_0,x) + \nabla f(w_0,x) \cdot dw\)
- convergence time does not depend on \(n\) and \(\alpha\)
if \(t_1\) become smaller than the convergence time, it is not linear until the end
\(t_1 \sim \sqrt{n} \alpha\) v.s. 1 \(\Longrightarrow\) \({\color{red}\alpha^* \sim \frac{1}{\sqrt{n}}}\)
when the dynamics stops before \(t_1\) we are in the lazy regime
lazy dynamics
- \(\mathcal{L}(w) = \frac{1}{\alpha^2 n} \sum_{(x,y)\in\text{train}} \ell\left( \alpha (f(w,x) - f(w_0,x)), y \right) \)
- \(\dot f(w, x) = \nabla_w f(w,x) \cdot \dot w\)
- \(= - \nabla_w f(w,x) \cdot \nabla_w \mathcal{L} \)
- \(\sim - \nabla_w f \cdot (\frac{\partial}{\partial f} \mathcal{L} \; \nabla_w f)\)
- \(\sim \frac{1}{\alpha} \Theta\)
- "\( \alpha \dot f t = 1 \)" \(\Rightarrow\) "\( t = \|\Theta\|^{-1} \)"
\(\alpha^* \sim 1/\sqrt{n}\)
then for large \(n\)
\(\Rightarrow\)
\( \alpha^* f(w_0, x) \ll 1 \)
\(\Rightarrow\)
\( \alpha^* (f(w,x) - f(w_0, x)) \approx \alpha^* f(w,x) \)
for large \(n\) our conclusions should holds without this trick
linearize the network with \(f - f_0\) was not necessary
arxiv.org/abs/1906.08034
github.com/mariogeiger/feature_lazy
\(n\)
\(\alpha\)
lazy learning
stop before \(t_1\)
feature learning
stop after \(t_1\)
kernel limit
mean field limit
\(\alpha^* \sim \frac{1}{\sqrt{n}}\)
time to leave the tangent space \(t_1 \sim \sqrt{n}\alpha\)
=> time for learning features
continuous dynamics
\(\dot w = -\nabla \mathcal{L}\)
Implemented with a dynamical adaptation of \(dt\) such that,
\(10^{-4} < \frac{\|\nabla \mathcal{L}(t_{i+1})- \nabla \mathcal{L}(t_i)\|^2}{\|\nabla \mathcal{L}(t_{i+1})\|\cdot\|\nabla \mathcal{L}(t_i)\|} < 10^{-2}\)
works well with full batch and smooth loss
continuous dynamics
\(10^{-4} < \frac{\|\nabla \mathcal{L}(t_{i+1})- \nabla \mathcal{L}(t_i)\|^2}{\|\nabla \mathcal{L}(t_{i+1})\|\cdot\|\nabla \mathcal{L}(t_i)\|} < 10^{-2}\)
momentum dynamics
\(\dot v = -\frac{1}{\tau}(v + \nabla \mathcal{L})\)
\(\dot w = v\)
I took \(\tau \sim \frac{\sqrt{h}}{\alpha}t\) (where \(t\) is the proper time of the dynamic)
kernel grow
\(\frac{\|\Theta - \Theta_0 \|}{\|\Theta_0\|}\)
\(\alpha\)
\(\sqrt{h}\alpha\)
feature-lazy
By Mario Geiger