Feature vs. Lazy Training

Stefano Spigler

Matthieu Wyart

Mario Geiger

Arthur Jacot

https://arxiv.org/abs/1906.08034

Two different regimes in the dynamics of neural networks

feature training

the network learns features

lazy training

no features are learned

\(f(w) \approx f(w_0) + \nabla f(w_0) \cdot dw\)

\(f(w,x)\)

characteristic training time separating the two regimes

training procedure to force each regim

How does the network perform in the infinite width limit?

\(n\to\infty\)

There exist two limits in the literature

\(n\) neurons per layer

Overparametrization

  • perform well
  • theoretically tractable

The 2 limits can be understood in the context of the central limit thm

\(\displaystyle Y = \frac{1}{\color{red}\sqrt{n}} \sum_{i=1}^n X_i \quad \)

As \({\color{red}n} \to \infty, \quad Y \longrightarrow \) Gaussian

\(\langle Y \rangle={\color{red}\sqrt{n}}\langle X_i \rangle\)   

\(\langle Y^2\rangle - \langle Y\rangle^2 = \langle X_i^2 \rangle - \langle X_i\rangle^2\)

\(\langle Y \rangle=\langle X_i \rangle\)   

\(\langle Y^2\rangle - \langle Y\rangle^2 = {\color{red}\frac{1}{n}} (\langle X_i^2 \rangle - \langle X_i\rangle^2)\)

\(\displaystyle Y = \frac{1}{\color{red}n} \sum_{i=1}^n X_i\)

Central Limit Theorem:

\(X_i\)

(i.i.d.)

(Law of large numbers)

regime 1: kernel limit

\(x\)

\(f(w,x)\)

\(W^1_{ij}\)

\(\displaystyle z^{\ell+1}_i = {\frac1{\color{red}\sqrt{n}}} \sum_{j=1}^n W^\ell_{ij} \ \phi(z^\ell_j)\)

At initialization

\(W_{ij} \longleftarrow \mathcal{N}(0, 1)\)

\(w\) = all the weights

\(z^1_j\)

\(z^2_i\)

With this scaling, small correlations adds up significantly => the weights will change only a little during training

[1995 R. M. Neal]

Independent terms in the sum, CLT

=> when \({\color{red}n} \to\infty\) output is a Gaussian process

space of functions \(\mathbb{R}^d \to \mathbb{R}\)

(dimension \(\infty\))

regime 1: kernel limit

\(f(w_0)\)

\(f(w_0) + \sum_\mu c_\mu \Theta(w_0, x_\mu)\)

kernel space

(dimension \(m\))

tangent space

(dimension \(N\))

\(f(w_0) + \nabla f(w_0) \cdot dw\)

\(d\) input dimension

\(N\) number of parameters

\(m\) size of the trainset

\(f(w)\)

network manifold

(dimension \(N\))

[2018 Jacot et al.]

[2018 Du et al.]

[2019 Lee et al.]

The NTK is independent of the initialization
and is constant through learning
=> the network behaves like a kernel method
 

The weights barely change
\(\|dw\|\sim \mathcal{O}(1)\) and \(dW_{ij} \sim \mathcal{O}({\color{red}1/n})\)   (\(\mathcal{O}({\color{red}1/\sqrt{n}})\) at the extremity)
 

The internal activations barely change

\( dz \sim \mathcal{O}({\color{red}1/\sqrt{n}})\)
=> no feature training

for \({\color{red}n}\to\infty\)

neural tengant kernel

\(\Theta(w,x_1,x_2) = \nabla_w f(w, x_1) \cdot \nabla_w f(w, x_2)\)

regime 1: kernel limit

regime 2: mean field limit

\(\displaystyle f(w,x) = \frac{1}{\color{red}n} \sum_{i=1}^n W_i \; \phi \! \left(\frac{1}{\color{red}\sqrt{n}} \sum_{j=1}^n W_{ij} x_j \right)\)

\(W_{ij}\)

\(x\)

\(f(w,x)\)

\(W_{i}\)

was \(\frac{1}{\color{red}\sqrt{n}}\) in the kernel limit

studied theoretically for 1 hidden layer

Another limit !

for \({\color{red}n}\longrightarrow \infty\)

\(\displaystyle f(w,x) = \frac{1}{\color{red}n} \sum_{i=1}^n W_i \; \phi \! \left(\frac{1}{\color{red}\sqrt{n}} \sum_{j=1}^n W_{ij} x_j \right)\)

\(\frac{1}{\color{red}n}\) instead of \(\frac{1}{\color{red}\sqrt{n}}\) implies

  • no output fluctuations at initialization
     
  • we can replace the sum by an integral

\(\displaystyle f(w,x) \approx f(\rho,x) = \int d\rho(W,\vec W) \; W \; \phi \! \left(\frac{1}{\sqrt{n}} \sum_{j=1}^n W_j x_j \right)\)

where \(\rho\) is the density of neuron's weights

regime 2: mean field limit

\(\displaystyle f(\rho,x) = \int d\rho(W,\vec W) \; W \; \phi \! \left(\frac{1}{\sqrt{n}} \sum_{j=1}^n W_j x_j \right)\)

\(\rho\) follows a differential equation

[2018 S. Mei et al], [2018 Rotskoff and Vanden-Eijnden], [2018 Chizat and Bach]

In this limit, the internal activation do change

\(\langle Y \rangle=\langle X_i \rangle\) => feature training

regime 2: mean field limit

What is the difference between the two limits

  • which limit describe better finite \(n\) networks?
  • are there corresponding regimes for finite \(n\)

kernel limit              and              mean field limit

\(\frac{1}{\color{red}\sqrt{n}}\)

\(\frac{1}{\color{red}n}\)

frozen

change

internal activations \(z^\ell\)

\(\displaystyle{\color{red}\alpha} f(w,x) = \frac{{\color{red}\alpha}}{\sqrt{n}} \sum_{i=1}^n W_i \ \phi(z_i)\)

[2019 Chizat and Bach]

  • if \(\alpha\) is fixed constant and \(n\to\infty\) then => kernel limit
  • if \({\color{red}\alpha} \sim \frac{1}{\sqrt{n}}\) and \(n\to\infty\) then => mean field limit

\(z_i\)

\(W_i\)

\(f(w,x)\)

we use the scaling factor \(\alpha\) to investigate

the transition between the two regimes

\( \alpha \cdot ( f(w,x) - f(w_0, x) ) \)

linearize the network with \(f - f_0\)

We would like that for any finite \(n\), in the limit \(\alpha \to \infty\), the network behaves linearly

This \(\alpha^2\) is here to converge in a time that does not scale with \(\alpha\) in the limit \(\alpha \to \infty\)

\(\displaystyle \mathcal{L}(w) = \frac{1}{\alpha^2 |\mathcal{D}|} \sum_{(x,y)\in\ \mathcal{D}} \ell\left( {\color{red}\alpha (f(w,x) - f(w_0,x))}, y \right) \)

loss function

\(\displaystyle f(w,x) =\\ \frac{1}{{\color{red}\sqrt{n}}} \sum_{i=1}^n W^3_i \ \phi(z^3_i)\)

\(W_{ij}(t=0) \longleftarrow \mathcal{N}(0, 1)\)

\(\displaystyle z^{\ell+1}_i = {\frac1{\color{red}\sqrt{n}}} \sum_{j=1}^n W^\ell_{ij} \ \phi(z^\ell_j)\)

\(z^3_i\)

\(W^3_i\)

\(z^2_i\)

\(z^1_i\)

\(W^0_{ij}\)

\(W^1_{ij}\)

\(W^2_{ij}\)

\(x_i\)

\(\dot w = -\nabla_w \mathcal{L}(w)\)

Implemented with a dynamical adaptation of the time step \(dt\) such that,

\(10^{-4} < \frac{\|\nabla \mathcal{L}(t_{i+1})- \nabla \mathcal{L}(t_i)\|^2}{\|\nabla \mathcal{L}(t_{i+1})\|\cdot\|\nabla \mathcal{L}(t_i)\|} < 10^{-2}\)

(works well only with full batch and smooth loss)

continuous dynamics

continuous dynamics

\(10^{-4} < \frac{\|\nabla \mathcal{L}(t_{i+1})- \nabla \mathcal{L}(t_i)\|^2}{\|\nabla \mathcal{L}(t_{i+1})\|\cdot\|\nabla \mathcal{L}(t_i)\|} < 10^{-2}\)

momentum dynamics

\(\dot v = -\frac{1}{\tau}(v + \nabla \mathcal{L})\)

\(\dot w = v\)

there is a plateau for large values of \(\alpha\)

MNIST 10k parity, FC L=3, softplus, gradient flow with momentum

lazy regime

MNIST 10k parity, FC L=3, softplus, gradient flow with momentum

ensemble average

the ensemble average converge with \(n \to \infty\)

no overlap

\(\alpha\)

the ensemble average

\(\displaystyle \bar f(x) = \int f(w(w_0),x) \;d\mu(w_0)\)

MNIST 10k parity, FC L=3, softplus, gradient flow with momentum

plot in function of \(\sqrt{n} \alpha\) overlap the lines

overlap !

\(\sqrt{n}\alpha\)

\(\alpha\)

\(\sqrt{n}\alpha\)

\(\alpha\)

\(\frac{\|\Theta - \Theta_0 \|}{\|\Theta_0\|}\)

the kernel evolution displays two regimes

MNIST 10k parity, FC L=3, softplus, gradient flow with momentum

the phase space is split in two by \(\alpha^*\) who decays with \(\sqrt{n}\)

\(n\)

\(\alpha\)

feature training

kernel limit

mean field limit

\(\alpha^* \sim \frac{1}{\sqrt{n}}\)

lazy training

same for other datasets: the trends depends on the dataset

MNIST 10k

reduced to 2 classes

10PCA MNIST 10k

reduced to 2 classes

FC L=3, softplus, gradient flow with momentum

EMNIST 10k

reduced to 2 classes

Fashion MNIST 10k

reduced to 2 classes

FC L=3, softplus, gradient flow with momentum

same for other datasets: the trends depends on the dataset

CIFAR10 10k

reduced to 2 classes

 

CNN SGD ADAM

CNN: the tendency is inverted

\(\sqrt{n}\alpha\)

how does the learning curves depends on \(n\) and \(\alpha\)

MNIST 10k parity, L=3, softplus, gradient flow with momentum

\(\sqrt{n} \alpha\)

overlap !

same time in lazy

lazy

MNIST 10k parity, L=3, softplus, gradient flow with momentum

there is a characteristic time in the learning curves

\(\sqrt{n} \alpha\)

overlap !

characteristic time \(t_1\)

lazy

\(t_1\) characterise the curvature of the network manifold

\(t_1\)

\(t\)

\(f(w)\)

network manifold

(dimension \(N\))

\(d\) input dimension

\(N\) number of parameters

\(m\) size of the trainset

tangent space

(dimension \(N\))

\(f(w_0) + \nabla f(w_0) \cdot dw\)

\(f(w_0) + \sum_\mu c_\mu \Theta(w_0, x_\mu)\)

kernel space

(dimension \(m\))

\(f(w_0)\)

space of functions \(\mathbb{R}^d \to \mathbb{R}\)

(dimension \(\infty\))

\(t_1\) is the time you need to drive to realize the earth is curved

\(v\)

\(R\)

\(t_1 \sim R/v\)

the rate of change of \(W\) determines when we leave the tangent space, aka \(t_1 \sim \sqrt{n}\alpha\)

\(x\)

\(f(w,x)\)

\(W_{ij}\)

\(\displaystyle z^{\ell+1}_i = \frac1{\sqrt{n}} \sum_{j=1}^n W^\ell_{ij} \ \phi(z^\ell_j)\)

\(z_j\)

\(z_i\)

\(W_{ij}\) and \(z_i\) are initialized \(\sim 1\)

\(\dot W_{ij}\) and \(\dot z_i\) at initialization \(\Rightarrow t_1\)

\(\dot W^\ell = \mathcal{O}(\frac{1}{\alpha n})\)

\(\dot W^0 = \mathcal{O}(\frac{1}{\alpha \sqrt{n}})\)

\(\dot z = \mathcal{O}(\frac{1}{\alpha \sqrt{n}})\)

actually the correct scaling

Upper bound: consider weight of last layer

\(\Rightarrow t_1 \sim \sqrt{n} \alpha \)

\(\dot W^L= -\frac{\partial\mathcal{L}}{\partial W^L} = \mathcal{O}\left(\frac{1}{\alpha \sqrt{n}}\right)\)

convergence time is of order 1 in lazy regime 

  • \(\displaystyle \mathcal{L}(w) = \frac{1}{\alpha^2 |\mathcal{D}|} \sum_{(x,y)\in\ \mathcal{D}} \ell\left( {\color{red}\alpha (f(w,x) - f(w_0,x))}, y \right) \)
     
  • \(\dot f(w, x) = \nabla_w f(w,x) \cdot \dot w\)
  • \(= - \nabla_w f(w,x) \cdot \nabla_w \mathcal{L} \)
  • \(\sim - \nabla_w f \cdot (\frac{\partial}{\partial f} \mathcal{L} \; \nabla_w f)\)
  • \(\sim \frac{1}{\alpha} \Theta\)
     
  • "\( \alpha \dot f t = 1 \)" \(\Rightarrow\) "\( t_{lazy} = \|\Theta\|^{-1} \)"

  \(\Longrightarrow\)  \({\color{red}\alpha^* \sim \frac{1}{\sqrt{n}}}\)

when the dynamics stops before \(t_1\) we are in the lazy regime

\(t_1 \sim \sqrt{n} \alpha \)

\(t_{lazy} \sim 1 \)

(time to converge in the lazy regime)

(time to exit the tangent space)

\(\alpha^* \sim 1/\sqrt{n}\)

then for large \(n\)

\(\Rightarrow\)

\( \alpha^* f(w_0, x) \ll 1 \)

\(\Rightarrow\)

\( \alpha^* (f(w,x) - f(w_0, x)) \approx \alpha^* f(w,x) \)

for large \(n\) our conclusions should holds without this trick

linearize the network with \(f - f_0\) was not necessary

arxiv.org/abs/1906.08034

github.com/mariogeiger/feature_lazy

\(n\)

\(\alpha\)

lazy training

stop before \(t_1\)

feature training

go beyond \(t_1\)

kernel limit

mean field limit

\(\alpha^* \sim \frac{1}{\sqrt{n}}\)

time to leave the tangent space \(t_1 \sim \sqrt{n}\alpha\)
 => time for learning features