A jamming transition from under- to over-parametrization
affects loss landscape and generalization
Stefano Spigler, Mario Geiger
Stéphane d’Ascoli Levent Sagun Marco Baity-Jesi
Giulio Biroli Matthieu Wyart
Deep learning: high-dimensional,
non convex problem
-
What is the geometry of the landscape?
-
How does it depend on the number of parameters?
-
Why not stuck in local minima?
Set-up
- Dataset: \(\large \color{red}P\) points \(\mathbf{x}_1,\dots,\mathbf{x}_P\)
- Binary classification: each point \(\mathbf{x}_\mu\) has a label \(y_\mu=\pm1\)
- Fully-connected network, \(\large \color{red}N\) parameters \(\mathbf{W}\in\mathbb{R}^N\)
\(\overbrace{\phantom{x+x+x+x+x+x++2}}^{L}\)
\(\to f(\mathbf{x}; \mathbf{W})\)
\(\mathbf{x} \to\)
Training
- The goal of training is to fit all the points \(\mathbf{x}_\mu\):
\[ y_\mu \cdot f(\mathbf{x}_\mu;\mathbf{W}) > 0\]
- \(\mathbf{W}\) is fixed by minimizing a loss function
\[\mathcal{L}(\mathbf{W}) = \frac1P \sum_\mu \ell\left(y_\mu \cdot f(\mathbf{x}_\mu; \mathbf{W})\right)\]
number of patterns
number of parameters
under-parametrized
over-parametrized
\(\mathbf W\)
\(\mathbf W\)
\(\mathcal{L}\)
\(\mathcal{L}\)
under-parametrized
over-parametrized
(above upper bound)
Loss landscape
Jamming in spheres
- \(\large N =\) Number of degrees of freedom
- \(\large N_\Delta =\) Number of contacts (constraints)
\(N_\Delta/N = 0\)
\(N_\Delta/N \geq 1\)
Loss function
Typical choice:
-
cross entropy
\(\ell(yf) = \log(1+e^{yf})\)
infinite range
\[\mathcal{L} = \frac1P \sum_\mu \ell\left(y_\mu f_\mu\right)\]
\(\ell(fy)\)
\(fy\)
Loss function
Let's use
-
quadratic hinge
\(\ell(yf) = \theta(\Delta) \frac12 \Delta^2\)
finite range!
same performance,
sharp jamming transition
\(\ell(fy)\)
\(fy\)
\(\Delta=1-yf\)
Dataset:
-
\(\mathbf{x}_\mu=\) random points on the sphere,
- \(y_\mu=\) random labels \(\pm1\)
Location of the transition
\(P\)
\(N\)
\({\large\color{red}N_\Delta} =\) number of unsatisfied constraints \(\Delta_\mu > 0\)
Sharpness of the transition
\[\mathcal{L} = \frac1{P} \sum_\mu \theta(\Delta_\mu) \frac12 \Delta_\mu^2 \]
\(N \approx 8000\)
\(P/N\)
\(N_\Delta/N\)
After minimization \(\longrightarrow\) net is in a minimum of \(\mathcal{L}(\mathbf{W})\)
Stability \(\Longrightarrow\) positive semi-definite Hessian \(\nabla^2 \mathcal{L}(\mathbf{W})\)
Upper bound
\(N < \mathrm{const} \times P\)
Discontinuous jump
\(N_\Delta > \mathrm{const} \times N\)
Let \(N_- =\) number of eigenvalues \(< 0\)
Stability \(\Longrightarrow \ \ \large N_-\leq N_\Delta\)
\(\mathrm{rank}\leq N_\Delta\)
\[\mathcal{L} = \frac1{P} \sum_\mu \theta(\Delta_\mu)\; \frac12 \Delta_\mu^2\]
\[ \mathcal{H} = \frac1P \sum_\mu \theta(\Delta_\mu)\; \nabla\Delta_\mu\otimes\nabla\Delta_\mu + \frac1P \sum_\mu \theta(\Delta_\mu)\; \Delta_\mu\, \nabla^2 \Delta_\mu\]
\[\Delta_\mu = 1 - y_\mu f(\mathbf x_\mu; \mathbf W)\]
Spectrum of \(\mathcal{H}_p\)
\(N_- \approx N/2\)
eigenvalue
- Upper bound: \(N < \frac{1}{C_0} \times P\)
- Discontinuous jump: \(N_\Delta > C_0 \times N\)
-
Stability \(\Longrightarrow \ N_-\leq N_\Delta\leq P\)
- Assumption: \(N_- = \left\{\begin{array}{ll}C_0\times N & \mathcal{L} > 0 \\ 0 & \mathcal{L} = 0\end{array} \right. \)
\(C_0 \approx 0.5\) in all observed cases
\(\mathcal{L}=0\)
\(N_-=0\)
\(N_\Delta=0\)
\(\mathcal{L}>0\)
\(N_-= C_0\times N\)
\(N_\Delta>C_0\times N\)
\(\mathbf{x}_\mu = \) MNIST
\(y_\mu = \) parity
\(P\)
\(N\)
slope \(\approx\) 3/4
Spectrum of \(\mathcal{H}\)
Over-parametrized
Jamming
Under-parametrized
\(\mathcal{H}_p=0\)
\(N_\Delta=0\)
\(\mathcal{H}_p=0\)
\(N_\Delta>C_0\times N\)
\(\mathcal{H}_p\neq 0\)
\(N_\Delta\nearrow\)
\(\sim\sqrt{\mathcal{L}}\)
\(\leftrightarrow\)
eigenvalue
eigenvalue
eigenvalue
Conclusions
-
Jamming transition delimits over- and under-parametrized networks
- In the over-parametrized phase only \(\mathcal{L}=0\) is stable
(perfect fitting)
Generalization for MNIST
\(P\)
\(N\)
Generalization
- There is a cusp at jamming
- Test error monotonically decreases in the over-parametrized phase
test error
\(N\)
\(N/N_\mathrm{jamming}\)
A jamming transition from under- to over-parametrizationaffects loss landscape and generalization
By Stefano Spigler
A jamming transition from under- to over-parametrizationaffects loss landscape and generalization
HEP-AI talk on "A jamming transition from under- to over-parametrizationaffects loss landscape and generalization", arxivs: 1809.09349, 1810.09665
- 1,850