Journal Club @EPFL, 05.01.21

Prevalence of neural collapse during the

terminal phase of deep learning training

Vardan Papyana, X. Y. Hanb and David L. Donoho

  • Modern architectures training involves a terminal phase of training (TPT) in which $$\epsilon_\text{train} = 0, \qquad \text{while} \qquad \mathcal{L}(\text{train-set}) \to 0$$
  • In this phase, the authors observe an inductive bias in last layer activations.
    They call it Neural Collapse (NC).

train error

train loss

Prevalence of neural collapse during the

terminal phase of deep learning training

Vardan Papyana, X. Y. Hanb and David L. Donoho

Some notation

  • Classes: \(\quad c \in \{1, \dots, C\}\)
  • Data-point: \(\quad x_{i,c},\quad\) \(i-\)th example in class \(c\)
     
  • Net:




     
  • Net prediction: $$\widehat{c}(x) = \argmax_c\langle w_c, h(x)\rangle + b_c$$

Features

Classifier weights

Classifier output for class \(c\)

(focus on last layer)

Some more notation

  • Global mean: \(\quad \mu_G = \langle h_{i,c} \rangle_{i,c}\)
  • Class mean: \(\;\;\;\quad \mu_c = \langle h_{i,c} \rangle_{i}\)
  • Within-class covariance: \(\quad \Sigma_W = \langle (h_{i,c}-\mu_c)(h_{i,c}-\mu_c)^T \rangle_{i,c}\)
  • Between-class covariance: \(\quad \Sigma_B = \langle (\mu_c-\mu_G)(\mu_c-\mu_G)^T \rangle_{c}\)

Notice that the features \(h_{i,c}\) and the classifier weights \(w_c\) live in the same space.

That's the space in which we observe Neural Collapse.

w_1
w_2
w_3

Features&Classifiers space

h_{i,c}

Features&Classifiers space

\mu_c

Features&Classifiers space

Within-class covariance: \(\quad \Sigma_W = \langle (h_{i,c}-\mu_c)(h_{i,c}-\mu_c)^T \rangle_{i,c}\)

Between-class covariances: \(\quad \Sigma_B = \langle (\mu_c-\mu_G)(\mu_c-\mu_G)^T \rangle_{c}\)

Simplex Equiangular Tight Frame (ETF):

Set of vectors which are

  1. Equinorm
  2. Equiangle
  3. Angle between them is maximal

Neural Collapse

A net trained past \(\epsilon_\text{train} = 0\), while \(\mathcal{L}(\text{train-set}) \to 0\) frequently induces neural collapse (NC).
NC
is characterized by:

  • [NC1] Within-class variability collapses: $$\Sigma^t_W \to 0 \text{ as } t \to \infty$$
  • [NC2] Class means converge to Simplex ETF
  • [NC3] Convergence to self-duality:
    $$\left\vert\left\vert\dfrac{\mu^t_c}{||\mu^t_c||} - \dfrac{w^t_c}{||w^t_c||} \right\vert \right\vert \to 0$$
  • [NC4] Classification in the last layer is equivalent to nearest-neighbor classifier w.r.t. class means.

VGG-13 trained on CIFAR-10

 

Evidence of Neural Collapse

  • Related to "compression" of features irrelevant for the task.
  • For ImageNet one could argue that this is not exactly going to zero.. It could be that
    • one just needs to spend more time in TPT
    • noisy labels play a role
    • ...

[NC1] Within-class variability collapses

\dfrac{Std_c[||\mu_c - \mu_G||]}{\mathbb{E_c}[||\mu_c - \mu_G||]}

[NC2a] Class means become equinorm

y-axis:

\cos_\mu(c, c') = \dfrac{\langle \mu_c - \mu_G, \mu_{c'} - \mu_G \rangle}{||\mu_c - \mu_G|| \,||\mu_{c'} - \mu_G||}

[NC2b] Class means approach equiangularity

y-axis:

\text{Std}_{c, c' \neq c}\cos_\mu(c, c') ,
\mathbb{E}_{c, c'} |\cos_\mu(c ,c') + 1 / (C-1) |

[NC2c] Class means approach maximal angle

y-axis:

[NC3] Self-duality

Class means matrix: \(\mathbf{M} = [\mu_c - \mu_G,\: c \in \{1, \dots, C\}]\)

Classifiers weights matrix \(\mathbf{W}\) 

\left\vert\left\vert\dfrac{\mathbf{M}}{||\mathbf{M}||_F} - \left(\dfrac{\mathbf{W}}{||\mathbf{W}||_F}\right)^\intercal \right\vert \right\vert_F^2

y-axis:

[NC4] Last layer ~ nearest-neighbor classifier w.r.t. class means

Net classifier:         \(\argmax_c\langle w_c, h(x)\rangle + b_c\)

Nearest-neighbor classifier:         \(\argmin_c ||h(x) - \mu_c||_2\)

y-axis:   mismatches between the two classifiers

Training beyond zero-error improves test error

Why the simplex ETF?
Is it optimal, in some sense?

Simplex ETF emergence in the presence of [NC1]

  • We are given observations \(h = \mu_\gamma + \mathbf z\), with \(\mathbf z \sim \mathcal{N}(\mathbf 0, \sigma^2 \mathbf I)\) and \(\gamma \sim \text{Unif}\{1, \dots, C\}\).
  • Goal: recover \(\gamma\) from \(h\) with a linear classifier :
    $$\hat{\gamma}(h) = \argmax_\gamma\:\langle w_\gamma, h\rangle + b_\gamma$$
  • Task: design a classifier \((\mathbf w, \mathbf b)\) and choose the class means \(\mathbf M = \{\mu_c\}_{c=1}^C\) which minimize the classification error, subject to \(||\mu_c||_2 \leq 1\).
  • Optimiality Criterion: large deviation error exponent

    $$\beta(\mathbf M, \mathbf w, \mathbf b) = -\lim_{\sigma\to 0} \sigma^2 \log P_\sigma\{\hat \gamma (h) \neq \gamma\}$$

  • Theorem. Optimal error exponent \(\beta^* = \max_{(\mathbf M, \mathbf w, \mathbf b)}\beta(\mathbf M, \mathbf w, \mathbf b)\) is achieved by the simplex ETF, \(\mathbf M^*\):
    $$\beta(\mathbf M^*,\mathbf M^*,0) = \beta^*$$

codewords are transmitted over a noisy channel

linear decoder

Information theory perspective

(norm constraint ~ limit to signal strenght)

design decoder and codebook for optimal retrieval

[NC1]

[Shannon, 1959]

Simplex ETF emerges as the optimal structure in the presence of [NC1] and Gaussian noise

Relation to previous results

Sharpening previous results (1)

slides from Papyan talk @ MIT 9.520/6.860: Statistical Learning Theory and Applications Fall 2020 

Sharpening previous results (2)

slides from Papyan talk @ MIT 9.520/6.860: Statistical Learning Theory and Applications Fall 2020 

The margin is the same for each point in the dataset and it is as large as it can possibly be.

Conclusions

  • The paper studies the canonical deep nets training protocol
  • During the terminal phase of training, NC takes place
    • [NC1] No interclass variability
    • [NC2-3] Class means and classifiers tend to simplex ETF
    • [NC4] last-layer classifier = nearest-neighbor classifier
  • NC is optimal under Gaussian noise
  • NC sharpens previous insights

[Journal Club @EPFL] Prevalence of neural collapse during the terminal phase of deep learning training

By Leonardo Petrini

[Journal Club @EPFL] Prevalence of neural collapse during the terminal phase of deep learning training

  • 95