Stefano Spigler, Mario Geiger
Stéphane d’Ascoli Levent Sagun Marco Baity-Jesi
Giulio Biroli Matthieu Wyart
Deep learning: high-dimensional,
non convex problem
Why not stuck in local minima?
\(\overbrace{\phantom{x+x+x+x+x+x++2}}^{L}\)
\(\to f(\mathbf{x}; \mathbf{W})\)
\(\mathbf{x} \to\)
number of patterns
number of parameters
under-parametrized
over-parametrized
\(\mathbf W\)
\(\mathbf W\)
\(\mathcal{L}\)
\(\mathcal{L}\)
under-parametrized
over-parametrized
(above upper bound)
Jamming in spheres
\(N_\Delta/N = 0\)
\(N_\Delta/N \geq 1\)
Typical choice:
infinite range
\[\mathcal{L} = \frac1P \sum_\mu \ell\left(y_\mu f_\mu\right)\]
\(\ell(fy)\)
\(fy\)
Let's use
finite range!
same performance,
sharp jamming transition
\(\ell(fy)\)
\(fy\)
\(\Delta=1-yf\)
Dataset:
\(P\)
\(N\)
\({\large\color{red}N_\Delta} =\) number of unsatisfied constraints \(\Delta_\mu > 0\)
\[\mathcal{L} = \frac1{P} \sum_\mu \theta(\Delta_\mu) \frac12 \Delta_\mu^2 \]
\(N \approx 8000\)
\(P/N\)
\(N_\Delta/N\)
After minimization \(\longrightarrow\) net is in a minimum of \(\mathcal{L}(\mathbf{W})\)
Stability \(\Longrightarrow\) positive semi-definite Hessian \(\nabla^2 \mathcal{L}(\mathbf{W})\)
Upper bound
\(N < \mathrm{const} \times P\)
Discontinuous jump
\(N_\Delta > \mathrm{const} \times N\)
Let \(N_- =\) number of eigenvalues \(< 0\)
Stability \(\Longrightarrow \ \ \large N_-\leq N_\Delta\)
\(\mathrm{rank}\leq N_\Delta\)
\[\mathcal{L} = \frac1{P} \sum_\mu \theta(\Delta_\mu)\; \frac12 \Delta_\mu^2\]
\[ \mathcal{H} = \frac1P \sum_\mu \theta(\Delta_\mu)\; \nabla\Delta_\mu\otimes\nabla\Delta_\mu + \frac1P \sum_\mu \theta(\Delta_\mu)\; \Delta_\mu\, \nabla^2 \Delta_\mu\]
\[\Delta_\mu = 1 - y_\mu f(\mathbf x_\mu; \mathbf W)\]
\(N_- \approx N/2\)
eigenvalue
\(C_0 \approx 0.5\) in all observed cases
\(\mathcal{L}=0\)
\(N_-=0\)
\(N_\Delta=0\)
\(\mathcal{L}>0\)
\(N_-= C_0\times N\)
\(N_\Delta>C_0\times N\)
\(\mathbf{x}_\mu = \) MNIST
\(y_\mu = \) parity
\(P\)
\(N\)
slope \(\approx\) 3/4
Over-parametrized
Jamming
Under-parametrized
\(\mathcal{H}_p=0\)
\(N_\Delta=0\)
\(\mathcal{H}_p=0\)
\(N_\Delta>C_0\times N\)
\(\mathcal{H}_p\neq 0\)
\(N_\Delta\nearrow\)
\(\sim\sqrt{\mathcal{L}}\)
\(\leftrightarrow\)
eigenvalue
eigenvalue
eigenvalue
(perfect fitting)
\(P\)
\(N\)
test error
\(N\)
\(N/N_\mathrm{jamming}\)