Predicting Attention Sparsity
in Transformers
May 27, 2022
Marcos Treviso
António Góis
Patrick Fernandes
Erick Fonseca
André F. T. Martins
DEEPSPIN
softmax transformer
(Voita et al., 2019)
rare tokens
softmax transformer
(Voita et al., 2019)
α-entmax transformer
(Correia et al., 2019)
rare tokens
neighboring tokens
softmax transformer
(Voita et al., 2019)
α-entmax transformer
(Correia et al., 2019)
rare tokens
neighboring tokens
both are O(n2)
softmax transformer
(Voita et al., 2019)
α-entmax transformer
(Correia et al., 2019)
rare tokens
neighboring tokens
can we save
computation?
O(n2)…O(n)…O(1)
both are O(n2)
Sparsefinder
Gh
q′=WhQq
k′=WhQk
distance to
negative pairs
distance to
positive pairs
margin
lower
dimensionality
Gh
distance-based O(n2)
G^h={(qi,kj)∣∥qi′−kj′∥≤t}
bucketing O(nn)
G^h={(qi,kj)∣f(qi′)∩f(kj′)=∅}
clustering O(nn)
G^h={(qi,kj)∣c(qi′)=c(kj′)}
f yields bucket ids element-wise
c yields a cluster id
Gh
Gh
• G: gold α-entmax graph
• G^: predicted graph
• ∣G∣: number of edges
e.g.,
G: gold α-entmax graph
e.g.,
G^: predicted graph
• G: gold α-entmax graph
• G^: predicted graph
• ∣G∣: number of edges
recall(G,G^)
∣G∣∣G∩G^∣
recall(G,G^)
∣G∣∣G∩G^∣
sparse consistency property
if G⊆G^,
recall(G,G^)=100%
• G: gold α-entmax graph
• G^: predicted graph
• ∣G∣: number of edges
recall(G,G^)
∣G∣∣G∩G^∣
sparsity(G^)
1−nm∣G^∣
sparse consistency property
if G⊆G^,
recall(G,G^)=100%
• G: gold α-entmax graph
• G^: predicted graph
• ∣G∣: number of edges
recall(G,G^)
∣G∣∣G∩G^∣
sparsity(G^)
1−nm∣G^∣
sparse consistency property
if G⊆G^,
recall(G,G^)=100%
efficiency
as sparsity(G^) →1,
save computation (theoretically)
• G: gold α-entmax graph
• G^: predicted graph
• ∣G∣: number of edges
• MT: Transformer trained on IWSLT 2017 with α=1.5
• MLM: RoBERTa finetuned on WikiText-103 with α=1.5
• Evaluation: replace softmax attention by the approximation at test time
• MT: Transformer trained on IWSLT 2017 with α=1.5
• MLM: RoBERTa finetuned on WikiText-103 with α=1.5
• Evaluation: replace softmax attention by the approximation at test time
• Compare all methods with Pareto curves
- window size ∈ {0, 1, 3, 5, ..., 27}
- distance threshold ∈ {0.5, 1.0, 1.5, ..., 5.0}
- num. of buckets / clusters ∈ {2, 4, 6, ..., 20}
- num. of random blocks / global tokens ∈ {2, 4, 6, ..., 20}
EN→FR
EN→FR
EN→FR
EN→FR
EN→FR
EN→FR
EN→DE
• Attention head example
Ground-truth
• Attention head example
Sparsefinder k-means
• Attention head example
After applying 1.5-entmax
• Attention head focusing on coreference tokens
• Learn projections of contiguous-chunked tokens: blocks!
• Learn projections of contiguous-chunked tokens: blocks!
• Resembles BigBird but random pattern is replaced by clustering
v1: top-k queries and keys closest to each centroid
v2: top-k queries for each clusters x top-k keys for each cluster
k = number of attended blocks
window size = 3
• Learn projections of contiguous-chunked tokens: blocks!
• Resembles BigBird but random pattern is replaced by clustering
v1: top-k queries and keys closest to each centroid
v2: top-k queries for each clusters x top-k keys for each cluster
k = number of attended blocks
window size = 3
• Learn projections of contiguous-chunked tokens: blocks!
• Resembles BigBird but random pattern is replaced by clustering
v1: top-k queries and keys closest to each centroid
v2: top-k queries for each clusters x top-k keys for each cluster
k = number of attended blocks
window size = 3
Taxonomy of Efficient Transformer Architectures by (Tay et al., 2020)
Taxonomy of Efficient Transformer Architectures by (Tay et al., 2020)
Sparsefinder
• Avoid the full computation of the score matrix
- distance-based
- bucketing
- clustering
• Avoid the full computation of the score matrix
- distance-based
- bucketing
- clustering
• Favorable sparsity-recall and sparsity-accuracy tradeoff curves
• Attention heads can remain interpretable
• Avoid the full computation of the score matrix
- distance-based
- bucketing
- clustering
• Favorable sparsity-recall and sparsity-accuracy tradeoff curves
• Attention heads can remain interpretable
• Lower bound for computational sparsity
Questions?
• Bottleneck in transformers:
...
1 2 3 4 n
→O(n2)
O(n2)…O(nlogn)…O(n)
• Bottleneck in transformers:
...
1 2 3 4 n
→O(n2)
-💰
+🚀 -💾 +🌱
1️⃣ Extract heads from a pre-trained α-entmax transformer
1️⃣ Extract heads from a pre-trained α-entmax transformer
2️⃣ Learn lower dimensionality projections
q′=WhQq
k′=WhQk
negative pair
positive pair
margin
1️⃣ Extract heads from a pre-trained α-entmax transformer
2️⃣ Learn lower dimensionality projections
3️⃣ Group ⟨q,k⟩ and compute attention within groups
q′=WhQq
k′=WhQk
negative pair
positive pair
margin
1️⃣ Extract heads from a pre-trained α-entmax transformer
2️⃣ Learn lower dimensionality projections
3️⃣ Group ⟨q,k⟩ and compute attention within groups
q′=WhQq
k′=WhQk
negative pair
positive pair
distance-based
G^h={(qi,kj)∣∥qi′−kj′∥≤t}
O(n2)
margin
1️⃣ Extract heads from a pre-trained α-entmax transformer
2️⃣ Learn lower dimensionality projections
3️⃣ Group ⟨q,k⟩ and compute attention within groups
q′=WhQq
k′=WhQk
negative pair
positive pair
distance-based
G^h={(qi,kj)∣∥qi′−kj′∥≤t}
O(n2)
bucketing
quantize each 1,...,r dim.
into β bins of size ⌈n/β⌉
O(nn)
margin
1️⃣ Extract heads from a pre-trained α-entmax transformer
2️⃣ Learn lower dimensionality projections
3️⃣ Group ⟨q,k⟩ and compute attention within groups
q′=WhQq
k′=WhQk
negative pair
positive pair
distance-based
G^h={(qi,kj)∣∥qi′−kj′∥≤t}
O(n2)
bucketing
quantize each 1,...,r dim.
into β bins of size ⌈n/β⌉
O(nn)
clustering
learn centroids c1,...,cB and set points to closest cluster
O(nn)
margin
1️⃣ Extract heads from a pre-trained α-entmax transformer
2️⃣ Learn lower dimensionality projections
3️⃣ Group ⟨q,k⟩ and compute attention within groups
4️⃣ Add window and global patterns to G^
q′=WhQq
k′=WhQk
negative pair
positive pair
distance-based
bucketing
clustering
margin
QK⊤ scores
α-entmax
(
)
||
G: α-entmax attention graph
QK⊤ scores
α-entmax
(
)
||
G: α-entmax attention graph
α-entmax
(
)
G: α-entmax attention graph
||
α-entmax
(
)
G: α-entmax attention graph
||
G: α-entmax attention graph
α-entmax
(
)
||
QK⊤ scores
α-entmax
(
)
||
G: α-entmax attention graph
QK⊤ scores
α-entmax
(
)
||
G: α-entmax attention graph
QK⊤ scores
α-entmax
(
)
||
G: α-entmax attention graph
QK⊤ scores
• Distribution of part-of-speech tags on the validation set
- for each cluster
- for each head
• Distribution of part-of-speech tags on the validation set
- for each cluster
- for each head
• verbs, nouns, auxiliary verbs attend to each other
N= sequence length
C= number of clusters
T=⌈N/C⌉ (max-size of balanced clusters)
O(C×T2)=O(C×N2/C2)=O(N2/C)
If C=N, then
O(N2/C)=O(NN)
N=20, C=4, balanced clusters