Stefano Spigler
Matthieu Wyart
Mario Geiger
Arthur Jacot
https://arxiv.org/abs/1906.08034
Two different regimes in the dynamics of neural networks
We want to understand why NN are so good at learning
We want to understand why NN are so good at learning
How do we proceed?
how the number of parameters affects learning
[2017 B. Neyshabur et al.]
[2018 S. Spigler et al.]
[2019 M. Geiger et al.] ...
Test error decreases in over-parametrized regime
\(n\)
\(L\)
parameters: \(N \approx n^2 L\)
\(N^*\)
Above \(N^*\) test error depends on \(n\) through fluctuations
\(\mathit{Variance}(f) \sim 1 / n\)
parameters: \(N \approx n^2 L\)
\(\mathit{Variance}(f) = \| f - \langle f \rangle\|^2\)
\(n\)
\(L\)
\(N^*\)
\(N^*\)
How does the network perform in the infinite width limit?
\(n\to\infty\)
There exist two limits in the literature
The 2 limits can be understood in the context of the central limit thm
\(\displaystyle Y = \frac{1}{\color{red}\sqrt{n}} \sum_{i=1}^n X_i \quad \)
As \({\color{red}n} \to \infty, \quad Y \longrightarrow \) Gaussian
\(\langle Y \rangle={\color{red}\sqrt{n}}\langle X_i \rangle\)
\(\langle Y^2\rangle - \langle Y\rangle^2 = \langle X_i^2 \rangle - \langle X_i\rangle^2\)
\(\langle Y \rangle=\langle X_i \rangle\)
\(\langle Y^2\rangle - \langle Y\rangle^2 = {\color{red}\frac{1}{n}} (\langle X_i^2 \rangle - \langle X_i\rangle^2)\)
\(\displaystyle Y = \frac{1}{\color{red}n} \sum_{i=1}^n X_i\)
Central Limit Theorem:
\(X_i\)
(i.i.d.)
(Law of large numbers)
regime 1: kernel limit
\(x\)
\(f(w,x)\)
\(W_{ij}\)
\(\displaystyle z^{\ell+1}_i = \sigma \left({\frac1{\color{red}\sqrt{n}}} \sum_{j=1}^n W^\ell_{ij} z^\ell_j \right)\)
initialization
\(W_{ij} \longleftarrow \mathcal{N}(0, 1)\)
\(w\) = all the weights
[1995 R. M. Neal]
\(z_j\)
\(z_i\)
Independent terms in the sum, CLT
=> when \({\color{red}n} \to\infty\) output is a Gaussian process
[2018 Jacot et al.]
[2018 Du et al.]
[2019 Lee et al.]
\(\langle Y \rangle={\color{red}\sqrt{n}}\langle X_i \rangle\) => no feature learning
During training,
the kernel: \(\Theta(w,x_1,x_2) = \nabla_w f(w, x_1) \cdot \nabla_w f(w, x_2)\)
when \({\color{red}n}\to\infty\)
regime 1: kernel limit
\(f(w_0)\)
\(f(w_0) + \nabla f(w_0) \cdot dw\)
\(f(w_0) + \sum_\mu c_\mu \Theta(w_0, x_\mu)\)
space of functions: \(x \mapsto \mathbb{R}\)
\(f(w)\)
regime 1: kernel limit
regime 2: mean field limit
\(\displaystyle f(w,x) = \frac{1}{\color{red}n} \sum_{i=1}^n W_i \; \sigma \! \left(\frac{1}{\color{red}\sqrt{n}} \sum_{j=1}^n W_{ij} x_j \right)\)
\(W_{ij}\)
\(x\)
\(f(w,x)\)
\(W_{i}\)
was \(\frac{1}{\color{red}\sqrt{n}}\) in the kernel limit
studied theoretically for 1 hidden layer
for \(n\longrightarrow \infty\)
\(\displaystyle f(w,x) = \frac{1}{\color{red}n} \sum_{i=1}^n W_i \; \sigma \! \left(\frac{1}{\color{red}\sqrt{n}} \sum_{j=1}^n W_{ij} x_j \right)\)
\(\frac{1}{\color{red}n}\) instead of \(\frac{1}{\color{red}\sqrt{n}}\) implies
\(\displaystyle f(w,x) \approx f(\rho,x) = \int d\rho(W,\vec W) \; W \; \sigma \! \left(\frac{1}{\sqrt{n}} \sum_{j=1}^n W_j x_j \right)\)
where \(\rho\) is the density of neuron's weights
regime 2: mean field limit
\(\displaystyle f(\rho,x) = \int d\rho(W,\vec W) \; W \; \sigma \! \left(\frac{1}{\sqrt{n}} \sum_{j=1}^n W_j x_j \right)\)
\(\rho\) follows a differential equation
[2018 S. Mei et al], [2018 Rotskoff and Vanden-Eijnden]
activation do change
\(\langle Y \rangle=\langle X_i \rangle\) => feature learning
regime 2: mean field limit
What is the difference between the two limits
kernel limit and mean field limit
\(\frac{1}{\color{red}\sqrt{n}}\)
\(\frac{1}{\color{red}n}\)
we use the scaling factor \(\alpha\) to investigate
the transition between the two regimes
\({\color{red}\alpha} f(w,x) = \frac{{\color{red}\alpha}}{\sqrt{n}} \sum_{i=1}^n W_i z_i\)
[2019 Chizat and Bach]
\(z_i\)
\(W_i\)
\(f(w,x)\)
\( \alpha \cdot ( f(w,x) - f(w_0, x) ) \)
\(\Longrightarrow\)
\(\displaystyle \mathcal{L}(w) = \frac{1}{\alpha^2 |\mathcal{D}|} \sum_{(x,y)\in\ \mathcal{D}} \ell\left( \alpha (f(w,x) - f(w_0,x)), y \right) \)
linearize the network with \(f - f_0\)
We would like that for any finite \(n\), in the limit \(\alpha \to \infty\), the network behaves linearly
This \(\alpha^2\) is here to converge in a time that does not scale with \(\alpha\) in the limit \(\alpha \to \infty\)
Loss:
there is a plateau for large values of \(\alpha\)
MNIST 10k parity, FC L=3, softplus, gradient flow with momentum
lazy regime
MNIST 10k parity, FC L=3, softplus, gradient flow with momentum
ensemble average
the ensemble average converge with \(n \to \infty\)
no overlap
\(\alpha\)
MNIST 10k parity, FC L=3, softplus, gradient flow with momentum
plot in function of \(\sqrt{n} \alpha\) overlap the lines
overlap !
\(\sqrt{n}\alpha\)
\(\alpha\)
the phase space is split in two by \(\alpha^*\) who decays with \(\sqrt{n}\)
\(n\)
\(\alpha\)
feature learning
kernel limit
mean field limit
\(\alpha^* \sim \frac{1}{\sqrt{n}}\)
lazy learning
same for other datasets: the trends depends on the dataset
MNIST 10k
reduced to 2 classes
10PCA MNIST 10k
reduced to 2 classes
FC L=3, softplus, gradient flow with momentum
EMNIST 10k
reduced to 2 classes
Fashion MNIST 10k
reduced to 2 classes
FC L=3, softplus, gradient flow with momentum
same for other datasets: the trends depends on the dataset
CIFAR10 10k
reduced to 2 classes
CNN SGD ADAM
CNN: the tendency is inverted
how does the learning curves depends on \(n\) and \(\alpha\)
MNIST 10k parity, L=3, softplus, gradient flow with momentum
\(\sqrt{n} \alpha\)
overlap !
same time in lazy
lazy
MNIST 10k parity, L=3, softplus, gradient flow with momentum
there is a characteristic time in the learning curves
\(\sqrt{n} \alpha\)
overlap !
characteristic time \(t_1\)
lazy
\(f(w_0) + \nabla f(w_0) \cdot dw\)
\(f(w)\)
\(f(w_0)\)
\(t_1\) characterise the curvature of the network manifold
\(t_1\)
\(t\)
\(t_1\) is the time you need to drive to realize the earth is curved
\(v\)
\(R\)
\(t_1 \sim R/v\)
\(\dot W^L= -\frac{\partial\mathcal{L}}{\partial W^L} = \mathcal{O}\left(\frac{1}{\alpha \sqrt{n}}\right)\)
\(\Rightarrow t_1 \sim \sqrt{n} \alpha \)
the rate of change of \(W\) determines when we leave the tangent space, aka \(t_1 \sim \sqrt{n}\alpha\)
\(x\)
\(f(w,x)\)
\(W_{ij}\)
\(\displaystyle z^{\ell+1}_i = \sigma \left(\frac1{\sqrt{n}} \sum_{j=1}^n W^\ell_{ij} z^\ell_j \right)\)
\(z_j\)
\(z_i\)
\(W_{ij}\) and \(z_i\) are initialized \(\sim 1\)
\(\dot W_{ij}\) and \(\dot z_i\) at initialization \(\Rightarrow t_1\)
Upper bound: consider weight of last layer
actually the correct scaling
for large \(\alpha\)
if \(t_1\) become smaller than the convergence time, it is not linear until the end
\(t_1 \sim \sqrt{n} \alpha\) v.s. 1 \(\Longrightarrow\) \({\color{red}\alpha^* \sim \frac{1}{\sqrt{n}}}\)
when the dynamics stops before \(t_1\) we are in the lazy regime
lazy dynamics
\(\alpha^* \sim 1/\sqrt{n}\)
then for large \(n\)
\(\Rightarrow\)
\( \alpha^* f(w_0, x) \ll 1 \)
\(\Rightarrow\)
\( \alpha^* (f(w,x) - f(w_0, x)) \approx \alpha^* f(w,x) \)
for large \(n\) our conclusions should holds without this trick
linearize the network with \(f - f_0\) was not necessary
arxiv.org/abs/1906.08034
github.com/mariogeiger/feature_lazy
\(n\)
\(\alpha\)
lazy learning
stop before \(t_1\)
feature learning
stop after \(t_1\)
kernel limit
mean field limit
\(\alpha^* \sim \frac{1}{\sqrt{n}}\)
time to leave the tangent space \(t_1 \sim \sqrt{n}\alpha\)
=> time for learning features
continuous dynamics
\(\dot w = -\nabla \mathcal{L}\)
Implemented with a dynamical adaptation of \(dt\) such that,
\(10^{-4} < \frac{\|\nabla \mathcal{L}(t_{i+1})- \nabla \mathcal{L}(t_i)\|^2}{\|\nabla \mathcal{L}(t_{i+1})\|\cdot\|\nabla \mathcal{L}(t_i)\|} < 10^{-2}\)
works well with full batch and smooth loss
continuous dynamics
\(10^{-4} < \frac{\|\nabla \mathcal{L}(t_{i+1})- \nabla \mathcal{L}(t_i)\|^2}{\|\nabla \mathcal{L}(t_{i+1})\|\cdot\|\nabla \mathcal{L}(t_i)\|} < 10^{-2}\)
momentum dynamics
\(\dot v = -\frac{1}{\tau}(v + \nabla \mathcal{L})\)
\(\dot w = v\)
I took \(\tau \sim \frac{\sqrt{h}}{\alpha}t\) (where \(t\) is the proper time of the dynamic)
kernel grow
\(\frac{\|\Theta - \Theta_0 \|}{\|\Theta_0\|}\)
\(\alpha\)
\(\sqrt{h}\alpha\)