A jamming transition from under- to over-parametrization
affects loss landscape and generalization

Stefano Spigler, Mario Geiger

Stéphane d’Ascoli      Levent Sagun      Marco Baity-Jesi

Giulio Biroli      Matthieu Wyart

Deep learning: high-dimensional,

non convex problem

 

  • What is the geometry of the landscape?


     
  • How does it depend on the number of parameters?


     
  • Why not stuck in local minima?

Set-up

  • Dataset: \(\large \color{red}P\) points \(\mathbf{x}_1,\dots,\mathbf{x}_P\)
     
  • Binary classification: each point \(\mathbf{x}_\mu\) has a label \(y_\mu=\pm1\)
     
  • Fully-connected network, \(\large \color{red}N\) parameters \(\mathbf{W}\in\mathbb{R}^N\)

\(\overbrace{\phantom{x+x+x+x+x+x++2}}^{L}\)

\(\to f(\mathbf{x}; \mathbf{W})\)

\(\mathbf{x} \to\)

Training

  • The goal of training is to fit all the points \(\mathbf{x}_\mu\):

    \[ y_\mu \cdot f(\mathbf{x}_\mu;\mathbf{W}) > 0\]
     
  • \(\mathbf{W}\) is fixed by minimizing a loss function

    \[\mathcal{L}(\mathbf{W}) = \frac1P \sum_\mu \ell\left(y_\mu \cdot f(\mathbf{x}_\mu; \mathbf{W})\right)\]

number of patterns

number of parameters

under-parametrized

over-parametrized

\(\mathbf W\)

\(\mathbf W\)

\(\mathcal{L}\)

\(\mathcal{L}\)

under-parametrized

over-parametrized

(above upper bound)

Loss landscape

Jamming in spheres

  • \(\large N =\) Number of degrees of freedom
  • \(\large N_\Delta =\) Number of contacts (constraints)

\(N_\Delta/N = 0\)

\(N_\Delta/N \geq 1\)

Loss function

Typical choice:

  • cross entropy

         \(\ell(yf) = \log(1+e^{yf})\)

infinite range

\[\mathcal{L} = \frac1P \sum_\mu \ell\left(y_\mu f_\mu\right)\]

\(\ell(fy)\)

\(fy\)

Loss function

Let's use
 

  • quadratic hinge

       \(\ell(yf) = \theta(\Delta) \frac12 \Delta^2\)

finite range!

same performance,

sharp jamming transition

\(\ell(fy)\)

\(fy\)

\(\Delta=1-yf\)

Dataset:
 

  • \(\mathbf{x}_\mu=\) random points on the sphere,
     
  • \(y_\mu=\) random labels \(\pm1\)

Location of the transition

\(P\)

\(N\)

 \({\large\color{red}N_\Delta} =\) number of unsatisfied constraints \(\Delta_\mu > 0\)

Sharpness of the transition

\[\mathcal{L} = \frac1{P} \sum_\mu  \theta(\Delta_\mu) \frac12 \Delta_\mu^2 \]

\(N \approx 8000\)

\(P/N\)

\(N_\Delta/N\)

After minimization   \(\longrightarrow\)   net is in a minimum of \(\mathcal{L}(\mathbf{W})\)

 

Stability   \(\Longrightarrow\)   positive semi-definite Hessian \(\nabla^2 \mathcal{L}(\mathbf{W})\)

Upper bound

\(N < \mathrm{const} \times P\)

Discontinuous jump

\(N_\Delta > \mathrm{const} \times N\)

\underbrace{\phantom{\frac1P \sum_\mu \theta(\Delta_\mu) \nabla\Delta_\mu\otimes\nabla\Delta_\mu}}_{\large\mathcal{H}_0}
1Pμθ(Δμ)ΔμΔμH0\underbrace{\phantom{\frac1P \sum_\mu \theta(\Delta_\mu) \nabla\Delta_\mu\otimes\nabla\Delta_\mu}}_{\large\mathcal{H}_0}

Let \(N_- =\) number of eigenvalues \(< 0\)

Stability   \(\Longrightarrow \ \ \large N_-\leq N_\Delta\)

\(\mathrm{rank}\leq N_\Delta\)

\underbrace{\phantom{\frac1P \sum_\mu \Delta_\mu\theta(\Delta_\mu)\,\nabla^2 \Delta_\mu}}_{\large\mathcal{H}_p}
1PμΔμθ(Δμ)&ThinSpace;2ΔμHp\underbrace{\phantom{\frac1P \sum_\mu \Delta_\mu\theta(\Delta_\mu)\,\nabla^2 \Delta_\mu}}_{\large\mathcal{H}_p}

\[\mathcal{L} = \frac1{P} \sum_\mu \theta(\Delta_\mu)\; \frac12  \Delta_\mu^2\]

\[ \mathcal{H} = \frac1P \sum_\mu \theta(\Delta_\mu)\; \nabla\Delta_\mu\otimes\nabla\Delta_\mu + \frac1P \sum_\mu \theta(\Delta_\mu)\; \Delta_\mu\, \nabla^2 \Delta_\mu\]

\[\Delta_\mu = 1 - y_\mu f(\mathbf x_\mu; \mathbf W)\]

Spectrum of \(\mathcal{H}_p\)

\(N_- \approx N/2\)

eigenvalue

  • Upper bound:   \(N < \frac{1}{C_0} \times P\)
     
  • Discontinuous jump:   \(N_\Delta > C_0 \times N\)
  • Stability   \(\Longrightarrow \ N_-\leq N_\Delta\leq P\)
     
  • Assumption:  \(N_- = \left\{\begin{array}{ll}C_0\times N & \mathcal{L} > 0 \\ 0 & \mathcal{L} = 0\end{array} \right. \)

\(C_0 \approx 0.5\)   in all observed cases

\(\mathcal{L}=0\)

\(N_-=0\)

\(N_\Delta=0\)

\(\mathcal{L}>0\)

\(N_-= C_0\times N\)

\(N_\Delta>C_0\times N\)

\(\mathbf{x}_\mu = \) MNIST

 

\(y_\mu = \) parity

\(P\)

\(N\)

slope \(\approx\) 3/4

Spectrum of \(\mathcal{H}\)

Over-parametrized

Jamming

Under-parametrized

\(\mathcal{H}_p=0\)

\(N_\Delta=0\)

\(\mathcal{H}_p=0\)

\(N_\Delta>C_0\times N\)

\(\mathcal{H}_p\neq 0\)

\(N_\Delta\nearrow\)

\(\sim\sqrt{\mathcal{L}}\)

\(\leftrightarrow\)

eigenvalue

eigenvalue

eigenvalue

Conclusions

  • Jamming transition delimits over- and under-parametrized networks


     
  • In the over-parametrized phase only \(\mathcal{L}=0\) is stable

(perfect fitting)

Generalization for MNIST

\(P\)

\(N\)

Generalization

  • There is a cusp at jamming

     
  • Test error monotonically decreases in the over-parametrized phase

test error

\(N\)

\(N/N_\mathrm{jamming}\)

Made with Slides.com