Attention Mechanisms
December 2, 2020
Marcos V. Treviso
Deep Structured Learning
Fall 2020
Neural Attention Mechanisms by Ben Peters
Learning with Sparse Latent Structure by Vlad Niculae
Seq2Seq and Attention by Lena Voita
The elephant in the interpretability room by Jasmijn Bastings
The illustrated transformer
http://jalammar.github.io/illustrated-transformer/
The annotated transformer
http://nlp.seas.harvard.edu/2018/04/03/attention.html
Łukasz Kaiser’s presentation
https://www.youtube.com/watch?v=rBCqOTEfxvg
⚙️
🎩
🖼
📖
Quick recap on RNN-based seq2seq models
Self-attention networks
Attention and its different flavors
Attention interpretability
🗞
🍭
🤳
🔍
dense • sparse • soft • hard • structured
The Transformer
Can we consider attention maps as explanation?
🎯
☁️
⚡️
⚓️
🔍
"dynamic alignments"
"biological retina" glimples
"visual attention"
"inner attention"
"word attention"
"memory networks"
"dynamic alignments"
"biological retina" glimpes
"visual attention"
"inner attention"
"word attention"
"memory networks"
🌟
📸 Vision
⠇
🎙 Speech
⠇
📄 NLP
⠇
context vector
Encoder
Decoder
↺
↺
BiLSTM
LSTM
$$x_1 \, x_2 \,...\, x_n$$
$$y_1 \, y_2 \,...\, y_m$$
Encoder
Decoder
↺
↺
BiLSTM
LSTM
$$x_1 \, x_2 \,...\, x_n$$
$$y_1 \, y_2 \,...\, y_m$$
this bottleneck is a problem!
context vector
Encoder
Decoder
BiLSTM
LSTM
$$x_1 \, x_2 \,...\, x_n$$
context vector
$$y_1 \, y_2 \,...\, y_m$$
↺
↺
Encoder
Decoder
BiLSTM
LSTM
$$x_1 \, x_2 \,...\, x_n$$
context vector
$$y_1 \, y_2 \,...\, y_m$$
↺
↺
Attention Mechanism!
$$h_1$$
$$h_2$$
$$h_i$$
$$h_n$$
$$x_1$$
$$x_2$$
$$x_i$$
$$x_n$$
...
...
...
...
$$y_{j-1}$$
$$y_j$$
...
...
$$\mathrm{Attention}$$
$$q_{j-1}$$
$$q_j$$
$$c_j$$
$$\mathrm{score\ function}$$
$$h_1 \, ...\,h_i \,...\, h_n$$
$$q_{j-1}$$
$$\theta_1 \, ...\,\theta_i \,...\, \theta_n$$
$$\mathrm{softmax}$$
$$p_1 \, ...\,p_i \,...\, p_n$$
$$c_j$$
$$\mathrm{weighted\ sum}$$
$$h_1$$
$$h_2$$
$$h_i$$
$$h_n$$
$$x_1$$
$$x_2$$
$$x_i$$
$$x_n$$
...
...
...
...
$$y_{j-1}$$
$$y_j$$
...
...
$$\mathrm{Attention}$$
$$q_{j-1}$$
$$q_j$$
$$c_j$$
$$\mathrm{score\ function}$$
$$h_1 \, ...\,h_i \,...\, h_n$$
$$q_{j-1}$$
$$\mathrm{softmax}$$
$$p_1 \, ...\,p_i \,...\, p_n$$
$$c_j$$
$$\mathrm{weighted\ sum}$$
query
keys
values
$$\theta_1 \, ...\,\theta_i \,...\, \theta_n$$
$$h_1$$
$$h_2$$
$$h_i$$
$$h_n$$
$$x_1$$
$$x_2$$
$$x_i$$
$$x_n$$
...
...
...
...
$$y_{j-1}$$
$$y_j$$
...
...
$$\mathrm{Attention}$$
$$q_{j-1}$$
$$q_j$$
$$c_j$$
query keys values
$$\mathbf{q} \in \mathbb{R}^{ d_q}$$
$$\mathbf{K} \in \mathbb{R}^{n \times d_k}$$
$$\mathbf{V} \in \mathbb{R}^{n \times d_v}$$
query keys values
$$\mathbf{q} \in \mathbb{R}^{ d_q}$$
$$\mathbf{K} \in \mathbb{R}^{n \times d_k}$$
$$\mathbf{V} \in \mathbb{R}^{n \times d_v}$$
1. Compute a score between q and each kj
$$\boldsymbol{\theta} = \mathrm{score}(\mathbf{q}, \mathbf{K}) \in \mathbb{R}^{n} $$
query keys values
$$\mathbf{q} \in \mathbb{R}^{ d_q}$$
$$\mathbf{K} \in \mathbb{R}^{n \times d_k}$$
$$\mathbf{V} \in \mathbb{R}^{n \times d_v}$$
1. Compute a score between q and each kj
dot-product:
bilinear:
additive:
neural net:
$$\mathbf{k}_j^\top \mathbf{q}, \quad (d_q == d_k)$$
$$\mathbf{k}_j^\top \mathbf{W} \mathbf{q}, \quad \mathbf{W} \in \mathbb{R}^{d_k \times d_q}$$
$$\mathbf{v}^\top \mathrm{tanh}(\mathbf{W}_1 \mathbf{k}_j + \mathbf{W}_2 \mathbf{q})$$
$$\mathrm{MLP}(\mathbf{q}, \mathbf{k}_j); \quad \mathrm{CNN}(\mathbf{q}, \mathbf{K}); \quad ...$$
$$\boldsymbol{\theta} = \mathrm{score}(\mathbf{q}, \mathbf{K}) \in \mathbb{R}^{n} $$
query keys values
$$\mathbf{q} \in \mathbb{R}^{ d_q}$$
$$\mathbf{K} \in \mathbb{R}^{n \times d_k}$$
$$\mathbf{V} \in \mathbb{R}^{n \times d_v}$$
1. Compute a score between q and each kj
2. Map scores to probabilities
$$\boldsymbol{\theta} = \mathrm{score}(\mathbf{q}, \mathbf{K}) \in \mathbb{R}^{n} $$
$$\mathbf{p} = \pi(\boldsymbol{\theta}) \in \triangle^{n} $$
query keys values
$$\mathbf{q} \in \mathbb{R}^{ d_q}$$
$$\mathbf{K} \in \mathbb{R}^{n \times d_k}$$
$$\mathbf{V} \in \mathbb{R}^{n \times d_v}$$
1. Compute a score between q and each kj
2. Map scores to probabilities
softmax:
sparsemax:
$$ \exp(\boldsymbol{\theta}_j) / \sum_k \exp(\boldsymbol{\theta}_k) $$
$$ \mathrm{argmin}_{\mathbf{p} \in \triangle^n} \,||\mathbf{p} - \boldsymbol{\theta}||_2^2 $$
$$\boldsymbol{\theta} = \mathrm{score}(\mathbf{q}, \mathbf{K}) \in \mathbb{R}^{n} $$
$$\mathbf{p} = \pi(\boldsymbol{\theta}) \in \triangle^{n} $$
query keys values
$$\mathbf{q} \in \mathbb{R}^{ d_q}$$
$$\mathbf{K} \in \mathbb{R}^{n \times d_k}$$
$$\mathbf{V} \in \mathbb{R}^{n \times d_v}$$
1. Compute a score between q and each kj
2. Map scores to probabilities
3. Combine values
$$\mathbf{z} = \mathbf{V}^\top \mathbf{p} =\sum\limits_{i=1}^{m} \mathbf{V}_i \mathbf{p}_i \in \mathbb{R}^{d_v}$$
$$\boldsymbol{\theta} = \mathrm{score}(\mathbf{q}, \mathbf{K}) \in \mathbb{R}^{n} $$
$$\mathbf{p} = \pi(\boldsymbol{\theta}) \in \triangle^{n} $$
query keys values
$$\mathbf{q} \in \mathbb{R}^{ d_q}$$
$$\mathbf{K} \in \mathbb{R}^{n \times d_k}$$
$$\mathbf{V} \in \mathbb{R}^{n \times d_v}$$
1. Compute a score between q and each kj
2. Map scores to probabilities
3. Combine values
$$\mathbf{z} = \mathbf{V}^\top \mathbf{p} =\sum\limits_{i=1}^{m} \mathbf{V}_i \mathbf{p}_i \in \mathbb{R}^{d_v}$$
not necessarily in the simplex! e.g.
$$\mathbf{p} = \mathrm{sigmoid}(\boldsymbol{\theta}) $$
$$\boldsymbol{\theta} = \mathrm{score}(\mathbf{q}, \mathbf{K}) \in \mathbb{R}^{n} $$
$$\mathbf{p} = \pi(\boldsymbol{\theta}) \in \triangle^{n} $$
query keys values
$$\mathbf{q} \in \mathbb{R}^{ d_q}$$
$$\mathbf{K} \in \mathbb{R}^{n \times d_k}$$
$$\mathbf{V} \in \mathbb{R}^{n \times d_v}$$
1. Compute a score between q and each kj
$$\boldsymbol{\theta} = \mathrm{score}(\mathbf{q}, \mathbf{K}) \in \mathbb{R}^{n} $$
2. Map scores to probabilities
$$\mathbf{p} = \pi(\boldsymbol{\theta}) \in \triangle^{n} $$
3. Combine values
$$\mathbf{z} = \mathbf{V}^\top \mathbf{p} =\sum\limits_{i=1}^{m} \mathbf{V}_i \mathbf{p}_i \in \mathbb{R}^{d_v}$$
but in this lecture:
$$\sum_i \mathbf{p}_i = 1 \, \\ \forall i, \mathbf{p}_i \geq 0$$
def attention(query, keys, values=None):
"""
query.shape is (batch_size, 1, d)
keys.shape is (batch_size, n, d)
values.shape is (batch_size, n, d)
"""
# use keys as values
if values is None:
values = keys
# STEP 1. scores.shape is (batch_size, 1, n)
scores = torch.matmul(query, keys.transpose(-1, -2))
# STEP 2. probas.shape is (batch_size, 1, n)
probas = torch.softmax(scores, dim=-1)
# STEP 3. c_vector.shape is (batch_size, 1, d)
c_vector = torch.matmul(probas, values)
return c_vector
Dense: \( |\mathrm{supp}(\mathbf{p})| = n\)
Sparse: \( |\mathrm{supp}(\mathbf{p})| < n\)
Fundamental Thm. Lin. Prog.
(Dantzig et al., 1955)
1
1
$$n=2$$
$$\boldsymbol{\theta} = [0.4, 1.4]$$
Fundamental Thm. Lin. Prog.
(Dantzig et al., 1955)
1
1
$$n=2$$
$$\boldsymbol{\theta} = [0.4, 1.4]$$
$$\mathbf{p}^\star = [0,1]$$
Fundamental Thm. Lin. Prog.
(Dantzig et al., 1955)
1
1
$$n=2$$
1
1
1
$$n=3$$
$$\boldsymbol{\theta} = [0.4, 1.4]$$
$$\mathbf{p}^\star = [0,1]$$
$$\boldsymbol{\theta} = [0.3, 0.1, 1.5]$$
Fundamental Thm. Lin. Prog.
(Dantzig et al., 1955)
1
1
$$n=2$$
1
1
1
$$n=3$$
$$\boldsymbol{\theta} = [0.4, 1.4]$$
$$\mathbf{p}^\star = [0,1]$$
$$\mathbf{p}^\star = [0,0,1]$$
$$\boldsymbol{\theta} = [0.3, 0.1, 1.5]$$
Fundamental Thm. Lin. Prog.
(Dantzig et al., 1955)
$$\triangle^3$$
argmax:
$$\triangle^3$$
argmax:
softmax:
$$\triangle^3$$
argmax:
softmax:
sparsemax:
$$\triangle^3$$
argmax:
softmax:
sparsemax:
\(\alpha\)-entmax:
sparsemax
Just compute \(\boldsymbol{\tau}\):
\(O(n\log n)\) or \(O(n)\)*
sparsemax
α-entmax
Just compute \(\boldsymbol{\tau}\):
\(O(n\log n)\) or \(O(n)\)*
Just compute \(\boldsymbol{\tau}\):
\(O(n\log n)\) or \(O(n)\)*
sparsemax
α-entmax
Just compute \(\boldsymbol{\tau}\):
\(O(n\log n)\) or \(O(n)\)*
Just compute \(\boldsymbol{\tau}\):
\(O(n\log n)\) or \(O(n)\)*
Jacobian:
sparsemax
α-entmax
Just compute \(\boldsymbol{\tau}\):
\(O(n\log n)\) or \(O(n)\)*
Just compute \(\boldsymbol{\tau}\):
\(O(n\log n)\) or \(O(n)\)*
(Martins and Astudillo, 2016)
* (Peters, Niculae, and Martins, 2019)
Jacobian:
\(p^\star = [.99, .01, 0] \implies \mathbf{s}=[1,1,0]\)
\(p^\star = [.50, .50, 0] \implies \mathbf{s}=[1,1,0]\)
The Jacobian of sparsemax (\(\alpha=2\)) depends only on the support and not on the actual values of \(\mathbf{p}^\star\)
Tsallis α-entropy regularizer
\(\theta\)
softmax
sparsemax
differentiable node
e.g. softmax/sparsemax
Encoder
Decoder
BiLSTM
LSTM
↺
↺
\(\mathbf{p}\)
\(\boldsymbol{\theta}\)
Encoder
Decoder
BiLSTM
LSTM
↺
↺
\(\mathbf{p}\)
argmax node
\(\boldsymbol{\theta}\)
Encoder
Decoder
BiLSTM
LSTM
↺
↺
\(\mathbf{p}\)
argmax node
\(\boldsymbol{\theta}\)
Encoder
Decoder
BiLSTM
LSTM
↺
↺
\(\mathbf{p}\)
argmax node
\(\boldsymbol{\theta}\)
Encoder
Decoder
BiLSTM
LSTM
↺
↺
\(\mathbf{p}\)
argmax node
\(\boldsymbol{\theta}\)
Encoder
Decoder
BiLSTM
LSTM
↺
↺
\(\mathbf{p}\)
\(\theta_3\)
\(p_3\)
argmax node
\(\boldsymbol{\theta}\)
Soft
"smooth selection"
Continuous representation
"soft" decisions
Differentiable
Just backprop!
Hard
"subset selection"
Discrete representation
"binary" decisions
Non differentiable
REINFORCE / surrogate gradients / reparameterization trick / perturb-and-MAP / etc.
I am going to the store → Vou à loja
I am going to the store → Vou à loja
I am going to the store → Vou à loja
I am going to the store → Vou à loja
$$\triangle^3$$
argmax:
softmax:
sparsemax:
\(\alpha\)-entmax:
fusedmax:
$$\triangle^3$$
argmax:
softmax:
sparsemax:
\(\alpha\)-entmax:
fusedmax:
penalize weight differences between adjacent positions
\(\pi(\theta)\)
independent \(z_i\)
linear dependencies on \(z_i\)
\(\pi(\theta)\)
independent \(z_i\)
linear dependencies on \(z_i\)
Cons:
no sparsity
specialized backprop implementation
\(\pi(\theta)\)
independent \(z_i\)
linear dependencies on \(z_i\)
Cons:
no sparsity
specialized backprop implementation
An alternative: SparseMAP
structured counterpart of sparsemax
model any kind of structure (just need MAP)
plug & play implemenation
$$x_1$$
$$x_2$$
...
$$x_n$$
$$x_1 \quad x_2 \quad x_3 \quad x_4 \quad x_5 \quad x_6 \quad x_7 \quad x_8 \quad x_9 \quad ... \quad x_{n-1} \quad x_{n}$$
(finite-tape) Turing Machine
Pointer network
$$x_1 \, x_2 \,...\, x_n$$
$$\mathbf{r}_1 \, \mathbf{r}_2 \,...\, \mathbf{r}_n$$
$$y_1 \, y_2 \,...\, y_m$$
encode
decode
"The animal didn't cross the street because it was too tired"
"The animal didn't cross the street because it was too tired"
$$\mathbf{Q}_j = \mathbf{K}_j = \mathbf{V}_j \in \mathbb{R}^{d} \quad \iff$$
dot-product scorer!
$$\mathbf{S} = \mathrm{score}(\mathbf{Q}, \mathbf{K}) \in \mathbb{R}^{n \times n} $$
$$\mathbf{P} = \pi(\mathbf{S}) \in \triangle^{n \times n} $$
$$\mathbf{Z} = \mathbf{P} \mathbf{V} \in \mathbb{R}^{n \times d}$$
$$\mathbf{Z} = \mathrm{softmax}\Big(\frac{\mathbf{Q} \mathbf{K}^\top}{\sqrt{d_k}}\Big) \mathbf{V} $$
$$\mathbf{Q}, \mathbf{K}, \mathbf{V} \in \mathbb{R}^{n \times d}$$
2 heads
all heads (8)
class MultiHeadAttention(nn.Module)
def __init__(self, d_size, num_heads, dropout=0.0):
assert d_size % num_heads == 0
self.num_heads = num_heads
self.h_size = d_size // num_heads
self.linear_q = nn.Linear(d_size, self.h_size)
self.linear_k = nn.Linear(d_size, self.h_size)
self.linear_v = nn.Linear(d_size, self.h_size)
self.linear_o = nn.Linear(d_size, d_size)
self.dropout = nn.Dropout(dropout)
def forward(self, queries, keys, values=None):
"""
queries.shape is (batch_size, m, d)
keys.shape is (batch_size, n, d)
values.shape is (batch_size, n, d)
"""
# use keys as values
if values is None:
values = keys
# do all linear projections
queries = self.linear_q(queries)
keys = self.linear_k(keys)
values = self.linear_v(values)
# split heads
batch_size = queries.shape[0]
queries = queries.view(batch_size, -1, self.num_heads, self.h_size).transpose(1, 2)
keys = keys.view(batch_size, -1, self.num_heads, self.h_size).transpose(1, 2)
values = values.view(batch_size, -1, self.num_heads, self.h_size).transpose(1, 2)
# new shapes:
# queries (batch_size, num_heads, m, h_size)
# keys (batch_size, num_heads, n, h_size)
# values (batch_size, num_heads, n, h_size)
# scores.shape is (batch_size, num_heads, m, n)
scores = torch.matmul(queries, keys.transpose(-1, -2)) / sqrt(self.h_size)
# probas.shape is (batch_size, num_heads, m, n)
probas = torch.softmax(scores, dim=-1)
probas = self.dropout(probas)
# cvector.shape is (batch_size, num_heads, m, h_size)
c_vector = torch.matmul(probas, values)
# reshape c_vector to (batch_size, m, d)
c_vector = c_vector.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.h_size)
# apply final linear projection
c_vector = self.linear_o(c_vector)
return c_vector
encoder self-attn
encoder self-attn
encoder self-attn
decoder self-attn (masked)
encoder self-attn
decoder self-attn (masked)
scores.masked_fill_(~mask, float('-inf'))
encoder self-attn
context attention
decoder self-attn (masked)
encoder self-attn
context attention
$$\mathbf{S} = \mathrm{score}(\mathbf{Q}, \mathbf{R}_{enc}) \in \mathbb{R}^{m \times n} $$
$$\mathbf{P} = \pi(\mathbf{S}) \in \triangle^{m \times n} $$
$$\mathbf{Z} = \mathbf{P} \mathbf{R}_{enc} \in \mathbb{R}^{m \times d}$$
decoder self-attn (masked)
$$\mathbf{R}_{enc} = \mathrm{Encoder}(\mathbf{x}) \in \mathbb{R}^{n \times d} $$
\(n\) = seq. length \(d\) = hidden dim \(k\) = kernel size
learn an \(\alpha\) in entmax for each head:
🔥 very active research topic! why?
+🚀 -💾 ⟹ -💰 +🚀 -💾 ⟹ +🌱
$$O(n^2) \quad \dots \quad O(n\log n) \quad \dots \quad O(n)$$
High Performance NLP - EMNLP 2020 (Ilharco et al., 2020)
🦕 an early example in NLP: alignments ⇄ attention
expose which tokens are the most important ones for a particular prediction => saliency map
BiLSTM with attention - basic architecture for text classification tasks
Gradient:
Leave-one-out:
$$\nabla_{\mathbf{x}_i} f(\mathbf{x}_{1:n}) \cdot \mathbf{x}_i$$
$$ f(\mathbf{x}_{1:n}) - f(\mathbf{x}_{-i})$$
Adversarial attention:
$$\max_{\tilde{\alpha} \in \triangle^n} f_{\tilde{\alpha}}(\mathbf{x}_{1:n})$$
$$\mathrm{s.t.} \quad |f_{\tilde{\alpha}(\mathbf{x}_{1:n})} - f_{\alpha^\star}(\mathbf{x}_{1:n})| < \epsilon$$
Fraction of original attention weights removed before first decision flip
Fraction of original attention weights removed before first decision flip
"the number of zeroed attended items is often too large to be helpful as an explanation"
Overall, attention can be manipulated with a negligible drop of performance
Overall, attention can be manipulated with a negligible drop of performance
"we hold that attention scores are used as providing an explanation; not the explanation."
"Jain and Wallace provide alternative distributions which may result in similar predictions, but [...] (ignore the) fact that the model was trained to attend to the tokens it chose"
"Train an adversary that minimizes change in prediction scores, while maximizing changes in the learned attention distributions "
"Train an adversary that minimizes change in prediction scores, while maximizing changes in the learned attention distributions "
"Train an adversary that minimizes change in prediction scores, while maximizing changes in the learned attention distributions "
😄 🙁
hidden state
word embedding
predicted words
masked words
masked selection!
generator \((\phi)\)
predictor \((\theta)\)
lower bound on the log-likelihood
penalties
input gating
input gating
Q: Where did the broncos practice for the Super Bowl?
---
A: The Panthers used the San Jose State practice facility and stayed at the San Jose Marriott. The Broncos practiced at Stanford University and stayed at the Santa Clara Marriott.
input gating
hidden states gating
input gating
hidden states gating
special tokens!
"when the sequence length is larger than the attention head dimension (\(n > d\)), self-attention is not unique"