Xross Mutual Learning

Deep Mutual Learning

(CVPR 2018)

Net1 Loss = D_KL(y2||y1)

Algorithm

How to fit the neuron/attribute in the middle?

FitNet

(ICLR 2015)

Algorithm

Teacher Net

Student Net

Logits

Logits

y_{t}
y_{s}
x

Layer afterwards want to fit

W_{r}

L2-loss

+BSKD loss

Potential Hazard

Teacher Net

Student Net

Logits

Logits

y_{t}
y_{s}
x

Neuron want to fit

1. The neurons between teacher net may exist lots of redundancy -> the constraint of l2-loss is too high

Teacher up-half networks

Student up-half networks

Student down-half networks

Teacher down-half networks

2. L2-loss is superficial

Why talking about Fitnet?

在 Knowledge Distillation上,擬合特徵的幾乎只有Fitnet。

其他KD的方法大約提個:
對Attn做擬合 / 對特徵之間的關係(ex: FSP) 做擬合 / 對輸出的關係做擬合(ex: Graph)

Xross Learning

Stage 1 - Cross the networks

y_{1}
x

Net1 up-half networks

Net2 up-half networks

Net2 down-half networks

Net1 down-half networks

Cross Networks

Net2 down-half networks

y_{12}

Net1 down-half networks

y_{21}
y_{2}

Net1

half neuron

Net2

half neuron

Why cross net? let neuron1 ~ neuron 2 but without hard constraint

must be close to predict y21

y_{1}

Net2 down-half networks

Net1 down-half networks

About Loss - DML

Net2 down-half networks

y_{12}

Net1 down-half networks

y_{21}
y_{2}

Mutual (update 1)

\text{CE}(y_{1},y) + \text{KL}(y_2||y_1) +

Net1 up-half networks

Net2 up-half networks

\sim y_2
y_{1}

Net2 down-half networks

Net1 down-half networks

About Loss - XML (update 1)

Net2 down-half networks

y_{12}

Net1 down-half networks

y_{21}
y_{2}

Net1 up-half networks

Net2 up-half networks

\sim y_2
\sim y_2
\sim y_2

Lets Look as Teacher-Student Architecture

Teacher-Student's viewpoint

Teacher down-half networks

y_{12}

Student up-half networks

\sim y_2

Fixed

  •  Students up-half needs learn:
    How to predict mid-neuron to fit teacher's content?

Student down-half networks

y_{12}

Teacher up-half networks

\sim y_2

Fixed

  •  Students down-half needs learn:
    How to use the teachers mid-neuron to answer the final question (or to fit the teachers answer)?

Result - 1

ResNet18 & MobileNet V1

Parameters Distribution

Net1 up-half networks

Net2 up-half networks

Net2 down-half networks

Net1 down-half networks

2,775,104

135,040

3,180,388

8,445,028

ResNet18

MobileNet
V1

2conv block + FC

2conv block

3 residual block

1 residual block + FC

Validation Accuracy - Mobile

Independent: 62.6 DML: 65.9(+3.1) XML: 68.3(+2.4)

(Net 1: ResNet18, Net 2: MobileNet V1), Net1三個都相當接近。

Validation Accuracy - Res18

Independent: 74 DML: 74.3(+0.3) XML: 73.6(-0.7)

 

XML Val Accuracy (2 - best model)

Net2 down-half networks

Net1 down-half networks

Net2 down-half networks

Net1 down-half networks

Net1 up-half networks

Net2 up-half networks

0.7320

0.7021

0.7225

0.6889

XML Acc Brainstorming

Net1 up-half networks

Net2 up-half networks

0.7225 -> 0.7320 (+0.0095)

0.6889 -> 0.7021 (+0.01312)

Net1 down-half networks

Net2 down-half networks

0.7021 -> 0.7320 (+0.0299)

0.6889 -> 0.7225 (+0.0336)

The mimic of up-half is good,

but down-half is bad.

2,775,104

8,445,028

135,040

3,180,388

Result - 2

ResNet18 & ResNet34

Parameters Distribution

Net1 up-half networks

Net2 up-half networks

Net2 down-half networks

Net1 down-half networks

675,392

1,340,224

19,988,068

10,544,740

ResNet18

ResNet34

2 residual block

2 residual block + FC

2 residual block

2 residual block + FC

Validation Accuracy - Res18

Independent: 73.9 DML: 75.7(+1.8) XML: 76.6(+0.9)

(Net 1: ResNet18, Net 2: ResNet34)

Validation Accuracy - Res34

Independent: 75.5 DML: 76.6(+1.1) XML: 77(+0.4)

(Net 1: ResNet18, Net 2: ResNet34), Net2三個都相當接近。

XML Val Accuracy (2 - best model)

Res34 down-half networks

Res18 down-half networks

Res34 down-half networks

Res18 down-half networks

Res18 up-half networks

Res34 up-half networks

0.7606

0.7625

0.7661

0.7661

What Else?

Dynamic Computation

Net1 part1 networks

Net2 part1 networks

Net3 part1

networks

Net1 part2 networks

Net2 part2 networks

Net3 part2

networks

Net1 part3 networks

Net2 part3 networks

Net3 part3

networks

Time Cost

Conclusion

  1. 架構相似可能導致XML學的跟DML差不多好?

  2. 有時候是後半段學的比較好,有時候是前半段比較好,所以其實不一定。

Conclusion

改善前端

想辦法讓小model fit 大model

但是原本的Xross Learning的兩個model distance其實是挺相近的。平均相差(0.004)。

改善後端

這樣其實就是原本的KD/Mutual的問題了。

Why Works?

Why works?

  • 單一個up-net或down-net必須要同時去fit兩個輸出或輸入,使得model被迫要進行相似的mutli-task任務。但是因為這兩個任務太過相近,使得flatness或sensitivity變低,進而讓validation更高。 (Generalization)
  • 相比Fitnet,XrossNet的解法證明了不用這麼hard的constraint也可以達到feature擬似的效果。而且Fitnet必須要是2-stage(先Fit在Knowledge Distillation)
  • 有種類似ensemble training的效果?

Up-half part experience

Is fit neurons needed?

what's the fitting learning curve?

Net1

half neuron

Net2

half neuron

\alpha = 0
\alpha = 1
\alpha = 2
\alpha = -1

interpolate

2 * net2

- net 1

net2

net1

2 * net1

- net 2

Best

almot Best, but dis > a = 0

catestrophy

this phenomenon occurs in both net1&net2 down half network

Dis1 - connect down 1

Dis1 - connect down 2

Dis2 - connect down 1

Dis2 - connect down 2

Experience 0

Add distance

x

Net1 up-half networks

Net2 up-half networks

Like Fitnet + Mutual

Net1

half neuron

Net2

half neuron

L1/2 Loss

Validation Score

original XML: 68.3, XML + L2-loss: 68.5 (+0.2)

distance: 0.002 -> 0.001

Experience 1

Down half net -> Discriminator

Change Down-Half Network

Net1 down-half networks

Net2 down-half networks

Classification Score

Synthesis Score (1=like Net1)

Classification Score

Synthesis Score (1=like Net2)

Train Net 1

Net1 down-half networks

Classification Score

Synthesis Score

(1=like Net1)

Net1 up-half networks

Net2 down-half networks

Classification Score

Synthesis Score

(1=like Net2)

Net2 up-half networks

Only Down Half

Generator_1

Discriminator_1

Generator_2

Discriminator_2

Experience 2

Independent Discriminator

Train Net 1

Net1 down-half networks

Classification Score

Net1 up-half networks

Net2 down-half networks

Generator_1

Classification Score

Discriminator

Synthesis Score

Result

沒有變好。有比DML好,但沒有比XML更好。

兩個Generator沒有因此拉近,反而呈現不穩定的局勢。