arxiv.org/abs/2103.00065
Figure 1
Figure 1
Problem | 5K subset of CIFAR10 |
Loss | Quadratic (0-1) |
Model | 2 hidden layer MLP, tanh |
Optimizer | Full batch GD |
Figure 1
Sharpness
=
Figure 1: Hessian grows until it reaches 2/𝜂
Figure 1: Hessian grows until it reaches 2/𝜂
(this is weird)
- Architectures; VGG-11, Resnet-32
- Run: Full batch with accumulation
- Batch-Norm*: Ghosting
- Eigenvalue approx. on 10% of data
For quadratics: 𝜆ₘₐₓ ≤ L ⇒ 𝜂 < 2/L works
Intuition from quadratics
Most of optimization assumes 𝜆ₘₐₓ ≤ L
https://www.cs.ubc.ca/~schmidtm/Courses/5XX-S20/S1.pdf?page=26
Intuition from quadratics
For quadratics: 𝜆ₘₐₓ ≤ L ⇒ 𝜂 < 2/L works
For deep learning: 𝜂 = 2/L ?⇒? 𝜆ₘₐₓ ≈ L
Intuition from quadratics
"sharpness vs generalization"*
On Large-Batch Training for Deep Learning: Generalization Gap and Sharp Minima. Keskar, Mudigere, Nocedal, Smelyanskiy. Tang. https://arxiv.org/abs/1609.04836
"sharpness vs generalization"*
Three Factors Influencing Minima in SGD. Jastrzębski, Kenton, Arpit, Ballas, Fischer, Bengio, Storkey. arxiv.org/abs/1711.04623
It's not only on a toy dataset
What's happening?
All trajectories start the same
(for similar step-sizes)
Logistic loss
All trajectories start the same
Checked with Runge-Kutta
Caveats
Tracks less well for ReLU/Max-Pooling but sometime works
What happens when 𝜆ₘₐₓ = 2/𝜂
(Oscillations are not due to stochasticity)
- The largest eigenvalue of the Hessian increases during training
("progressive sharpening")
- The maximum "stable" value depends on the step-size
(Larger step-size lead to smaller 𝜆ₘₐₓ)
- GD oscillates when it hits that threshold and sharpening stop
(but still makes progress)
Key empirical observation
"What if", validation, NTK regime
What if: Momentum
"𝜂 < 2/L" equivalent for momentum on quadratics
What if: Momentum
What if: Drop the learning rate
What if: Use 𝜂 = 1/𝜆ₜ
Validation
- Simplest architecture that does this?
- What changes don't affect behavior?
- What changes do?
Simplest architecture with this behavior
Deep (20) Linear Network
Changes that don't affect behavior
- Loss function: Squared loss, Logistic loss
- Activation function: Tanh, Softplus, ELU, ReLU
- Block: MLP or Convolution
Starting from toy model on CIFAR10 (5K)
MLP, Tanh, Squared loss
MLP, Tanh, Logistic
MLP, ELU, Squared Loss
MLP, ELU, Logistic Loss
CNN, ReLU, Squared Loss
CNN, ReLU, Logistic Loss
Changes that matter
Last week: NTK and Infinite width
gradient flow → init kernel solution
How does this match with "Edge of stability"?
Progressive Sharpening
Initialization | Less sharpening with NTK init |
Width | Wider => Less sharpening |
Depth | Deeper => More sharpening |
Data | More data => More sharpening |
Changes that matter
- Start: 2-hidden layer MLP, 200 neurons, tanh
- Optimizer: Gradient flow (RK4)
- 5 restarts
Changes that matter: Width
(standard initialization)
Changes that matter: Width
(NTK initialization)
Changes that matter: Depth
(probably width/depth or width/data ratios)
Changes that matter: Dataset Size
Glossed over: BatchNorm, Transformers, SGD, Generalization
- Great* empirical paper
- NN are not quadratics
- GD is not monotonic
- 𝜆ₘₐₓ increases
- In normal training, we're closer to this regime than Gradient Flow/NTK
stability
By fkunstner
stability
- 214