arxiv.org/abs/2103.00065

Figure 1

Figure 1

Problem 5K subset of CIFAR10
Loss Quadratic (0-1)
Model 2 hidden layer MLP, tanh
Optimizer Full batch GD

Figure 1

\lambda_{\max}(\nabla^2 \mathcal{L}(w))

Sharpness

=

Figure 1:  Hessian grows until it reaches 2/𝜂

Figure 1:  Hessian grows until it reaches 2/𝜂

(this is weird)

  • Architectures; VGG-11, Resnet-32
  • Run: Full batch with accumulation
  • Batch-Norm*: Ghosting
  • Eigenvalue approx. on 10% of data

For quadratics:          𝜆ₘₐₓ ≤ L   ⇒   𝜂 < 2/L   works

Intuition from quadratics

Most of optimization assumes     𝜆ₘₐₓ ≤ L

https://www.cs.ubc.ca/~schmidtm/Courses/5XX-S20/S1.pdf?page=26

Intuition from quadratics

For quadratics:          𝜆ₘₐₓ ≤ L   ⇒   𝜂 < 2/L   works

For deep learning:          𝜂 = 2/L      ?⇒?      𝜆ₘₐₓ ≈ L

Intuition from quadratics

"sharpness vs generalization"*

On Large-Batch Training for Deep Learning: Generalization Gap and Sharp Minima. Keskar, Mudigere, Nocedal, Smelyanskiy. Tang. https://arxiv.org/abs/1609.04836

"sharpness vs generalization"*

Three Factors Influencing Minima in SGD. Jastrzębski, Kenton, Arpit, Ballas, Fischer, Bengio, Storkey. ​arxiv.org/abs/1711.04623

It's not only on a toy dataset

What's happening?

All trajectories start the same

(for similar step-sizes)

Logistic loss

All trajectories start the same

Checked with Runge-Kutta 

Caveats

 

Tracks less well for ReLU/Max-Pooling but sometime works

What happens when 𝜆ₘₐₓ = 2/𝜂

(Oscillations are not due to stochasticity)

  • The largest eigenvalue of the Hessian increases during training
    ("progressive sharpening")
     
  • The maximum "stable" value depends on the step-size
    (Larger step-size lead to smaller 𝜆ₘₐₓ)
     
  • GD oscillates when it hits that threshold and sharpening stop
    (but still makes progress)

Key empirical observation

"What if",        validation,         NTK regime

What if: Momentum

"𝜂 < 2/L" equivalent for momentum on quadratics

What if: Momentum

What if: Drop the learning rate

What if: Use 𝜂 = 1/𝜆ₜ

Validation

  • Simplest architecture that does this?
     
  • What changes don't affect behavior?
     
  • What changes do?

Simplest architecture with this behavior

Deep (20) Linear Network

Changes that don't affect behavior

  • Loss function: Squared loss, Logistic loss 
  • Activation function: Tanh, Softplus, ELU, ReLU
  • Block: MLP or Convolution

Starting from toy model on CIFAR10 (5K)

MLP, Tanh, Squared loss

MLP, Tanh, Logistic

MLP, ELU, Squared Loss

MLP, ELU, Logistic Loss

CNN, ReLU, Squared Loss

CNN, ReLU, Logistic Loss

Changes that matter

Last week: NTK and Infinite width

gradient flow → init kernel solution

How does this match with "Edge of stability"?

Progressive Sharpening

Initialization Less sharpening with NTK init
Width Wider => Less sharpening
Depth Deeper => More sharpening
Data More data => More sharpening

Changes that matter

  • Start: 2-hidden layer MLP, 200 neurons, tanh
  • Optimizer: Gradient flow (RK4)
  • 5 restarts

Changes that matter: Width

(standard initialization)

Changes that matter: Width

(NTK initialization)

Changes that matter: Depth

(probably width/depth or width/data ratios)

Changes that matter: Dataset Size

Glossed over: BatchNorm, Transformers, SGD, Generalization

  • Great* empirical paper
     
  • NN are not quadratics
  • GD is not monotonic
  • 𝜆ₘₐₓ increases
     
  • In normal training, we're closer to this regime than Gradient Flow/NTK