Xross Mutual Learning
Deep Mutual Learning
(CVPR 2018)
Net1 Loss = D_KL(y2||y1)
Algorithm
How to fit the neuron/attribute in the middle?
FitNet
(ICLR 2015)
Algorithm
Teacher Net
Student Net
Logits
Logits
Layer afterwards want to fit
L2-loss
+BSKD loss
Potential Hazard
Teacher Net
Student Net
Logits
Logits
Neuron want to fit
1. The neurons between teacher net may exist lots of redundancy -> the constraint of l2-loss is too high
Teacher up-half networks
Student up-half networks
Student down-half networks
Teacher down-half networks
2. L2-loss is superficial
Why talking about Fitnet?
在 Knowledge Distillation上,擬合特徵的幾乎只有Fitnet。
其他KD的方法大約提個:
對Attn做擬合 / 對特徵之間的關係(ex: FSP) 做擬合 / 對輸出的關係做擬合(ex: Graph)
Xross Learning
Stage 1 - Cross the networks
Net1 up-half networks
Net2 up-half networks
Net2 down-half networks
Net1 down-half networks
Cross Networks
Net2 down-half networks
Net1 down-half networks
Net1
half neuron
Net2
half neuron
Why cross net? let neuron1 ~ neuron 2 but without hard constraint
must be close to predict y21
Net2 down-half networks
Net1 down-half networks
About Loss - DML
Net2 down-half networks
Net1 down-half networks
Mutual (update 1)
Net1 up-half networks
Net2 up-half networks
Net2 down-half networks
Net1 down-half networks
About Loss - XML (update 1)
Net2 down-half networks
Net1 down-half networks
Net1 up-half networks
Net2 up-half networks
Lets Look as Teacher-Student Architecture
Teacher-Student's viewpoint
Teacher down-half networks
Student up-half networks
Fixed
- Students up-half needs learn:
How to predict mid-neuron to fit teacher's content?
Student down-half networks
Teacher up-half networks
Fixed
- Students down-half needs learn:
How to use the teachers mid-neuron to answer the final question (or to fit the teachers answer)?
Result - 1
ResNet18 & MobileNet V1
Parameters Distribution
Net1 up-half networks
Net2 up-half networks
Net2 down-half networks
Net1 down-half networks
2,775,104
135,040
3,180,388
8,445,028
ResNet18
MobileNet
V1
2conv block + FC
2conv block
3 residual block
1 residual block + FC
Validation Accuracy - Mobile
Independent: 62.6 DML: 65.9(+3.1) XML: 68.3(+2.4)
(Net 1: ResNet18, Net 2: MobileNet V1), Net1三個都相當接近。
Validation Accuracy - Res18
Independent: 74 DML: 74.3(+0.3) XML: 73.6(-0.7)
XML Val Accuracy (2 - best model)
Net2 down-half networks
Net1 down-half networks
Net2 down-half networks
Net1 down-half networks
Net1 up-half networks
Net2 up-half networks
0.7320
0.7021
0.7225
0.6889
XML Acc Brainstorming
Net1 up-half networks
Net2 up-half networks
0.7225 -> 0.7320 (+0.0095)
0.6889 -> 0.7021 (+0.01312)
Net1 down-half networks
Net2 down-half networks
0.7021 -> 0.7320 (+0.0299)
0.6889 -> 0.7225 (+0.0336)
The mimic of up-half is good,
but down-half is bad.
2,775,104
8,445,028
135,040
3,180,388
Result - 2
ResNet18 & ResNet34
Parameters Distribution
Net1 up-half networks
Net2 up-half networks
Net2 down-half networks
Net1 down-half networks
675,392
1,340,224
19,988,068
10,544,740
ResNet18
ResNet34
2 residual block
2 residual block + FC
2 residual block
2 residual block + FC
Validation Accuracy - Res18
Independent: 73.9 DML: 75.7(+1.8) XML: 76.6(+0.9)
(Net 1: ResNet18, Net 2: ResNet34)
Validation Accuracy - Res34
Independent: 75.5 DML: 76.6(+1.1) XML: 77(+0.4)
(Net 1: ResNet18, Net 2: ResNet34), Net2三個都相當接近。
XML Val Accuracy (2 - best model)
Res34 down-half networks
Res18 down-half networks
Res34 down-half networks
Res18 down-half networks
Res18 up-half networks
Res34 up-half networks
0.7606
0.7625
0.7661
0.7661
What Else?
Dynamic Computation
Net1 part1 networks
Net2 part1 networks
Net3 part1
networks
Net1 part2 networks
Net2 part2 networks
Net3 part2
networks
Net1 part3 networks
Net2 part3 networks
Net3 part3
networks
Time Cost
Conclusion
-
架構相似可能導致XML學的跟DML差不多好?
-
有時候是後半段學的比較好,有時候是前半段比較好,所以其實不一定。
Conclusion
改善前端
想辦法讓小model fit 大model
但是原本的Xross Learning的兩個model distance其實是挺相近的。平均相差(0.004)。
改善後端
這樣其實就是原本的KD/Mutual的問題了。
Why Works?
Why works?
- 單一個up-net或down-net必須要同時去fit兩個輸出或輸入,使得model被迫要進行相似的mutli-task任務。但是因為這兩個任務太過相近,使得flatness或sensitivity變低,進而讓validation更高。 (Generalization)
- 相比Fitnet,XrossNet的解法證明了不用這麼hard的constraint也可以達到feature擬似的效果。而且Fitnet必須要是2-stage(先Fit在Knowledge Distillation)
- 有種類似ensemble training的效果?
Up-half part experience
Is fit neurons needed?
what's the fitting learning curve?
Net1
half neuron
Net2
half neuron
interpolate
2 * net2
- net 1
net2
net1
2 * net1
- net 2
Best
almot Best, but dis > a = 0
catestrophy
this phenomenon occurs in both net1&net2 down half network
Dis1 - connect down 1
Dis1 - connect down 2
Dis2 - connect down 1
Dis2 - connect down 2
Experience 0
Add distance
Net1 up-half networks
Net2 up-half networks
Like Fitnet + Mutual
Net1
half neuron
Net2
half neuron
L1/2 Loss
Validation Score
original XML: 68.3, XML + L2-loss: 68.5 (+0.2)
distance: 0.002 -> 0.001
Experience 1
Down half net -> Discriminator
Change Down-Half Network
Net1 down-half networks
Net2 down-half networks
Classification Score
Synthesis Score (1=like Net1)
Classification Score
Synthesis Score (1=like Net2)
Train Net 1
Net1 down-half networks
Classification Score
Synthesis Score
(1=like Net1)
Net1 up-half networks
Net2 down-half networks
Classification Score
Synthesis Score
(1=like Net2)
Net2 up-half networks
Only Down Half
Generator_1
Discriminator_1
Generator_2
Discriminator_2
Experience 2
Independent Discriminator
Train Net 1
Net1 down-half networks
Classification Score
Net1 up-half networks
Net2 down-half networks
Generator_1
Classification Score
Discriminator
Synthesis Score
Result
沒有變好。有比DML好,但沒有比XML更好。
兩個Generator沒有因此拉近,反而呈現不穩定的局勢。
Xross Mutual Learning
By Arvin Liu
Xross Mutual Learning
- 978