arxiv.org/abs/2103.00065
Figure 1
Figure 1
Problem | 5K subset of CIFAR10 |
Loss | Quadratic (0-1) |
Model | 2 hidden layer MLP, tanh |
Optimizer | Full batch GD |
Figure 1
Sharpness
=
Figure 1: Hessian grows until it reaches 2/𝜂
Figure 1: Hessian grows until it reaches 2/𝜂
(this is weird)
For quadratics: 𝜆ₘₐₓ ≤ L ⇒ 𝜂 < 2/L works
Intuition from quadratics
Most of optimization assumes 𝜆ₘₐₓ ≤ L
https://www.cs.ubc.ca/~schmidtm/Courses/5XX-S20/S1.pdf?page=26
Intuition from quadratics
For quadratics: 𝜆ₘₐₓ ≤ L ⇒ 𝜂 < 2/L works
For deep learning: 𝜂 = 2/L ?⇒? 𝜆ₘₐₓ ≈ L
Intuition from quadratics
"sharpness vs generalization"*
On Large-Batch Training for Deep Learning: Generalization Gap and Sharp Minima. Keskar, Mudigere, Nocedal, Smelyanskiy. Tang. https://arxiv.org/abs/1609.04836
"sharpness vs generalization"*
Three Factors Influencing Minima in SGD. Jastrzębski, Kenton, Arpit, Ballas, Fischer, Bengio, Storkey. arxiv.org/abs/1711.04623
It's not only on a toy dataset
What's happening?
All trajectories start the same
(for similar step-sizes)
Logistic loss
All trajectories start the same
Checked with Runge-Kutta
Caveats
Tracks less well for ReLU/Max-Pooling but sometime works
What happens when 𝜆ₘₐₓ = 2/𝜂
(Oscillations are not due to stochasticity)
Key empirical observation
"What if", validation, NTK regime
What if: Momentum
"𝜂 < 2/L" equivalent for momentum on quadratics
What if: Momentum
What if: Drop the learning rate
What if: Use 𝜂 = 1/𝜆ₜ
Validation
Simplest architecture with this behavior
Deep (20) Linear Network
Changes that don't affect behavior
Starting from toy model on CIFAR10 (5K)
MLP, Tanh, Squared loss
MLP, Tanh, Logistic
MLP, ELU, Squared Loss
MLP, ELU, Logistic Loss
CNN, ReLU, Squared Loss
CNN, ReLU, Logistic Loss
Changes that matter
Last week: NTK and Infinite width
gradient flow → init kernel solution
How does this match with "Edge of stability"?
Progressive Sharpening
Initialization | Less sharpening with NTK init |
Width | Wider => Less sharpening |
Depth | Deeper => More sharpening |
Data | More data => More sharpening |
Changes that matter
Changes that matter: Width
(standard initialization)
Changes that matter: Width
(NTK initialization)
Changes that matter: Depth
(probably width/depth or width/data ratios)
Changes that matter: Dataset Size
Glossed over: BatchNorm, Transformers, SGD, Generalization