Feature vs. Lazy Training

Stefano Spigler

Matthieu Wyart

Mario Geiger

Arthur Jacot

https://arxiv.org/abs/1906.08034

Two different regimes in the dynamics of neural networks

We want to understand why NN are so good at learning

  • versatile: images, songs, game agents, etc...
  • generalize well
  • overparametrized

 

  • why do they generalize so well?
  • how many parameters do we need?
  • how the architecture affects its ability to learn
  • there exist two different regimes in training NN!

We want to understand why NN are so good at learning

  • we can measure everything
  • variety of architecture
  • variety of tasks
  • variety of dynamics

How do we proceed?

how the number of parameters affects learning

[2017 B. Neyshabur et al.]

[2018 S. Spigler et al.]

[2019 M. Geiger et al.] ...

Test error decreases in over-parametrized regime

\(n\)

\(L\)

parameters: \(N \approx n^2 L\)

\(N^*\)

Above \(N^*\) test error depends on \(n\) through fluctuations

\(\mathit{Variance}(f) \sim 1 / n\)

parameters: \(N \approx n^2 L\)

\(\mathit{Variance}(f) = \| f - \langle f \rangle\|^2\)

\(n\)

\(L\)

\(N^*\)

\(N^*\)

How does the network perform in the infinite width limit?

\(n\to\infty\)

There exist two limits in the literature

The 2 limits can be understood in the context of the central limit thm

\(\displaystyle Y = \frac{1}{\color{red}\sqrt{n}} \sum_{i=1}^n X_i \quad \)

As \({\color{red}n} \to \infty, \quad Y \longrightarrow \) Gaussian

\(\langle Y \rangle={\color{red}\sqrt{n}}\langle X_i \rangle\)   

\(\langle Y^2\rangle - \langle Y\rangle^2 = \langle X_i^2 \rangle - \langle X_i\rangle^2\)

\(\langle Y \rangle=\langle X_i \rangle\)   

\(\langle Y^2\rangle - \langle Y\rangle^2 = {\color{red}\frac{1}{n}} (\langle X_i^2 \rangle - \langle X_i\rangle^2)\)

\(\displaystyle Y = \frac{1}{\color{red}n} \sum_{i=1}^n X_i\)

Central Limit Theorem:

\(X_i\)

(i.i.d.)

(Law of large numbers)

regime 1: kernel limit

\(x\)

\(f(w,x)\)

\(W_{ij}\)

\(\displaystyle z^{\ell+1}_i = \sigma \left({\frac1{\color{red}\sqrt{n}}} \sum_{j=1}^n W^\ell_{ij} z^\ell_j \right)\)

initialization

\(W_{ij} \longleftarrow \mathcal{N}(0, 1)\)

\(w\) = all the weights

[1995 R. M. Neal]

\(z_j\)

\(z_i\)

Independent terms in the sum, CLT

=> when \({\color{red}n} \to\infty\) output is a Gaussian process

[2018 Jacot et al.]

[2018 Du et al.]

[2019 Lee et al.]

  • is independent of the initialization
    and is constant through learning

 

  • the weights does not change
    => it behaves like a kernel method

 

  • internal activations does not change

    \(\langle Y \rangle={\color{red}\sqrt{n}}\langle X_i \rangle\) => no feature learning

During training,

the kernel: \(\Theta(w,x_1,x_2) = \nabla_w f(w, x_1) \cdot \nabla_w f(w, x_2)\)

when \({\color{red}n}\to\infty\)

regime 1: kernel limit

\(f(w_0)\)

\(f(w_0) + \nabla f(w_0) \cdot dw\)

\(f(w_0) + \sum_\mu c_\mu \Theta(w_0, x_\mu)\)

space of functions: \(x \mapsto \mathbb{R}\)

\(f(w)\)

regime 1: kernel limit

regime 2: mean field limit

\(\displaystyle f(w,x) = \frac{1}{\color{red}n} \sum_{i=1}^n W_i \; \sigma \! \left(\frac{1}{\color{red}\sqrt{n}} \sum_{j=1}^n W_{ij} x_j \right)\)

\(W_{ij}\)

\(x\)

\(f(w,x)\)

\(W_{i}\)

was \(\frac{1}{\color{red}\sqrt{n}}\) in the kernel limit

studied theoretically for 1 hidden layer

Another limit !

for \(n\longrightarrow \infty\)

\(\displaystyle f(w,x) = \frac{1}{\color{red}n} \sum_{i=1}^n W_i \; \sigma \! \left(\frac{1}{\color{red}\sqrt{n}} \sum_{j=1}^n W_{ij} x_j \right)\)

\(\frac{1}{\color{red}n}\) instead of \(\frac{1}{\color{red}\sqrt{n}}\) implies

  • no output fluctuations at initialization
     
  • we can replace the sum by an integral

\(\displaystyle f(w,x) \approx f(\rho,x) = \int d\rho(W,\vec W) \; W \; \sigma \! \left(\frac{1}{\sqrt{n}} \sum_{j=1}^n W_j x_j \right)\)

where \(\rho\) is the density of neuron's weights

regime 2: mean field limit

\(\displaystyle f(\rho,x) = \int d\rho(W,\vec W) \; W \; \sigma \! \left(\frac{1}{\sqrt{n}} \sum_{j=1}^n W_j x_j \right)\)

\(\rho\) follows a differential equation

[2018 S. Mei et al], [2018 Rotskoff and Vanden-Eijnden]

activation do change

\(\langle Y \rangle=\langle X_i \rangle\) => feature learning

regime 2: mean field limit

What is the difference between the two limits

  • which limit describe better finite \(n\) networks?
  • are there corresponding regimes for finite \(n\)

kernel limit              and              mean field limit

\(\frac{1}{\color{red}\sqrt{n}}\)

\(\frac{1}{\color{red}n}\)

we use the scaling factor \(\alpha\) to investigate

the transition between the two regimes

\({\color{red}\alpha} f(w,x) = \frac{{\color{red}\alpha}}{\sqrt{n}} \sum_{i=1}^n W_i z_i\)

[2019 Chizat and Bach]

  • if \(\alpha\) is fixed constant and \(n\to\infty\) then => kernel limit
  • if \({\color{red}\alpha} \sim \frac{1}{\sqrt{n}}\) and \(n\to\infty\) then => mean field limit

\(z_i\)

\(W_i\)

\(f(w,x)\)

\( \alpha \cdot ( f(w,x) - f(w_0, x) ) \)

\(\Longrightarrow\)

\(\displaystyle \mathcal{L}(w) = \frac{1}{\alpha^2 |\mathcal{D}|} \sum_{(x,y)\in\ \mathcal{D}} \ell\left( \alpha (f(w,x) - f(w_0,x)), y \right) \)

linearize the network with \(f - f_0\)

We would like that for any finite \(n\), in the limit \(\alpha \to \infty\), the network behaves linearly

This \(\alpha^2\) is here to converge in a time that does not scale with \(\alpha\) in the limit \(\alpha \to \infty\)

Loss:

there is a plateau for large values of \(\alpha\)

MNIST 10k parity, FC L=3, softplus, gradient flow with momentum

lazy regime

MNIST 10k parity, FC L=3, softplus, gradient flow with momentum

ensemble average

the ensemble average converge with \(n \to \infty\)

no overlap

\(\alpha\)

MNIST 10k parity, FC L=3, softplus, gradient flow with momentum

plot in function of \(\sqrt{n} \alpha\) overlap the lines

overlap !

\(\sqrt{n}\alpha\)

\(\alpha\)

the phase space is split in two by \(\alpha^*\) who decays with \(\sqrt{n}\)

\(n\)

\(\alpha\)

feature learning

kernel limit

mean field limit

\(\alpha^* \sim \frac{1}{\sqrt{n}}\)

lazy learning

same for other datasets: the trends depends on the dataset

MNIST 10k

reduced to 2 classes

10PCA MNIST 10k

reduced to 2 classes

FC L=3, softplus, gradient flow with momentum

EMNIST 10k

reduced to 2 classes

Fashion MNIST 10k

reduced to 2 classes

FC L=3, softplus, gradient flow with momentum

same for other datasets: the trends depends on the dataset

CIFAR10 10k

reduced to 2 classes

 

CNN SGD ADAM

CNN: the tendency is inverted

how does the learning curves depends on \(n\) and \(\alpha\)

MNIST 10k parity, L=3, softplus, gradient flow with momentum

\(\sqrt{n} \alpha\)

overlap !

same time in lazy

lazy

MNIST 10k parity, L=3, softplus, gradient flow with momentum

there is a characteristic time in the learning curves

\(\sqrt{n} \alpha\)

overlap !

characteristic time \(t_1\)

lazy

\(f(w_0) + \nabla f(w_0) \cdot dw\)

\(f(w)\)

\(f(w_0)\)

\(t_1\) characterise the curvature of the network manifold

\(t_1\)

\(t\)

\(t_1\) is the time you need to drive to realize the earth is curved

\(v\)

\(R\)

\(t_1 \sim R/v\)

\(\dot W^L= -\frac{\partial\mathcal{L}}{\partial W^L} = \mathcal{O}\left(\frac{1}{\alpha \sqrt{n}}\right)\)

\(\Rightarrow t_1 \sim \sqrt{n} \alpha \)

the rate of change of \(W\) determines when we leave the tangent space, aka \(t_1 \sim \sqrt{n}\alpha\)

\(x\)

\(f(w,x)\)

\(W_{ij}\)

\(\displaystyle z^{\ell+1}_i = \sigma \left(\frac1{\sqrt{n}} \sum_{j=1}^n W^\ell_{ij} z^\ell_j \right)\)

\(z_j\)

\(z_i\)

\(W_{ij}\) and \(z_i\) are initialized \(\sim 1\)

\(\dot W_{ij}\) and \(\dot z_i\) at initialization \(\Rightarrow t_1\)

Upper bound: consider weight of last layer

actually the correct scaling

for large \(\alpha\)

  • \(f\) is almost linear: \(f(w,x) \approx f(w_0,x) + \nabla f(w_0,x) \cdot dw\)
  • convergence time does not depend on \(n\) and \(\alpha\)

if \(t_1\) become smaller than the convergence time, it is not linear until the end

\(t_1 \sim \sqrt{n} \alpha\)   v.s.   1          \(\Longrightarrow\)  \({\color{red}\alpha^* \sim \frac{1}{\sqrt{n}}}\)

when the dynamics stops before \(t_1\) we are in the lazy regime

lazy dynamics

  • \(\mathcal{L}(w) = \frac{1}{\alpha^2 n} \sum_{(x,y)\in\text{train}} \ell\left( \alpha (f(w,x) - f(w_0,x)), y \right) \)
     
  • \(\dot f(w, x) = \nabla_w f(w,x) \cdot \dot w\)
  • \(= - \nabla_w f(w,x) \cdot \nabla_w \mathcal{L} \)
  • \(\sim - \nabla_w f \cdot (\frac{\partial}{\partial f} \mathcal{L} \; \nabla_w f)\)
  • \(\sim \frac{1}{\alpha} \Theta\)
     
  • "\( \alpha \dot f t = 1 \)" \(\Rightarrow\) "\( t = \|\Theta\|^{-1} \)"

\(\alpha^* \sim 1/\sqrt{n}\)

then for large \(n\)

\(\Rightarrow\)

\( \alpha^* f(w_0, x) \ll 1 \)

\(\Rightarrow\)

\( \alpha^* (f(w,x) - f(w_0, x)) \approx \alpha^* f(w,x) \)

for large \(n\) our conclusions should holds without this trick

linearize the network with \(f - f_0\) was not necessary

arxiv.org/abs/1906.08034

github.com/mariogeiger/feature_lazy

\(n\)

\(\alpha\)

lazy learning

stop before \(t_1\)

feature learning

stop after \(t_1\)

kernel limit

mean field limit

\(\alpha^* \sim \frac{1}{\sqrt{n}}\)

time to leave the tangent space \(t_1 \sim \sqrt{n}\alpha\)
 => time for learning features

continuous dynamics

\(\dot w = -\nabla \mathcal{L}\)

Implemented with a dynamical adaptation of \(dt\) such that,

\(10^{-4} < \frac{\|\nabla \mathcal{L}(t_{i+1})- \nabla \mathcal{L}(t_i)\|^2}{\|\nabla \mathcal{L}(t_{i+1})\|\cdot\|\nabla \mathcal{L}(t_i)\|} < 10^{-2}\)

works well with full batch and smooth loss

continuous dynamics

\(10^{-4} < \frac{\|\nabla \mathcal{L}(t_{i+1})- \nabla \mathcal{L}(t_i)\|^2}{\|\nabla \mathcal{L}(t_{i+1})\|\cdot\|\nabla \mathcal{L}(t_i)\|} < 10^{-2}\)

momentum dynamics

\(\dot v = -\frac{1}{\tau}(v + \nabla \mathcal{L})\)

\(\dot w = v\)

I took \(\tau \sim \frac{\sqrt{h}}{\alpha}t\) (where \(t\) is the proper time of the dynamic)

kernel grow

\(\frac{\|\Theta - \Theta_0 \|}{\|\Theta_0\|}\)

\(\alpha\)

\(\sqrt{h}\alpha\)

feature-lazy

By Mario Geiger

feature-lazy

Presentation of the paper https://arxiv.org/abs/1906.08034

  • 769