Lecture 9: Transformers

 

Shen Shen

November 1, 2024

Intro to Machine Learning

Outline

  • Recap, convolutional neural networks
  • Transformers example use case: large language models
  • Transformers key ideas:
    • Input, and (query, key, value) embedding
    • Attention mechanism
  • (Applications and case studies)

[video credit Lena Voita]

  • Looking locally
  • Parameter sharing
  • Template matching
  • Translational equivariance 
Recap: convolution

Enduring principles:

  1. Chop up signal into patches (divide and conquer)
  2. Process each patch independently and identically (and in parallel)

[image credit: Fredo Durand]

Transformers follow similar principles:

1. Chop up signal into patches (divide and conquer)

2. Process each patch identically, and in parallel

Enduring principles:

  1. Chop up signal into patches (divide and conquer)
  2. Process each patch independently, identically, and in parallel

but not independently; each patch's processing depends on all other patches, allowing us to take into account the full context.

Outline

  • Recap, convolutional neural networks
  • Transformers example use case: large language models
  • Transformers key ideas:
    • Input, and (query, key, value) embedding
    • Attention mechanism
  • (Applications and case studies)

Large Language Models (LLMs) are trained in a self-supervised way

  • Scrape the internet for unlabeled plain texts.
  • Cook up “labels” (prediction targets) from the unlabeled texts.
  • Convert “unsupervised” problem into “supervised” setup.

"To date, the cleverest thinker of all time was Issac. "

feature

label

To date, the

cleverest

\dots
\dots

To date, the cleverest 

thinker

To date, the cleverest thinker

was

\dots
\dots
\dots
\dots
\dots
\dots

To date, the cleverest thinker of all time was 

Issac

e.g., train to predict the next-word

Auto-regressive

For LLMs (and many other applications), the model used are transformers

[video edited from 3b1b]

[image edited from 3b1b]

nn

\underbrace{\hspace{5.98cm}}
\underbrace{\hspace{5.98cm}}
{\left\{ \begin{array}{l} \\ \\ \\ \\ \\ \\ \\ \end{array} \right.
\left\{ \begin{array}{l} \\ \\ \\ \\ \\ \\ \\ \end{array} \right.

dd

input embedding (e.g. via a fixed encoder)

[video edited from 3b1b]

[video edited from 3b1b]

[image edited from 3b1b]

Cross-entropy loss encourages the internal weights update so as to make this probability higher

Outline

  • Recap, convolutional neural networks
  • Transformers example use case: large language models
  • Transformers key ideas:
    • Input, and (query, key, value) embedding
    • Attention mechanism
  • (Applications and case studies)

a

robot

must

obey

{\left\{ \begin{array}{l} \\ \\ \\ \end{array} \right.
\left\{ \begin{array}{l} \\ \\ \\ \end{array} \right.

distribution over the  vocabulary

Transformer

"A robot must obey the orders given it by human beings ..."

push for Prob("robot") to be high

push for Prob("must") to be high

push for Prob("obey") to be high

push for Prob("the") to be high

\dots

\dots

\dots

\dots

a

robot

must

obey

input embedding

output embedding

\dots

\dots

\dots

\dots

transformer block

transformer block

transformer block

{\left\{ \begin{array}{l} \\ \\ \\ \\ \end{array} \right.
\left\{ \begin{array}{l} \\ \\ \\ \\ \end{array} \right.

LL blocks

\dots

[video edited from 3b1b]

embedding

a

robot

must

obey

input embedding

output embedding

\dots

\dots

\dots

\dots

\dots

transformer block

transformer block

transformer block

x(1)x^{(1)}
x^{(1)}
x(2)x^{(2)}
x^{(2)}
x(3)x^{(3)}
x^{(3)}
x(4)x^{(4)}
x^{(4)}

A sequence of nn tokens, each token in Rd\mathbb{R}^{d}

a

robot

must

obey

input embedding

\dots

transformer block

transformer block

transformer block

x(1)x^{(1)}
x^{(1)}
x(2)x^{(2)}
x^{(2)}
x(3)x^{(3)}
x^{(3)}
x(4)x^{(4)}
x^{(4)}

output embedding

\dots

\dots

\dots

\dots

a

robot

must

obey

input embedding

output embedding

transformer block

\dots

\dots

\dots

attention layer

fully-connected network

x(1)x^{(1)}
x^{(1)}
x(2)x^{(2)}
x^{(2)}
x(3)x^{(3)}
x^{(3)}
x(4)x^{(4)}
x^{(4)}

\dots

[video edited from 3b1b]

[video edited from 3b1b]

a

robot

must

obey

input embedding

output embedding

transformer block

\dots

\dots

\dots

\dots

attention layer

fully-connected network

x(1)x^{(1)}
x^{(1)}
x(2)x^{(2)}
x^{(2)}
x(3)x^{(3)}
x^{(3)}
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q

the usual weights

x(4)x^{(4)}
x^{(4)}

attention mechanism

x(1)x^{(1)}
x^{(1)}
x(2)x^{(2)}
x^{(2)}
x(3)x^{(3)}
x^{(3)}
x(4)x^{(4)}
x^{(4)}

a

robot

must

obey

v1v_1
v_1
k1k_1
k_1
q1q_1
q_1
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
v2v_2
v_2
k2k_2
k_2
q2q_2
q_2
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
v3v_3
v_3
k3k_3
k_3
q3q_3
q_3
v4v_4
v_4
q4q_4
q_4
k4k_4
k_4
v4v_4
v_4
v4v_4
v_4
v4v_4
v_4

attention layer

attention mechanism

{\left\{ \begin{array}{l} \\ \\ \end{array} \right.
\left\{ \begin{array}{l} \\ \\ \end{array} \right.
dkd_k
d_k
x(1)x^{(1)}
x^{(1)}
x(2)x^{(2)}
x^{(2)}
x(3)x^{(3)}
x^{(3)}
x(4)x^{(4)}
x^{(4)}

a

robot

must

obey

v1v_1
v_1
k1k_1
k_1
q1q_1
q_1
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
v2v_2
v_2
k2k_2
k_2
q2q_2
q_2
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
v3v_3
v_3
k3k_3
k_3
q3q_3
q_3
v4v_4
v_4
q4q_4
q_4
k4k_4
k_4
v4v_4
v_4
v4v_4
v_4
v4v_4
v_4
  • sequence of dd-dimensional input tokens xx
  • learnable weights, Wq,Wv,WkW_q, W_v, W_k, all in Rd×dk\mathbb{R}^{d \times d_k}
  • map the input sequence into dkd_k-dimensional (qkvqkv) sequence, e.g., q1=WqTx(1)q_1 = W_q^Tx^{(1)}
  • the weights are shared, across the sequence of tokens -- parallel processing

(q,k,v)(q,k,v)

embedding

v1v_1
v_1
k1k_1
k_1
q1q_1
q_1
v2v_2
v_2
k2k_2
k_2
q2q_2
q_2
v3v_3
v_3
k3k_3
k_3
q3q_3
q_3
{\left\{ \begin{array}{l} \\ \\ \end{array} \right.
\left\{ \begin{array}{l} \\ \\ \end{array} \right.
dkd_k
d_k

attention mechanism

x(1)x^{(1)}
x^{(1)}
x(2)x^{(2)}
x^{(2)}
x(3)x^{(3)}
x^{(3)}
x(4)x^{(4)}
x^{(4)}
v4v_4
v_4
q4q_4
q_4
k4k_4
k_4
v4v_4
v_4
v4v_4
v_4
v4v_4
v_4

input

embedding 

Let's think about dictionary look-up:

apple

pomme

banane

citron

banana

lemon

Key

Value

::

::

::

dict_en2fr = { 
  "apple" : "pomme",
  "banana": "banane", 
  "lemon" : "citron"}

Having good (query, key, value) embedding enables effective attention. 

dict_en2fr = { 
  "apple" : "pomme",
  "banana": "banane", 
  "lemon" : "citron"}

query = "lemon" 
output = dict_en2fr[query]

apple

pomme

banane

citron

banana

lemon

Key

Value

lemon

::

::

::

Query

Output

citron

dict_en2fr = { 
  "apple" : "pomme",
  "banana": "banane", 
  "lemon" : "citron"}

query = "orange" 
output = dict_en2fr[query]

What if we run

Python would complain. 🤯

orange

apple

pomme

banane

citron

banana

lemon

Key

Value

::

::

::

Query

Output

???

What if we run

But we can probably see the rationale behind this:

Query

Key

Value

Output

orange

apple

::

pomme

banana

::

banane

lemon

::

citron

dict_en2fr = { 
  "apple" : "pomme",
  "banana": "banane", 
  "lemon" : "citron"}

query = "orange" 
output = dict_en2fr[query]

0.1

pomme

    0.1

banane

   0.8

citron

+

+

0.1

pomme

    0.1

banane

   0.8

citron

+

+

We implicitly assumed the (query, key, value)  are represented in 'good' embeddings.

If we are to generalize this idea, we need to: 

Query

Key

Value

Output

orange

apple

::

pomme

banana

::

banane

lemon

::

citron

0.1

pomme

    0.1

banane

   0.8

citron

+

+

0.1

pomme

    0.1

banane

   0.8

citron

+

+

get this sort of percentage

Query

Key

Value

Output

orange

apple

::

pomme

0.1

pomme

    0.1

banane

   0.8

citron

banana

::

banane

lemon

::

citron

+

+

orange

orange

0.1

pomme

    0.1

banane

   0.8

citron

+

+

apple

banana

lemon

orange

very roughly, the attention mechanism does exactly this kind of "soft" look-up:

apple

banana

lemon

orange

,,
,
,,
,

orange

orange

Query

Key

Value

Output

orange

apple

::

pomme

banana

::

banane

lemon

::

citron

orange

orange

pomme

banane

citron

0.1

pomme

    0.1

banane

   0.8

citron

+

+

0.1

pomme

    0.1

banane

   0.8

citron

+

+

dot-product similarity

very roughly, the attention mechanism does exactly this kind of "soft" look-up:

dot-product similarity

[video edited from 3b1b]

dot-product similarity

apple

banana

lemon

orange

softmax

(\Bigg( \begin{array}{l} \end{array} \Bigg.
\Bigg( \begin{array}{l} \end{array} \Bigg.
)\Bigg) \begin{array}{l} \end{array} \Bigg.
\Bigg) \begin{array}{l} \end{array} \Bigg.
,,
,
,,
,

orange

orange

Query

Key

Value

Output

orange

apple

::

pomme

banana

::

banane

lemon

::

citron

orange

orange

pomme

banane

citron

0.1

pomme

    0.1

banane

   0.8

citron

+

+

pomme

banane

citron

+

+

very roughly, the attention mechanism does exactly this kind of "soft" look-up:

0.1

    0.1

   0.8

=[]=[\quad \quad \quad ]
=[\quad \quad \quad ]

apple

banana

lemon

orange

softmax

(\Bigg( \begin{array}{l} \end{array} \Bigg.
\Bigg( \begin{array}{l} \end{array} \Bigg.
)\Bigg) \begin{array}{l} \end{array} \Bigg.
\Bigg) \begin{array}{l} \end{array} \Bigg.
,,
,
,,
,

orange

orange

Query

Key

Value

Output

orange

apple

::

pomme

0.1

pomme

    0.1

banane

   0.8

citron

banana

::

banane

lemon

::

citron

+

+

orange

orange

0.8

pomme

    0.1

banane

   0.1

citron

+

+

=[]=[\quad \quad \quad ]
=[\quad \quad \quad ]

0.1

    0.1

   0.8

pomme

banane

citron

+

+

and output

weighted average over values 

very roughly, the attention mechanism does exactly this kind of "soft" look-up:

???

v1v_1
v_1
k1k_1
k_1
q1q_1
q_1
v2v_2
v_2
k2k_2
k_2
q2q_2
q_2
v3v_3
v_3
k3k_3
k_3
q3q_3
q_3
v4v_4
v_4
q4q_4
q_4
k4k_4
k_4
v4v_4
v_4
v4v_4
v_4
v4v_4
v_4
q1q_1
q_1
{\left\{ \begin{array}{l} \\ \\ \end{array} \right.
\left\{ \begin{array}{l} \\ \\ \end{array} \right.
dkd_k
d_k
q1q_1
q_1
q1q_1
q_1
q1q_1
q_1
q1q_1
q_1
k1k_1
k_1
k2k_2
k_2
k3k_3
k_3
k4k_4
k_4

attention mechanism

x(1)x^{(1)}
x^{(1)}
x(2)x^{(2)}
x^{(2)}
x(3)x^{(3)}
x^{(3)}
x(4)x^{(4)}
x^{(4)}

a

robot

must

obey

???

v1v_1
v_1
k1k_1
k_1
q1q_1
q_1
v2v_2
v_2
k2k_2
k_2
q2q_2
q_2
v3v_3
v_3
k3k_3
k_3
q3q_3
q_3
v4v_4
v_4
q4q_4
q_4
k4k_4
k_4
v4v_4
v_4
v4v_4
v_4
v4v_4
v_4
q1q_1
q_1
{\left\{ \begin{array}{l} \\ \\ \end{array} \right.
\left\{ \begin{array}{l} \\ \\ \end{array} \right.
dkd_k
d_k
q1q_1
q_1
q1q_1
q_1
q1q_1
q_1
q1q_1
q_1
k1k_1
k_1
k2k_2
k_2
k3k_3
k_3
k4k_4
k_4

softmax

(\Bigg( \begin{array}{l} \end{array} \Bigg.
\Bigg( \begin{array}{l} \end{array} \Bigg.
)\Bigg) \begin{array}{l} \end{array} \Bigg.
\Bigg) \begin{array}{l} \end{array} \Bigg.
,,
,
,,
,
,,
,
[\Bigg[ \begin{array}{l} \end{array} \Bigg.
\Bigg[ \begin{array}{l} \end{array} \Bigg.
]\Bigg] \begin{array}{l} \end{array} \Bigg.
\Bigg] \begin{array}{l} \end{array} \Bigg.
/dk/\sqrt{d_k}
/\sqrt{d_k}
a11a_{11}
a_{11}
a14a_{14}
a_{14}
a12a_{12}
a_{12}
a13a_{13}
a_{13}
v4v_4
v_4
v2v_2
v_2
v3v_3
v_3
v1v_1
v_1
a11a_{11}
a_{11}
a14a_{14}
a_{14}
a12a_{12}
a_{12}
a13a_{13}
a_{13}
==
=
x(1)x^{(1)}
x^{(1)}
x(2)x^{(2)}
x^{(2)}
x(3)x^{(3)}
x^{(3)}
x(4)x^{(4)}
x^{(4)}

must

obey

a

robot

???

v1v_1
v_1
k1k_1
k_1
q1q_1
q_1
v2v_2
v_2
k2k_2
k_2
q2q_2
q_2
v3v_3
v_3
k3k_3
k_3
q3q_3
q_3
v4v_4
v_4
q4q_4
q_4
k4k_4
k_4
v4v_4
v_4
v4v_4
v_4
v4v_4
v_4
q1q_1
q_1
{\left\{ \begin{array}{l} \\ \\ \end{array} \right.
\left\{ \begin{array}{l} \\ \\ \end{array} \right.
dkd_k
d_k
q1q_1
q_1
q1q_1
q_1
q1q_1
q_1
q1q_1
q_1
k1k_1
k_1
k2k_2
k_2
k3k_3
k_3
k4k_4
k_4

softmax

(\Bigg( \begin{array}{l} \end{array} \Bigg.
\Bigg( \begin{array}{l} \end{array} \Bigg.
)\Bigg) \begin{array}{l} \end{array} \Bigg.
\Bigg) \begin{array}{l} \end{array} \Bigg.
,,
,
,,
,
,,
,
[\Bigg[ \begin{array}{l} \end{array} \Bigg.
\Bigg[ \begin{array}{l} \end{array} \Bigg.
]\Bigg] \begin{array}{l} \end{array} \Bigg.
\Bigg] \begin{array}{l} \end{array} \Bigg.
/dk/\sqrt{d_k}
/\sqrt{d_k}
==
=
v4v_4
v_4
v2v_2
v_2
v3v_3
v_3
v1v_1
v_1
++
+
++
+
++
+
==
=
a11a_{11}
a_{11}
a14a_{14}
a_{14}
a12a_{12}
a_{12}
a13a_{13}
a_{13}
a11a_{11}
a_{11}
a14a_{14}
a_{14}
a12a_{12}
a_{12}
a13a_{13}
a_{13}
x(1)x^{(1)}
x^{(1)}
x(2)x^{(2)}
x^{(2)}
x(3)x^{(3)}
x^{(3)}
x(4)x^{(4)}
x^{(4)}

a

robot

must

obey

???

v1v_1
v_1
k1k_1
k_1
q1q_1
q_1
v2v_2
v_2
k2k_2
k_2
q2q_2
q_2
v3v_3
v_3
k3k_3
k_3
q3q_3
q_3
v4v_4
v_4
q4q_4
q_4
k4k_4
k_4
v4v_4
v_4
v4v_4
v_4
v4v_4
v_4
q1q_1
q_1
{\left\{ \begin{array}{l} \\ \\ \end{array} \right.
\left\{ \begin{array}{l} \\ \\ \end{array} \right.
dkd_k
d_k
q2q_2
q_2
q2q_2
q_2
q2q_2
q_2
q2q_2
q_2
k1k_1
k_1
k2k_2
k_2
k3k_3
k_3
k4k_4
k_4

...

attention mechanism

x(1)x^{(1)}
x^{(1)}
x(2)x^{(2)}
x^{(2)}
x(3)x^{(3)}
x^{(3)}
x(4)x^{(4)}
x^{(4)}
v1v_1
v_1
k1k_1
k_1
q1q_1
q_1
v2v_2
v_2
k2k_2
k_2
q2q_2
q_2
v3v_3
v_3
k3k_3
k_3
q3q_3
q_3
v4v_4
v_4
q4q_4
q_4
k4k_4
k_4
v4v_4
v_4
v4v_4
v_4
v4v_4
v_4
q1q_1
q_1
{\left\{ \begin{array}{l} \\ \\ \end{array} \right.
\left\{ \begin{array}{l} \\ \\ \end{array} \right.
dkd_k
d_k
==
=
a21a_{21}
a_{21}
a24a_{24}
a_{24}
a22a_{22}
a_{22}
a23a_{23}
a_{23}
q2q_2
q_2
q2q_2
q_2
q2q_2
q_2
q2q_2
q_2
k1k_1
k_1
k2k_2
k_2
k3k_3
k_3
k4k_4
k_4

softmax

(\Bigg( \begin{array}{l} \end{array} \Bigg.
\Bigg( \begin{array}{l} \end{array} \Bigg.
)\Bigg) \begin{array}{l} \end{array} \Bigg.
\Bigg) \begin{array}{l} \end{array} \Bigg.
,,
,
,,
,
,,
,
[\Bigg[ \begin{array}{l} \end{array} \Bigg.
\Bigg[ \begin{array}{l} \end{array} \Bigg.
]\Bigg] \begin{array}{l} \end{array} \Bigg.
\Bigg] \begin{array}{l} \end{array} \Bigg.
/dk/\sqrt{d_k}
/\sqrt{d_k}
v4v_4
v_4
v2v_2
v_2
v3v_3
v_3
v1v_1
v_1
a21a_{21}
a_{21}
a24a_{24}
a_{24}
a22a_{22}
a_{22}
a23a_{23}
a_{23}

???

x(1)x^{(1)}
x^{(1)}
x(2)x^{(2)}
x^{(2)}
x(3)x^{(3)}
x^{(3)}
x(4)x^{(4)}
x^{(4)}
v1v_1
v_1
k1k_1
k_1
q1q_1
q_1
v2v_2
v_2
k2k_2
k_2
q2q_2
q_2
v3v_3
v_3
k3k_3
k_3
q3q_3
q_3
v4v_4
v_4
q4q_4
q_4
k4k_4
k_4
v4v_4
v_4
v4v_4
v_4
v4v_4
v_4
q1q_1
q_1
{\left\{ \begin{array}{l} \\ \\ \end{array} \right.
\left\{ \begin{array}{l} \\ \\ \end{array} \right.
dkd_k
d_k
==
=
==
=
a21a_{21}
a_{21}
a24a_{24}
a_{24}
a22a_{22}
a_{22}
a23a_{23}
a_{23}
q2q_2
q_2
q2q_2
q_2
q2q_2
q_2
q2q_2
q_2
k1k_1
k_1
k2k_2
k_2
k3k_3
k_3
k4k_4
k_4

softmax

(\Bigg( \begin{array}{l} \end{array} \Bigg.
\Bigg( \begin{array}{l} \end{array} \Bigg.
)\Bigg) \begin{array}{l} \end{array} \Bigg.
\Bigg) \begin{array}{l} \end{array} \Bigg.
,,
,
,,
,
,,
,
[\Bigg[ \begin{array}{l} \end{array} \Bigg.
\Bigg[ \begin{array}{l} \end{array} \Bigg.
]\Bigg] \begin{array}{l} \end{array} \Bigg.
\Bigg] \begin{array}{l} \end{array} \Bigg.
/dk/\sqrt{d_k}
/\sqrt{d_k}
v4v_4
v_4
v2v_2
v_2
v3v_3
v_3
v1v_1
v_1
++
+
++
+
++
+
a21a_{21}
a_{21}
a24a_{24}
a_{24}
a22a_{22}
a_{22}
a23a_{23}
a_{23}

???

x(1)x^{(1)}
x^{(1)}
x(2)x^{(2)}
x^{(2)}
x(3)x^{(3)}
x^{(3)}
x(4)x^{(4)}
x^{(4)}
v1v_1
v_1
k1k_1
k_1
q1q_1
q_1
v2v_2
v_2
k2k_2
k_2
q2q_2
q_2
v3v_3
v_3
k3k_3
k_3
q3q_3
q_3
v4v_4
v_4
q4q_4
q_4
k4k_4
k_4
v4v_4
v_4
v4v_4
v_4
v4v_4
v_4
q1q_1
q_1
{\left\{ \begin{array}{l} \\ \\ \end{array} \right.
\left\{ \begin{array}{l} \\ \\ \end{array} \right.
dkd_k
d_k
q3q_3
q_3
q3q_3
q_3
q3q_3
q_3
q3q_3
q_3
k1k_1
k_1
k2k_2
k_2
k3k_3
k_3
k4k_4
k_4

...

???

...

attention mechanism

x(1)x^{(1)}
x^{(1)}
x(2)x^{(2)}
x^{(2)}
x(3)x^{(3)}
x^{(3)}
x(4)x^{(4)}
x^{(4)}
v1v_1
v_1
k1k_1
k_1
q1q_1
q_1
v2v_2
v_2
k2k_2
k_2
q2q_2
q_2
v3v_3
v_3
k3k_3
k_3
q3q_3
q_3
v4v_4
v_4
q4q_4
q_4
k4k_4
k_4
v4v_4
v_4
v4v_4
v_4
v4v_4
v_4
q1q_1
q_1
{\left\{ \begin{array}{l} \\ \\ \end{array} \right.
\left\{ \begin{array}{l} \\ \\ \end{array} \right.
dkd_k
d_k
==
=
a31a_{31}
a_{31}
a34a_{34}
a_{34}
a32a_{32}
a_{32}
a33a_{3 3}
a_{3 3}
q3q_3
q_3
q3q_3
q_3
q3q_3
q_3
q3q_3
q_3
k1k_1
k_1
k2k_2
k_2
k3k_3
k_3
k4k_4
k_4

softmax

(\Bigg( \begin{array}{l} \end{array} \Bigg.
\Bigg( \begin{array}{l} \end{array} \Bigg.
)\Bigg) \begin{array}{l} \end{array} \Bigg.
\Bigg) \begin{array}{l} \end{array} \Bigg.
,,
,
,,
,
,,
,
[\Bigg[ \begin{array}{l} \end{array} \Bigg.
\Bigg[ \begin{array}{l} \end{array} \Bigg.
]\Bigg] \begin{array}{l} \end{array} \Bigg.
\Bigg] \begin{array}{l} \end{array} \Bigg.
/dk/\sqrt{d_k}
/\sqrt{d_k}
v4v_4
v_4
v2v_2
v_2
v3v_3
v_3
v1v_1
v_1
a31a_{31}
a_{31}
a34a_{34}
a_{34}
a32a_{32}
a_{32}
a33a_{33}
a_{33}

???

x(1)x^{(1)}
x^{(1)}
x(2)x^{(2)}
x^{(2)}
x(3)x^{(3)}
x^{(3)}
x(4)x^{(4)}
x^{(4)}
v1v_1
v_1
k1k_1
k_1
q1q_1
q_1
v2v_2
v_2
k2k_2
k_2
q2q_2
q_2
v3v_3
v_3
k3k_3
k_3
q3q_3
q_3
v4v_4
v_4
q4q_4
q_4
k4k_4
k_4
v4v_4
v_4
v4v_4
v_4
v4v_4
v_4
q1q_1
q_1
{\left\{ \begin{array}{l} \\ \\ \end{array} \right.
\left\{ \begin{array}{l} \\ \\ \end{array} \right.
dkd_k
d_k
==
=
==
=
a31a_{31}
a_{31}
a34a_{34}
a_{34}
a32a_{32}
a_{32}
a33a_{3 3}
a_{3 3}
q3q_3
q_3
q3q_3
q_3
q3q_3
q_3
q3q_3
q_3
k1k_1
k_1
k2k_2
k_2
k3k_3
k_3
k4k_4
k_4

softmax

(\Bigg( \begin{array}{l} \end{array} \Bigg.
\Bigg( \begin{array}{l} \end{array} \Bigg.
)\Bigg) \begin{array}{l} \end{array} \Bigg.
\Bigg) \begin{array}{l} \end{array} \Bigg.
,,
,
,,
,
,,
,
[\Bigg[ \begin{array}{l} \end{array} \Bigg.
\Bigg[ \begin{array}{l} \end{array} \Bigg.
]\Bigg] \begin{array}{l} \end{array} \Bigg.
\Bigg] \begin{array}{l} \end{array} \Bigg.
/dk/\sqrt{d_k}
/\sqrt{d_k}
v4v_4
v_4
v2v_2
v_2
v3v_3
v_3
v1v_1
v_1
++
+
++
+
++
+
a31a_{31}
a_{31}
a34a_{34}
a_{34}
a32a_{32}
a_{32}
a33a_{33}
a_{33}

???

x(1)x^{(1)}
x^{(1)}
x(2)x^{(2)}
x^{(2)}
x(3)x^{(3)}
x^{(3)}
x(4)x^{(4)}
x^{(4)}
v1v_1
v_1
k1k_1
k_1
q1q_1
q_1
v2v_2
v_2
k2k_2
k_2
q2q_2
q_2
v3v_3
v_3
k3k_3
k_3
q3q_3
q_3
v4v_4
v_4
q4q_4
q_4
k4k_4
k_4
v4v_4
v_4
v4v_4
v_4
v4v_4
v_4
q1q_1
q_1
{\left\{ \begin{array}{l} \\ \\ \end{array} \right.
\left\{ \begin{array}{l} \\ \\ \end{array} \right.
dkd_k
d_k
q4q_4
q_4
q4q_4
q_4
q4q_4
q_4
q4q_4
q_4
k1k_1
k_1
k2k_2
k_2
k3k_3
k_3
k4k_4
k_4

...

???

...

...

attention mechanism

x(1)x^{(1)}
x^{(1)}
x(2)x^{(2)}
x^{(2)}
x(3)x^{(3)}
x^{(3)}
x(4)x^{(4)}
x^{(4)}
v1v_1
v_1
k1k_1
k_1
q1q_1
q_1
v2v_2
v_2
k2k_2
k_2
q2q_2
q_2
v3v_3
v_3
k3k_3
k_3
q3q_3
q_3
v4v_4
v_4
q4q_4
q_4
k4k_4
k_4
v4v_4
v_4
v4v_4
v_4
v4v_4
v_4
q1q_1
q_1
{\left\{ \begin{array}{l} \\ \\ \end{array} \right.
\left\{ \begin{array}{l} \\ \\ \end{array} \right.
dkd_k
d_k
==
=
a41a_{41}
a_{41}
a44a_{44}
a_{44}
a42a_{42}
a_{42}
a43a_{43}
a_{43}
q4q_4
q_4
q4q_4
q_4
q4q_4
q_4
q4q_4
q_4
k1k_1
k_1
k2k_2
k_2
k3k_3
k_3
k4k_4
k_4

softmax

(\Bigg( \begin{array}{l} \end{array} \Bigg.
\Bigg( \begin{array}{l} \end{array} \Bigg.
)\Bigg) \begin{array}{l} \end{array} \Bigg.
\Bigg) \begin{array}{l} \end{array} \Bigg.
,,
,
,,
,
,,
,
[\Bigg[ \begin{array}{l} \end{array} \Bigg.
\Bigg[ \begin{array}{l} \end{array} \Bigg.
]\Bigg] \begin{array}{l} \end{array} \Bigg.
\Bigg] \begin{array}{l} \end{array} \Bigg.
/dk/\sqrt{d_k}
/\sqrt{d_k}

???

v1v_1
v_1
v2v_2
v_2
v3v_3
v_3
v4v_4
v_4
a41a_{41}
a_{41}
a42a_{42}
a_{42}
a43a_{43}
a_{43}
a44a_{44}
a_{44}
x(1)x^{(1)}
x^{(1)}
x(2)x^{(2)}
x^{(2)}
x(3)x^{(3)}
x^{(3)}
x(4)x^{(4)}
x^{(4)}
v1v_1
v_1
k1k_1
k_1
q1q_1
q_1
v2v_2
v_2
k2k_2
k_2
q2q_2
q_2
v3v_3
v_3
k3k_3
k_3
q3q_3
q_3
v4v_4
v_4
q4q_4
q_4
k4k_4
k_4
v4v_4
v_4
v4v_4
v_4
v4v_4
v_4
q1q_1
q_1
{\left\{ \begin{array}{l} \\ \\ \end{array} \right.
\left\{ \begin{array}{l} \\ \\ \end{array} \right.
dkd_k
d_k
==
=
a41a_{41}
a_{41}
a44a_{44}
a_{44}
a42a_{42}
a_{42}
a43a_{43}
a_{43}
q4q_4
q_4
q4q_4
q_4
q4q_4
q_4
q4q_4
q_4
k1k_1
k_1
k2k_2
k_2
k3k_3
k_3
k4k_4
k_4

softmax

(\Bigg( \begin{array}{l} \end{array} \Bigg.
\Bigg( \begin{array}{l} \end{array} \Bigg.
)\Bigg) \begin{array}{l} \end{array} \Bigg.
\Bigg) \begin{array}{l} \end{array} \Bigg.
,,
,
,,
,
,,
,
[\Bigg[ \begin{array}{l} \end{array} \Bigg.
\Bigg[ \begin{array}{l} \end{array} \Bigg.
]\Bigg] \begin{array}{l} \end{array} \Bigg.
\Bigg] \begin{array}{l} \end{array} \Bigg.
/dk/\sqrt{d_k}
/\sqrt{d_k}

???

==
=
v2v_2
v_2
v3v_3
v_3
v1v_1
v_1
v4v_4
v_4
++
+
++
+
++
+
a41a_{41}
a_{41}
a44a_{44}
a_{44}
a42a_{42}
a_{42}
a43a_{43}
a_{43}
x(1)x^{(1)}
x^{(1)}
x(2)x^{(2)}
x^{(2)}
x(3)x^{(3)}
x^{(3)}
x(4)x^{(4)}
x^{(4)}
q4q_4
q_4
q1q_1
q_1
q2q_2
q_2
q3q_3
q_3
Q=Q =
Q =
k2k_2
k_2
k1k_1
k_1
=K= K
= K
k3k_3
k_3
k4k_4
k_4
Rn×dk\mathbb{R}^{n \times d_k}
\mathbb{R}^{n \times d_k}
Rn×dk\mathbb{R}^{n \times d_k}
\mathbb{R}^{n \times d_k}
q4q_4
q_4
q1q_1
q_1
q2q_2
q_2
q3q_3
q_3
Q=Q =
Q =
k2k_2
k_2
k1k_1
k_1
=K= K
= K
k3k_3
k_3
k4k_4
k_4
Rn×dk\mathbb{R}^{n \times d_k}
\mathbb{R}^{n \times d_k}
Rn×dk\mathbb{R}^{n \times d_k}
\mathbb{R}^{n \times d_k}
(q1)Tk1(q_1)^Tk_1
(q_1)^Tk_1
q1q_1
q_1
Q=Q =
Q =
k2k_2
k_2
k1k_1
k_1
=K= K
= K
k3k_3
k_3
k4k_4
k_4
Rn×dk\mathbb{R}^{n \times d_k}
\mathbb{R}^{n \times d_k}
Rn×dk\mathbb{R}^{n \times d_k}
\mathbb{R}^{n \times d_k}
(q1)Tk3(q_1)^Tk_3
(q_1)^Tk_3
q4q_4
q_4
q2q_2
q_2
q3q_3
q_3
q1q_1
q_1
q2q_2
q_2
Q=Q =
Q =
k2k_2
k_2
k1k_1
k_1
=K= K
= K
k3k_3
k_3
k4k_4
k_4
Rn×dk\mathbb{R}^{n \times d_k}
\mathbb{R}^{n \times d_k}
Rn×dk\mathbb{R}^{n \times d_k}
\mathbb{R}^{n \times d_k}
(q2)Tk1(q_2)^Tk_1
(q_2)^Tk_1
q4q_4
q_4
q3q_3
q_3
q4q_4
q_4
q2q_2
q_2
q3q_3
q_3
Q=Q =
Q =
k2k_2
k_2
k1k_1
k_1
=K= K
= K
k3k_3
k_3
k4k_4
k_4
Rn×dk\mathbb{R}^{n \times d_k}
\mathbb{R}^{n \times d_k}
Rn×dk\mathbb{R}^{n \times d_k}
\mathbb{R}^{n \times d_k}
(q3)Tk4(q_3)^Tk_4
(q_3)^Tk_4
q4q_4
q_4
q2q_2
q_2
q1q_1
q_1
q4q_4
q_4
q1q_1
q_1
Q=Q =
Q =
k2k_2
k_2
k1k_1
k_1
=K= K
= K
k3k_3
k_3
k4k_4
k_4
Rn×dk\mathbb{R}^{n \times d_k}
\mathbb{R}^{n \times d_k}
Rn×dk\mathbb{R}^{n \times d_k}
\mathbb{R}^{n \times d_k}
(q4)Tk2(q_4)^Tk_2
(q_4)^Tk_2
q1q_1
q_1
q2q_2
q_2
q3q_3
q_3
q4q_4
q_4
q1q_1
q_1
q2q_2
q_2
q3q_3
q_3
Q=Q =
Q =
k2k_2
k_2
k1k_1
k_1
=K= K
= K
A=A =
A =
[\Bigg[ \begin{array}{l} \end{array} \Bigg.
\Bigg[ \begin{array}{l} \end{array} \Bigg.
]\Bigg] \begin{array}{l} \end{array} \Bigg.
\Bigg] \begin{array}{l} \end{array} \Bigg.
a41a_{41}
a_{41}
a42a_{42}
a_{42}
a43a_{43}
a_{43}
a44a_{44}
a_{44}
==
=
a31a_{31}
a_{31}
a34a_{34}
a_{34}
a32a_{32}
a_{32}
a33a_{3 3}
a_{3 3}
a21a_{21}
a_{21}
a24a_{24}
a_{24}
a22a_{22}
a_{22}
a23a_{23}
a_{23}
a11a_{11}
a_{11}
a14a_{14}
a_{14}
a12a_{12}
a_{12}
a13a_{13}
a_{13}
k3k_3
k_3
k4k_4
k_4
Rn×dk\mathbb{R}^{n \times d_k}
\mathbb{R}^{n \times d_k}
Rn×dk\mathbb{R}^{n \times d_k}
\mathbb{R}^{n \times d_k}
Rn×n\mathbb{R}^{n \times n}
\mathbb{R}^{n \times n}

each row sums up to 1

((
(
))
)

softmax

/dk/\sqrt{d_k}
/\sqrt{d_k}
((
(
))
)

softmax

/dk/\sqrt{d_k}
/\sqrt{d_k}
((
(
))
)

softmax

/dk/\sqrt{d_k}
/\sqrt{d_k}
((
(
))
)

softmax

/dk/\sqrt{d_k}
/\sqrt{d_k}

attention matrix

v4v_4
v_4
v4v_4
v_4
q1q_1
q_1
a41a_{41}
a_{41}
a42a_{42}
a_{42}
a43a_{43}
a_{43}
a44a_{44}
a_{44}
a31a_{31}
a_{31}
a34a_{34}
a_{34}
a32a_{32}
a_{32}
a33a_{3 3}
a_{3 3}
a21a_{21}
a_{21}
a24a_{24}
a_{24}
a22a_{22}
a_{22}
a23a_{23}
a_{23}
a11a_{11}
a_{11}
a14a_{14}
a_{14}
a12a_{12}
a_{12}
a13a_{13}
a_{13}
q1q_1
q_1
q2q_2
q_2
q3q_3
q_3
v4v_4
v_4
q4q_4
q_4
k4k_4
k_4
v4v_4
v_4
v1v_1
v_1
k1k_1
k_1
v2v_2
v_2
k2k_2
k_2
v3v_3
v_3
k3k_3
k_3

attention mechanism

x(1)x^{(1)}
x^{(1)}
x(2)x^{(2)}
x^{(2)}
x(3)x^{(3)}
x^{(3)}
x(4)x^{(4)}
x^{(4)}
++
+
++
+
++
+
a11a_{11}
a_{11}
a14a_{14}
a_{14}
a12a_{12}
a_{12}
a13a_{13}
a_{13}
a41a_{41}
a_{41}
a42a_{42}
a_{42}
a43a_{43}
a_{43}
a44a_{44}
a_{44}
a31a_{31}
a_{31}
a34a_{34}
a_{34}
a32a_{32}
a_{32}
a33a_{3 3}
a_{3 3}
a21a_{21}
a_{21}
a24a_{24}
a_{24}
a22a_{22}
a_{22}
a23a_{23}
a_{23}
a11a_{11}
a_{11}
a14a_{14}
a_{14}
a12a_{12}
a_{12}
a13a_{13}
a_{13}
==
=
v4v_4
v_4
v2v_2
v_2
v3v_3
v_3
v1v_1
v_1
v4v_4
v_4
v4v_4
v_4
q1q_1
q_1
q1q_1
q_1
q2q_2
q_2
q3q_3
q_3
v4v_4
v_4
q4q_4
q_4
k4k_4
k_4
v4v_4
v_4
v1v_1
v_1
k1k_1
k_1
v2v_2
v_2
k2k_2
k_2
v3v_3
v_3
k3k_3
k_3

attention mechanism

x(1)x^{(1)}
x^{(1)}
x(2)x^{(2)}
x^{(2)}
x(3)x^{(3)}
x^{(3)}
x(4)x^{(4)}
x^{(4)}
Rdk\in \mathbb{R}^{d_k}
\in \mathbb{R}^{d_k}
v4v_4
v_4
v4v_4
v_4
q1q_1
q_1
==
=
a41a_{41}
a_{41}
a42a_{42}
a_{42}
a43a_{43}
a_{43}
a44a_{44}
a_{44}
a31a_{31}
a_{31}
a34a_{34}
a_{34}
a32a_{32}
a_{32}
a33a_{3 3}
a_{3 3}
a21a_{21}
a_{21}
a24a_{24}
a_{24}
a22a_{22}
a_{22}
a23a_{23}
a_{23}
a11a_{11}
a_{11}
a14a_{14}
a_{14}
a12a_{12}
a_{12}
a13a_{13}
a_{13}
++
+
++
+
++
+
a21a_{21}
a_{21}
a24a_{24}
a_{24}
a22a_{22}
a_{22}
a23a_{23}
a_{23}
v1v_1
v_1
q1q_1
q_1
k1k_1
k_1
v2v_2
v_2
q2q_2
q_2
k2k_2
k_2
v3v_3
v_3
q3q_3
q_3
k3k_3
k_3
v4v_4
v_4
q4q_4
q_4
k4k_4
k_4
v4v_4
v_4
v4v_4
v_4
v2v_2
v_2
v3v_3
v_3
v1v_1
v_1

attention mechanism

x(1)x^{(1)}
x^{(1)}
x(2)x^{(2)}
x^{(2)}
x(3)x^{(3)}
x^{(3)}
x(4)x^{(4)}
x^{(4)}
Rdk\in \mathbb{R}^{d_k}
\in \mathbb{R}^{d_k}
v4v_4
v_4
v4v_4
v_4
q1q_1
q_1
==
=
a41a_{41}
a_{41}
a42a_{42}
a_{42}
a43a_{43}
a_{43}
a44a_{44}
a_{44}
a31a_{31}
a_{31}
a34a_{34}
a_{34}
a32a_{32}
a_{32}
a33a_{3 3}
a_{3 3}
a21a_{21}
a_{21}
a24a_{24}
a_{24}
a22a_{22}
a_{22}
a23a_{23}
a_{23}
a11a_{11}
a_{11}
a14a_{14}
a_{14}
a12a_{12}
a_{12}
a13a_{13}
a_{13}
++
+
++
+
++
+
a31a_{31}
a_{31}
a34a_{34}
a_{34}
a32a_{32}
a_{32}
a33a_{33}
a_{33}
v1v_1
v_1
q1q_1
q_1
k1k_1
k_1
v2v_2
v_2
q2q_2
q_2
k2k_2
k_2
v3v_3
v_3
q3q_3
q_3
k3k_3
k_3
v4v_4
v_4
q4q_4
q_4
k4k_4
k_4
v4v_4
v_4
v4v_4
v_4
v2v_2
v_2
v3v_3
v_3
v1v_1
v_1

attention mechanism

x(1)x^{(1)}
x^{(1)}
x(2)x^{(2)}
x^{(2)}
x(3)x^{(3)}
x^{(3)}
x(4)x^{(4)}
x^{(4)}
Rdk\in \mathbb{R}^{d_k}
\in \mathbb{R}^{d_k}
v4v_4
v_4
v4v_4
v_4
q1q_1
q_1
==
=
a41a_{41}
a_{41}
a42a_{42}
a_{42}
a43a_{43}
a_{43}
a44a_{44}
a_{44}
a31a_{31}
a_{31}
a34a_{34}
a_{34}
a32a_{32}
a_{32}
a33a_{3 3}
a_{3 3}
a21a_{21}
a_{21}
a24a_{24}
a_{24}
a22a_{22}
a_{22}
a23a_{23}
a_{23}
a11a_{11}
a_{11}
a14a_{14}
a_{14}
a12a_{12}
a_{12}
a13a_{13}
a_{13}
++
+
++
+
++
+
a41a_{41}
a_{41}
a44a_{44}
a_{44}
a42a_{42}
a_{42}
a43a_{43}
a_{43}
v1v_1
v_1
q1q_1
q_1
k1k_1
k_1
v2v_2
v_2
q2q_2
q_2
k2k_2
k_2
v3v_3
v_3
q3q_3
q_3
k3k_3
k_3
v4v_4
v_4
q4q_4
q_4
k4k_4
k_4
v4v_4
v_4
v4v_4
v_4
v2v_2
v_2
v3v_3
v_3
v1v_1
v_1

attention mechanism

x(1)x^{(1)}
x^{(1)}
x(2)x^{(2)}
x^{(2)}
x(3)x^{(3)}
x^{(3)}
x(4)x^{(4)}
x^{(4)}
Rdk\in \mathbb{R}^{d_k}
\in \mathbb{R}^{d_k}
x(1)x^{(1)}
x^{(1)}
x(2)x^{(2)}
x^{(2)}
x(3)x^{(3)}
x^{(3)}
x(4)x^{(4)}
x^{(4)}

a

robot

must

obey

v1v_1
v_1
k1k_1
k_1
q1q_1
q_1
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
v2v_2
v_2
k2k_2
k_2
q2q_2
q_2
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
v3v_3
v_3
k3k_3
k_3
q3q_3
q_3
v4v_4
v_4
q4q_4
q_4
k4k_4
k_4
v4v_4
v_4
v4v_4
v_4
v4v_4
v_4

one attention head

attention mechanism

attention layer

x(1)x^{(1)}
x^{(1)}
x(2)x^{(2)}
x^{(2)}
x(3)x^{(3)}
x^{(3)}
x(4)x^{(4)}
x^{(4)}

a

robot

must

obey

v1v_1
v_1
k1k_1
k_1
q1q_1
q_1
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
v2v_2
v_2
k2k_2
k_2
q2q_2
q_2
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
v3v_3
v_3
k3k_3
k_3
q3q_3
q_3
v4v_4
v_4
q4q_4
q_4
k4k_4
k_4
v4v_4
v_4
v4v_4
v_4
v4v_4
v_4

attention mechanism

v1v_1
v_1
k1k_1
k_1
q1q_1
q_1
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
v2v_2
v_2
k2k_2
k_2
q2q_2
q_2
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
v3v_3
v_3
k3k_3
k_3
q3q_3
q_3
v4v_4
v_4
q4q_4
q_4
k4k_4
k_4
v4v_4
v_4
v4v_4
v_4
v4v_4
v_4

attention mechanism

v1v_1
v_1
k1k_1
k_1
q1q_1
q_1
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
v2v_2
v_2
k2k_2
k_2
q2q_2
q_2
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
v3v_3
v_3
k3k_3
k_3
q3q_3
q_3
v4v_4
v_4
q4q_4
q_4
k4k_4
k_4
v4v_4
v_4
v4v_4
v_4
v4v_4
v_4

attention mechanism

v1v_1
v_1
k1k_1
k_1
q1q_1
q_1
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
v2v_2
v_2
k2k_2
k_2
q2q_2
q_2
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
v3v_3
v_3
k3k_3
k_3
q3q_3
q_3
v4v_4
v_4
q4q_4
q_4
k4k_4
k_4
v4v_4
v_4
v4v_4
v_4
v4v_4
v_4

attention mechanism

v1v_1
v_1
k1k_1
k_1
q1q_1
q_1
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
v2v_2
v_2
k2k_2
k_2
q2q_2
q_2
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
v3v_3
v_3
k3k_3
k_3
q3q_3
q_3
v4v_4
v_4
q4q_4
q_4
k4k_4
k_4
v4v_4
v_4
v4v_4
v_4
v4v_4
v_4

attention mechanism

\dots
\dots

multi-headed attention layer

v1v_1
v_1
k1k_1
k_1
q1q_1
q_1
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
v2v_2
v_2
k2k_2
k_2
q2q_2
q_2
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
v3v_3
v_3
k3k_3
k_3
q3q_3
q_3
v4v_4
v_4
q4q_4
q_4
k4k_4
k_4
v4v_4
v_4
v4v_4
v_4
v4v_4
v_4

attention mechanism

v1v_1
v_1
k1k_1
k_1
q1q_1
q_1
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
v2v_2
v_2
k2k_2
k_2
q2q_2
q_2
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
v3v_3
v_3
k3k_3
k_3
q3q_3
q_3
v4v_4
v_4
q4q_4
q_4
k4k_4
k_4
v4v_4
v_4
v4v_4
v_4
v4v_4
v_4

attention mechanism

each head

  • can be processed in parallel to all other heads
  • learns its own independent Wq,Wk,WvW_q, W_k, W_v
  • creates its own (q,k,v)(q,k,v) sequence
    • inside each head, the sequence of nn (q,k,v)(q,k,v) tokens can be processed in parallel
  • computes its own attention output sequence
    • inside each head, the sequence of nn output tokens can be processed in parallel

we then learn yet another weight WhW_h to sum up the outputs from individual head, to be the multi-headed attention layer output.

x(1)x^{(1)}
x^{(1)}
x(2)x^{(2)}
x^{(2)}
x(3)x^{(3)}
x^{(3)}
x(4)x^{(4)}
x^{(4)}

a

robot

must

obey

v1v_1
v_1
k1k_1
k_1
q1q_1
q_1
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
v2v_2
v_2
k2k_2
k_2
q2q_2
q_2
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
v3v_3
v_3
k3k_3
k_3
q3q_3
q_3
v4v_4
v_4
q4q_4
q_4
k4k_4
k_4
v4v_4
v_4
v4v_4
v_4
v4v_4
v_4

attention mechanism

v1v_1
v_1
k1k_1
k_1
q1q_1
q_1
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
v2v_2
v_2
k2k_2
k_2
q2q_2
q_2
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
v3v_3
v_3
k3k_3
k_3
q3q_3
q_3
v4v_4
v_4
q4q_4
q_4
k4k_4
k_4
v4v_4
v_4
v4v_4
v_4
v4v_4
v_4

attention mechanism

v1v_1
v_1
k1k_1
k_1
q1q_1
q_1
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
v2v_2
v_2
k2k_2
k_2
q2q_2
q_2
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
v3v_3
v_3
k3k_3
k_3
q3q_3
q_3
v4v_4
v_4
q4q_4
q_4
k4k_4
k_4
v4v_4
v_4
v4v_4
v_4
v4v_4
v_4

attention mechanism

v1v_1
v_1
k1k_1
k_1
q1q_1
q_1
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
v2v_2
v_2
k2k_2
k_2
q2q_2
q_2
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
v3v_3
v_3
k3k_3
k_3
q3q_3
q_3
v4v_4
v_4
q4q_4
q_4
k4k_4
k_4
v4v_4
v_4
v4v_4
v_4
v4v_4
v_4

attention mechanism

v1v_1
v_1
k1k_1
k_1
q1q_1
q_1
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
v2v_2
v_2
k2k_2
k_2
q2q_2
q_2
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
v3v_3
v_3
k3k_3
k_3
q3q_3
q_3
v4v_4
v_4
q4q_4
q_4
k4k_4
k_4
v4v_4
v_4
v4v_4
v_4
v4v_4
v_4

attention mechanism

\dots
\dots
v1v_1
v_1
k1k_1
k_1
q1q_1
q_1
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
v2v_2
v_2
k2k_2
k_2
q2q_2
q_2
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
v3v_3
v_3
k3k_3
k_3
q3q_3
q_3
v4v_4
v_4
q4q_4
q_4
k4k_4
k_4
v4v_4
v_4
v4v_4
v_4
v4v_4
v_4

attention mechanism

v1v_1
v_1
k1k_1
k_1
q1q_1
q_1
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
v2v_2
v_2
k2k_2
k_2
q2q_2
q_2
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
v3v_3
v_3
k3k_3
k_3
q3q_3
q_3
v4v_4
v_4
q4q_4
q_4
k4k_4
k_4
v4v_4
v_4
v4v_4
v_4
v4v_4
v_4

attention mechanism

multi-headed attention layer

{\left\{ \begin{array}{l} \\ \\ \end{array} \right.
\left\{ \begin{array}{l} \\ \\ \end{array} \right.
WhW_h
W_h
{\left\{ \begin{array}{l} \\ \\ \end{array} \right.
\left\{ \begin{array}{l} \\ \\ \end{array} \right.
WhW_h
W_h
{\left\{ \begin{array}{l} \\ \\ \end{array} \right.
\left\{ \begin{array}{l} \\ \\ \end{array} \right.
WhW_h
W_h
{\left\{ \begin{array}{l} \\ \\ \end{array} \right.
\left\{ \begin{array}{l} \\ \\ \end{array} \right.
WhW_h
W_h
x(1)x^{(1)}
x^{(1)}
x(2)x^{(2)}
x^{(2)}
x(3)x^{(3)}
x^{(3)}
x(4)x^{(4)}
x^{(4)}

a

robot

must

obey

v1v_1
v_1
k1k_1
k_1
q1q_1
q_1
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
v2v_2
v_2
k2k_2
k_2
q2q_2
q_2
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
v3v_3
v_3
k3k_3
k_3
q3q_3
q_3
v4v_4
v_4
q4q_4
q_4
k4k_4
k_4
v4v_4
v_4
v4v_4
v_4
v4v_4
v_4

attention mechanism

v1v_1
v_1
k1k_1
k_1
q1q_1
q_1
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
v2v_2
v_2
k2k_2
k_2
q2q_2
q_2
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
v3v_3
v_3
k3k_3
k_3
q3q_3
q_3
v4v_4
v_4
q4q_4
q_4
k4k_4
k_4
v4v_4
v_4
v4v_4
v_4
v4v_4
v_4

attention mechanism

v1v_1
v_1
k1k_1
k_1
q1q_1
q_1
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
v2v_2
v_2
k2k_2
k_2
q2q_2
q_2
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
v3v_3
v_3
k3k_3
k_3
q3q_3
q_3
v4v_4
v_4
q4q_4
q_4
k4k_4
k_4
v4v_4
v_4
v4v_4
v_4
v4v_4
v_4

attention mechanism

v1v_1
v_1
k1k_1
k_1
q1q_1
q_1
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
v2v_2
v_2
k2k_2
k_2
q2q_2
q_2
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
v3v_3
v_3
k3k_3
k_3
q3q_3
q_3
v4v_4
v_4
q4q_4
q_4
k4k_4
k_4
v4v_4
v_4
v4v_4
v_4
v4v_4
v_4

attention mechanism

v1v_1
v_1
k1k_1
k_1
q1q_1
q_1
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
v2v_2
v_2
k2k_2
k_2
q2q_2
q_2
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
v3v_3
v_3
k3k_3
k_3
q3q_3
q_3
v4v_4
v_4
q4q_4
q_4
k4k_4
k_4
v4v_4
v_4
v4v_4
v_4
v4v_4
v_4

attention mechanism

\dots
\dots
v1v_1
v_1
k1k_1
k_1
q1q_1
q_1
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
v2v_2
v_2
k2k_2
k_2
q2q_2
q_2
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
v3v_3
v_3
k3k_3
k_3
q3q_3
q_3
v4v_4
v_4
q4q_4
q_4
k4k_4
k_4
v4v_4
v_4
v4v_4
v_4
v4v_4
v_4

attention mechanism

v1v_1
v_1
k1k_1
k_1
q1q_1
q_1
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
v2v_2
v_2
k2k_2
k_2
q2q_2
q_2
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
v3v_3
v_3
k3k_3
k_3
q3q_3
q_3
v4v_4
v_4
q4q_4
q_4
k4k_4
k_4
v4v_4
v_4
v4v_4
v_4
v4v_4
v_4

attention mechanism

multi-headed attention layer

{\left\{ \begin{array}{l} \\ \\ \end{array} \right.
\left\{ \begin{array}{l} \\ \\ \end{array} \right.
WhW_h
W_h
{\left\{ \begin{array}{l} \\ \\ \end{array} \right.
\left\{ \begin{array}{l} \\ \\ \end{array} \right.
WhW_h
W_h
{\left\{ \begin{array}{l} \\ \\ \end{array} \right.
\left\{ \begin{array}{l} \\ \\ \end{array} \right.
WhW_h
W_h
{\left\{ \begin{array}{l} \\ \\ \end{array} \right.
\left\{ \begin{array}{l} \\ \\ \end{array} \right.
WhW_h
W_h

all in Rd\mathbb{R}^{d}

Some other ideas commonly used in practice:

 

masking

Layer normlization

Residual connection

Positional encoding

We will see the details in hw/lab

Neural networks are representation learners 

 Deep nets transform datapoints, layer by layer
 Each layer gives a different representation (aka embedding) of the data

Recall

Training data

xx
x
z1z_1
z_1
a1a_1
a_1
z2z_2
z_2
gg
g
z1= linear (x){z}_1=\text { linear }(x)
{z}_1=\text { linear }(x)
a1= ReLU(z1){a}_1=\text { ReLU}(z_1)
{a}_1=\text { ReLU}(z_1)
g=softmax(z2)g=\text {softmax}(z_2)
g=\text {softmax}(z_2)
z2= linear (a1){z}_2=\text { linear }(a_1)
{z}_2=\text { linear }(a_1)
xR2x\in \mathbb{R^2}
x\in \mathbb{R^2}

maps from complex data space to simple embedding space

Recall
(\left( \begin{array}{l} \\ \\ \\ \\ \\ \\ \end{array} \right.
\left( \begin{array}{l} \\ \\ \\ \\ \\ \\ \end{array} \right.

We can tokenize anything.
General strategy: chop the input up into chunks, project each chunk to an embedding

this projection can be fixed from a pre-trained model, or trained jointly with downstream task

\underbrace{\hspace{2.78cm}}
\underbrace{\hspace{2.78cm}}

a sequence of nn tokens

a projection, e.g. via a fixed, or learned linear transformation 

{\left\{ \begin{array}{l} \\ \end{array} \right.
\left\{ \begin{array}{l} \\ \end{array} \right.
{\left\{ \begin{array}{l} \\ \end{array} \right.
\left\{ \begin{array}{l} \\ \end{array} \right.

each token Rd\in \mathbb{R}^{d} embedding

100-by-100

\underbrace{\hspace{2.6cm}}
\underbrace{\hspace{2.6cm}}

each token R400\in \mathbb{R}^{400}

{\left\{ \begin{array}{l} \\ \\ \end{array} \right.
\left\{ \begin{array}{l} \\ \\ \end{array} \right.

20-by-20

{\left\{ \begin{array}{l} \\ \end{array} \right.
\left\{ \begin{array}{l} \\ \end{array} \right.
\underbrace{\hspace{2.78cm}}
\underbrace{\hspace{2.78cm}}

a sequence of n=25n=25 tokens

suppose just flatten

Multi-modality (text + image)

  • (query, key, value) come from different input modality
  • cross-attention

[Bahdanau et al. 2015]

Input sentence: “The agreement on the European Economic Area was signed in August 1992”

Output sentence: “L’accord sur la
zone économique européenne a été signé en août 1992”

[“DINO”, Caron et all. 2021]

Success mode:

Success mode:

[Show, Attend and Tell: Neural Image Caption Generation with Visual Attention. Xu et al. CVPR (2016)]

Failure mode:

[Show, Attend and Tell: Neural Image Caption Generation with Visual Attention. Xu et al. CVPR (2016)]

Failure mode:

[Show, Attend and Tell: Neural Image Caption Generation with Visual Attention. Xu et al. CVPR (2016)]

Failure mode:

)\left) \begin{array}{l} \\ \\ \\ \\ \\ \\ \end{array} \right.
\left) \begin{array}{l} \\ \\ \\ \\ \\ \\ \end{array} \right.

Summary

  • Transformers combine many of the best ideas from earlier architectures—convolutional patch-wise processing, relu nonlinearities, residual connections —with several new innovations, in particular, embedding and attention layers.
  • Transformers start with some generic hard-coded embeddings, and layer-by-layer, creates better and better embeddings.
  • Parallel processing everything in attention: each head is processed in parallel, and within each head, the q,k,vq,k,v token sequence is created in parallel, the attention scores is computed in parallel, and the attention output is computed in parallel. 

Thanks!

for your attention!

We'd love to hear your thoughts.

  • A token is just transformer lingo for a vector of neurons
  • But the connotation is that a token is an encapsulated bundle of information
  • with transformers we will operate over tokens rather than over neurons

array of neurons

array of tokens

set of tokens

set of neurons

seq. of

tokens in Rd\mathbb{R}^d

learned embeddings

a

robot

must

obey

v1v_1
v_1
k1k_1
k_1
q1q_1
q_1
v2v_2
v_2
k2k_2
k_2
q2q_2
q_2
v3v_3
v_3
k3k_3
k_3
q3q_3
q_3
v4v_4
v_4
q4q_4
q_4
k4k_4
k_4
v4v_4
v_4
v4v_4
v_4
v4v_4
v_4
q1q_1
q_1

attention head

GPT

a

robot

must

obey

\dots
\dots
\dots
\dots
\dots
\dots
\dots
\dots
{\left\{ \begin{array}{l} \\ \\ \\ \\ \\ \end{array} \right.
\left\{ \begin{array}{l} \\ \\ \\ \\ \\ \end{array} \right.

distribution over the entire vocabulary

x(i)Rdx^{(i)} \in \mathbb{R}^{d}
x^{(i)} \in \mathbb{R}^{d}
q(i)Rdkq^{(i)} \in \mathbb{R}^{d_k}
q^{(i)} \in \mathbb{R}^{d_k}
k(i)Rdkk^{(i)} \in \mathbb{R}^{d_k}
k^{(i)} \in \mathbb{R}^{d_k}
v(i)Rdkv^{(i)} \in \mathbb{R}^{d_k}
v^{(i)} \in \mathbb{R}^{d_k}
  • Which color is query/key/value respectively?
  • How do we go from xx to q,k,vq, k, v?

via learned projection weights 

WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
x(i)Rdx^{(i)} \in \mathbb{R}^{d}
x^{(i)} \in \mathbb{R}^{d}
q(i)Rdkq^{(i)} \in \mathbb{R}^{d_k}
q^{(i)} \in \mathbb{R}^{d_k}
k(i)Rdkk^{(i)} \in \mathbb{R}^{d_k}
k^{(i)} \in \mathbb{R}^{d_k}
v(i)Rdkv^{(i)} \in \mathbb{R}^{d_k}
v^{(i)} \in \mathbb{R}^{d_k}
  • Importantly, all these learned projection weights WW are shared along the token sequence:
  • These three weights WW -- once learned -- do not change based on input token x.x.
  • If the input sequence had been longer, we can still use the same weights in the same fashion --- just maps to a longer output sequence. 
  • This is yet another parallel processing (similar to convolution)
  • But each (q,k,v)(q,k,v) do depend on the corresponding input xx (can be interpreted as dynamically changing convolution filter weights)
x(1)x^{(1)}
x^{(1)}
x(2)x^{(2)}
x^{(2)}
x(3)x^{(3)}
x^{(3)}
x(4)x^{(4)}
x^{(4)}
x(5)x^{(5)}
x^{(5)}

WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q
WkW_k
W_k
WvW_v
W_v
WqW_q
W_q

a

robot

must

obey

input embedding

output embedding

transformer block

\dots

\dots

\dots

\dots

attention layer

fully-connected network

x(1)x^{(1)}
x^{(1)}
x(2)x^{(2)}
x^{(2)}
x(3)x^{(3)}
x^{(3)}
x(4)x^{(4)}
x^{(4)}

6.390 IntroML (Fall24) - Lecture 9 Transformers

By Shen Shen

6.390 IntroML (Fall24) - Lecture 9 Transformers

  • 85