Marcos V. Treviso
Instituto de Telecomunicações
December 19, 2019
Attention is all you need
https://arxiv.org/abs/1706.03762
The illustrated transformer
http://jalammar.github.io/illustrated-transformer/
The annotated transformer
http://nlp.seas.harvard.edu/2018/04/03/attention.html
Łukasz Kaiser’s presentation
https://www.youtube.com/watch?v=rBCqOTEfxvg
context vector
Encoder
Decoder
↺
↺
BiLSTM
LSTM
x1x2...xn
y1y2...ym
Encoder
Decoder
BiLSTM
LSTM
x1x2...xn
context vector
y1y2...ym
↺
↺
query keys values
q∈Rdq
K∈Rn×dk
V∈Rn×dv
query keys values
q∈Rdq
K∈Rn×dk
V∈Rn×dv
1. Compute a score between q and each kj
s=score(q,K)∈Rn
query keys values
q∈Rdq
K∈Rn×dk
V∈Rn×dv
1. Compute a score between q and each kj
dot-product:
bilinear:
additive:
neural net:
kj⊤q,(dq==dk)
kj⊤Wq,W∈Rdk×dq
v⊤tanh(W1kj+W2q)
MLP(q,kj);CNN(q,K);...
s=score(q,K)∈Rn
query keys values
q∈Rdq
K∈Rn×dk
V∈Rn×dv
1. Compute a score between q and each kj
s=score(q,K)∈Rn
2. Map scores to probabilities
p=π(s)∈△n
query keys values
q∈Rdq
K∈Rn×dk
V∈Rn×dv
1. Compute a score between q and each kj
s=score(q,K)∈Rn
2. Map scores to probabilities
p=π(s)∈△n
softmax:
sparsemax:
exp(sj)/k∑exp(sk)
argminp∈△n∣∣p−s∣∣22
query keys values
q∈Rdq
K∈Rn×dk
V∈Rn×dv
1. Compute a score between q and each kj
s=score(q,K)∈Rn
2. Map scores to probabilities
p=π(s)∈△n
3. Combine values via a weighted sum
z=i=1∑mpiVi∈Rdv
x1
x2
...
xn
x1x2x3x4x5x6x7x8x9...xn−1xn
x1x2...xn
r1r2...rn
y1y2...ym
encode
decode
"The animal didn't cross the street because it was too tired"
"The animal didn't cross the street because it was too tired"
Qj=Kj=Vj∈Rd⟺
dot-product scorer!
S=score(Q,K)∈Rn×n
P=π(S)∈△n×n
Z=PV∈Rn×d
Z=softmax(dkQK⊤)V
Q,K,V∈Rn×d
2 heads
all heads (8)
encoder self-attn
encoder self-attn
encoder self-attn
decoder self-attn (masked)
encoder self-attn
decoder self-attn (masked)
encoder self-attn
context attention
decoder self-attn (masked)
encoder self-attn
context attention
S=score(Q,Renc)∈Rm×n
P=π(S)∈△m×n
Z=PRenc∈Rm×d
decoder self-attn (masked)
Renc=Encoder(x)∈Rn×d
n = seq. length d = hidden dim k = kernel size