Shen Shen
November 1, 2024
[video credit Lena Voita]
Recap: convolution
Enduring principles:
[image credit: Fredo Durand]
Transformers follow similar principles:
1. Chop up signal into patches (divide and conquer)
2. Process each patch identically, and in parallel
Enduring principles:
but not independently; each patch's processing depends on all other patches, allowing us to take into account the full context.
Large Language Models (LLMs) are trained in a self-supervised way
"To date, the cleverest thinker of all time was Issac. "
feature
label
To date, the
cleverest
To date, the cleverest
thinker
To date, the cleverest thinker
was
To date, the cleverest thinker of all time was
Issac
e.g., train to predict the next-word
Auto-regressive
For LLMs (and many other applications), the model used are transformers
[video edited from 3b1b]
[image edited from 3b1b]
\(n\)
\(d\)
input embedding (e.g. via a fixed encoder)
[video edited from 3b1b]
[video edited from 3b1b]
[image edited from 3b1b]
Cross-entropy loss encourages the internal weights update so as to make this probability higher
a
robot
must
obey
distribution over the vocabulary
Transformer
"A robot must obey the orders given it by human beings ..."
push for Prob("robot") to be high
push for Prob("must") to be high
push for Prob("obey") to be high
push for Prob("the") to be high
\(\dots\)
\(\dots\)
\(\dots\)
\(\dots\)
a
robot
must
obey
input embedding
output embedding
\(\dots\)
\(\dots\)
\(\dots\)
\(\dots\)
\(\dots\)
transformer block
transformer block
transformer block
\(L\) blocks
[video edited from 3b1b]
embedding
a
robot
must
obey
input embedding
output embedding
\(\dots\)
\(\dots\)
\(\dots\)
\(\dots\)
\(\dots\)
transformer block
transformer block
transformer block
A sequence of \(n\) tokens, each token in \(\mathbb{R}^{d}\)
a
robot
must
obey
input embedding
\(\dots\)
transformer block
transformer block
transformer block
output embedding
\(\dots\)
\(\dots\)
\(\dots\)
\(\dots\)
a
robot
must
obey
input embedding
output embedding
transformer block
\(\dots\)
\(\dots\)
\(\dots\)
attention layer
fully-connected network
\(\dots\)
[video edited from 3b1b]
[video edited from 3b1b]
a
robot
must
obey
input embedding
output embedding
transformer block
\(\dots\)
\(\dots\)
\(\dots\)
\(\dots\)
attention layer
fully-connected network
the usual weights
attention mechanism
a
robot
must
obey
attention layer
attention mechanism
a
robot
must
obey
\((q,k,v)\)
embedding
attention mechanism
input
embedding
Let's think about dictionary look-up:
apple
pomme
banane
citron
banana
lemon
Key
Value
\(:\)
\(:\)
\(:\)
dict_en2fr = {
"apple" : "pomme",
"banana": "banane",
"lemon" : "citron"}
Having good (query, key, value) embedding enables effective attention.
dict_en2fr = {
"apple" : "pomme",
"banana": "banane",
"lemon" : "citron"}
query = "lemon"
output = dict_en2fr[query]
apple
pomme
banane
citron
banana
lemon
Key
Value
lemon
\(:\)
\(:\)
\(:\)
Query
Output
citron
dict_en2fr = {
"apple" : "pomme",
"banana": "banane",
"lemon" : "citron"}
query = "orange"
output = dict_en2fr[query]
What if we run
Python would complain. 🤯
orange
apple
pomme
banane
citron
banana
lemon
Key
Value
\(:\)
\(:\)
\(:\)
Query
Output
???
What if we run
But we can probably see the rationale behind this:
Query
Key
Value
Output
orange
apple
\(:\)
pomme
banana
\(:\)
banane
lemon
\(:\)
citron
dict_en2fr = {
"apple" : "pomme",
"banana": "banane",
"lemon" : "citron"}
query = "orange"
output = dict_en2fr[query]
0.1
pomme
0.1
banane
0.8
citron
+
+
0.1
pomme
0.1
banane
0.8
citron
+
+
We implicitly assumed the (query, key, value) are represented in 'good' embeddings.
If we are to generalize this idea, we need to:
Query
Key
Value
Output
orange
apple
\(:\)
pomme
banana
\(:\)
banane
lemon
\(:\)
citron
0.1
pomme
0.1
banane
0.8
citron
+
+
0.1
pomme
0.1
banane
0.8
citron
+
+
get this sort of percentage
Query
Key
Value
Output
orange
apple
\(:\)
pomme
0.1
pomme
0.1
banane
0.8
citron
banana
\(:\)
banane
lemon
\(:\)
citron
+
+
orange
orange
0.1
pomme
0.1
banane
0.8
citron
+
+
apple
banana
lemon
orange
very roughly, the attention mechanism does exactly this kind of "soft" look-up:
apple
banana
lemon
orange
orange
orange
Query
Key
Value
Output
orange
apple
\(:\)
pomme
banana
\(:\)
banane
lemon
\(:\)
citron
orange
orange
pomme
banane
citron
0.1
pomme
0.1
banane
0.8
citron
+
+
0.1
pomme
0.1
banane
0.8
citron
+
+
dot-product similarity
very roughly, the attention mechanism does exactly this kind of "soft" look-up:
dot-product similarity
[video edited from 3b1b]
dot-product similarity
apple
banana
lemon
orange
softmax
orange
orange
Query
Key
Value
Output
orange
apple
\(:\)
pomme
banana
\(:\)
banane
lemon
\(:\)
citron
orange
orange
pomme
banane
citron
0.1
pomme
0.1
banane
0.8
citron
+
+
pomme
banane
citron
+
+
very roughly, the attention mechanism does exactly this kind of "soft" look-up:
0.1
0.1
0.8
apple
banana
lemon
orange
softmax
orange
orange
Query
Key
Value
Output
orange
apple
\(:\)
pomme
0.1
pomme
0.1
banane
0.8
citron
banana
\(:\)
banane
lemon
\(:\)
citron
+
+
orange
orange
0.8
pomme
0.1
banane
0.1
citron
+
+
0.1
0.1
0.8
pomme
banane
citron
+
+
and output
weighted average over values
very roughly, the attention mechanism does exactly this kind of "soft" look-up:
???
attention mechanism
a
robot
must
obey
???
softmax
must
obey
a
robot
???
softmax
a
robot
must
obey
???
...
attention mechanism
softmax
???
softmax
???
...
???
...
attention mechanism
softmax
???
softmax
???
...
???
...
...
attention mechanism
softmax
???
softmax
???
each row sums up to 1
softmax
softmax
softmax
softmax
attention matrix
attention mechanism
attention mechanism
attention mechanism
attention mechanism
attention mechanism
a
robot
must
obey
one attention head
attention mechanism
attention layer
a
robot
must
obey
attention mechanism
attention mechanism
attention mechanism
attention mechanism
attention mechanism
multi-headed attention layer
attention mechanism
attention mechanism
each head
we then learn yet another weight \(W_h\) to sum up the outputs from individual head, to be the multi-headed attention layer output.
a
robot
must
obey
attention mechanism
attention mechanism
attention mechanism
attention mechanism
attention mechanism
attention mechanism
attention mechanism
multi-headed attention layer
a
robot
must
obey
attention mechanism
attention mechanism
attention mechanism
attention mechanism
attention mechanism
attention mechanism
attention mechanism
multi-headed attention layer
all in \(\mathbb{R}^{d}\)
Some other ideas commonly used in practice:
masking
Layer normlization
Residual connection
Positional encoding
We will see the details in hw/lab
Neural networks are representation learners
Deep nets transform datapoints, layer by layer
Each layer gives a different representation (aka embedding) of the data
Recall
Training data
maps from complex data space to simple embedding space
Recall
We can tokenize anything.
General strategy: chop the input up into chunks, project each chunk to an embedding
this projection can be fixed from a pre-trained model, or trained jointly with downstream task
a sequence of \(n\) tokens
a projection, e.g. via a fixed, or learned linear transformation
each token \(\in \mathbb{R}^{d}\) embedding
100-by-100
each token \(\in \mathbb{R}^{400}\)
20-by-20
a sequence of \(n=25\) tokens
suppose just flatten
Multi-modality (text + image)
[Bahdanau et al. 2015]
Input sentence: “The agreement on the European Economic Area was signed in August 1992”
Output sentence: “L’accord sur la
zone économique européenne a été signé en août 1992”
[“DINO”, Caron et all. 2021]
Success mode:
Success mode:
[Show, Attend and Tell: Neural Image Caption Generation with Visual Attention. Xu et al. CVPR (2016)]
Failure mode:
[Show, Attend and Tell: Neural Image Caption Generation with Visual Attention. Xu et al. CVPR (2016)]
Failure mode:
[Show, Attend and Tell: Neural Image Caption Generation with Visual Attention. Xu et al. CVPR (2016)]
Failure mode:
for your attention!
We'd love to hear your thoughts.
array of neurons
array of tokens
set of tokens
set of neurons
seq. of
tokens in \(\mathbb{R}^d\)
learned embeddings
a
robot
must
obey
attention head
GPT
a
robot
must
obey
distribution over the entire vocabulary
via learned projection weights
命
命
運
我
操
縱
a
robot
must
obey
input embedding
output embedding
transformer block
\(\dots\)
\(\dots\)
\(\dots\)
\(\dots\)
attention layer
fully-connected network