Attention & Transformers
Attention & Transformers
Spring 2024
Many materials from CSE447@UW (Liwei Jiang), COS 484@Princeton, and CS224n@Stanford with special thanks!
Transformers
1 2 3 N
0 1 2
h1 h2 h3 hT
attention 2 2 2 2 2 2 2 2
h1 h2 h3 hT
In a lookup table, we have a table of keys In attention, the query matches all keys softly, to
that map to values. The query matches a weight between 0 and 1. The keys’ values are
one of the keys, returning its value. multiplied by the weights and summed.
Self-Attention Layer
a1 a2 a3 a4
a1 a2 a3 a4
How to compute α?
α=q⋅k W
tanh
q . k q + k
WQ WK WQ WK
query q1 k2 key k3 k4
k2 = WK a2 k3 = WK a3 k4 = WK a4
q1 = WQ a1
a1 a2 a3 a4
query q1 k1 k2 key k3 k4
k1 = WK a1 k2 = WK a2 k3 = WK a3 k4 = WK a4
q1 = WQ a1
a1 a2 a3 a4
Softmax
query q1 k1 k2 key k3 k4
k1 = WK a1 k2 = WK a2 k3 = WK a3 k4 = WK a4
q1 = WQ a1
a1 a2 a3 a4

17 Lecture 5: Attention & Transformers




Denote how relevant each token are to a1!
Use attention scores to extract information
′ ′ ′ ′
α1,1 α1,2 α1,3 α1,4
Softmax
query q1 k1 k2 key k3 k4
k1 = WK a1 k2 = WK a2 k3 = WK a3 k4 = WK a4
q1 = WQ a1
a1 a2 a3 a4
′
∑
b1 = α1,i vi
i
b1
′ ′ ′ ′
α1,1 × α1,2 × α1,3 × α1,4 ×
q1 k1 v1 k2 v2 k3 v3 k4 v4
v1 = WV a1 v2 = WV a2 v3 = WV a3 v4 = WV a4
a1 a2 a3 a4
′
∑
b1 = α1,i vi
i
b1
′ ′ ′ ′
α1,1 × α1,2 × α1,3 × α1,4 ×
′
The higher the attention score is, the α1,i
more important ai is to composing b1
q1 k1 v1 k2 v2 k3 v3 k4 v4
v1 = WV a1 v2 = WV a2 v3 = WV a3 v4 = WV a4
a1 a2 a3 a4
′
∑
b2 = α2,i vi
b2 i
′ ′ ′ ′
α2,1 × α2,2 × α2,3 × α2,4 ×
q1 k1 v1 q2 k2 v2 q3 k3 v3 q4 k4 v4
a1 a2 a3 a4
′
∑
b2 = α2,i vi
b2 i
′ ′ ′ ′
α2,1 × α2,2 × α2,3 × α2,4 ×
Note that the computation of be bi can
parallelized, as they are independent to
each other
q1 k1 v1 q2 k2 v2 q3 k3 v3 q4 k4 v4
a1 a2 a3 a4
Q I K I V I
q1 a1 k1 a1 v1 a1
q2 a2 k2 a2 v2 a2
= WQ = WK = WV
q3 a3 k3 a3 v3 a3
q4 a4 k4 a4 v4 a4
= k1 k2 k3 k4
′ ′ ′ ′
α1,1 α1,2 α1,3 α1,4
q1 k1 v1 k2 v2 k3 v3 k4 v4
v1 = WV a1 v2 = WV a2 v3 = WV a3 v4 = WV a4
a1 a2 a3 a4
q1
α1,1 α1,2 α1,3 α1,4
= k1 k2 k3 k4
Q
A ′ A K T
′ ′ ′ ′
q1
α1,1 α1,2 α1,3 α1,4 α1,1 α1,2 α1,3 α1,4
′
α2,1 ′
α2,2 ′
α2,3 ′
α2,4 α2,1 α2,2 α2,3 α2,4 q2
′ ′ ′ ′
= k1 k2 k3 k4
α3,1 α3,2 α3,3 α3,4 α3,1 α3,2 α3,3 α3,4 q3
′ ′ ′ ′
α4,1 α4,2 α4,3 α4,4 α4,1 α4,2 α4,3 α4,4
q4
b1 ′ ′ ′ ′
v1
α1,1 α1,2 α1,3 α1,4
v2
=
v3
v4
O V
′
A
b1 ′ ′ ′ ′
v1
α1,1 α1,2 α1,3 α1,4
b2 ′
α2,1 ′
α2,2 ′
α2,3 ′
α2,4 v2
=
b3 ′
α3,1 ′
α3,2 ′
α3,3 ′
α3,4 v3
′ ′ ′ ′
α4,1 α4,2 α4,3 α4,4
b4 v4
T
A=QK
′ Softmax T
T
A = I WQ (I WK ) = I WQ WKT I T
A A = Q K
′
A = softmax(A)
′
O=A V ′ O = A V
T
A=QK
T T T ′ n×n
A = I WQ (I WK ) = I WQ WK I A, A ∈ ℝ ?
′
A = softmax(A) Dimensions?
′ n×d
O=A V O ∈ ℝ?
T
A=QK
T T T ′ n×n
A = I WQ (I WK ) = I WQ WK I A, A ∈ ℝ
′
A = softmax(A) Dimensions?
′ n×d
O=A V O∈ℝ
2∗1/
sin( /10000 )
Dimension
2∗1/
cos( /10000 )
=
2∗ 2 /
sin( /10000 )
2∗ 2 /
cos( /10000 ) Index in the sequence
https://timodenk.com/blog/linear-relationships-in-the-transformers-positional-encoding/
b1 b2 … bn
Easy Fix: add a feed-forward network
to post-process each output vector. FF FF … FF
Self-Attention
a1 a2 … an
https://jalammar.github.io/illustrated-gpt2/
Lecture 5: Attention & Transformers
Looking into the Future → Masking
We can look at these (not
• In decoders (language modeling, greyed out) words
{
producing the next word given qi kj, j ≤ i
αi,j =
previous context), we need to −∞, j > i RT ]
TA h e e f h o
[ S T c h w
ensure we don’t peek at the future.
−∞ −∞ −∞
[START]
• To enable parallelization, we mask
out attention to future words by
−∞ −∞
setting attention scores to −∞. The
For encoding
these words
−∞
chef
who
• Self-attention Softmax
• Positional Encoding
Inputs
43 Lecture 5: Attention & Transformers
Output Probabilities
Linear
Position Embedding
• Replace self-attention with multi-head +
self-attention. Input Embeddings
Inputs
44 Lecture 5: Attention & Transformers
Multi-head Attention
“The Beast with Many Heads”
H0 H1 H7
https://jalammar.github.io/illustrated-transformer/
ff
Multi-Head Attention: Walk-through
bi,1
′ ′
αi,i,1 × αi,j,1 ×
qi,1 qi,2 ki,1 ki,2 vi,1 vi,2 qj,1 qj,2 kj,1 kj,2 vj,1 vj,2
qi ki vi qj kj vj
ai Multi-head Attention aj
46 Lecture 5: Attention & Transformers


Multi-Head Attention: Walk-through
bi,2
′ ′
αi,i,2 × αi,j,2 ×
qi,1 qi,2 ki,1 ki,2 vi,1 vi,2 qj,1 qj,2 kj,1 kj,2 vj,1 vj,2
qi ki vi qj kj vj
ai Multi-head Attention aj
47 Lecture 5: Attention & Transformers


bi,1
Concatenation
bi = Y
× ×
bi,2
Some
transformatio × ×
qi,1 qi,2 ki,1 ki,2 vi,1 vi,2 qj,1 qj,2 kj,1 kj,2 vj,1 vj,2
qi ki vi qj kj vj
ai Multi-head Attention aj
48 Lecture 5: Attention & Transformers
Recall the Matrices Form of Self-Attention
n×d d
Q = I WQ I = {a1, . . . , an} ∈ ℝ , where ai ∈ ℝ
d×d
K = I WK WQ, WK, WV ∈ ℝ
n×d
V = I WV Q, K, V ∈ ℝ
T
A=QK
T T T ′ n×n
A = I WQ (I WK ) = I WQ WK I A, A ∈ ℝ
′
A = softmax(A)
′ n×d
O=A V O∈ℝ
• Multiple attention “heads” can be de ned via multiple WQ, WK, WV matrices
l l l d× dh
• Let WQ, WK, WV ∈ℝ , where h is the number of attention heads, and l ranges
from 1 to h.
• Each attention head performs attention independently:
l l lT T l
• O = softmax(I WQ WK I )I WV
l
• Concatenating different O from different attention heads.
1 n d×d
• O = [O ; . . . ; O ] Y, where Y ∈ ℝ
d×d
Y∈ℝ
1 h n×d
1 h
O = [O ; . . . ; O ] Y [O ; . . . ; O ] ∈ ℝ ?
n×d
O∈ℝ ?
51 Lecture 5: Attention & Transformers



The Matrices Form of Multi-head Attention
l l n×d d
Q =I WQ I = {a1, . . . , an} ∈ ℝ , where ai ∈ ℝ
l l l l l d× dh
K =I WK WQ, WK, WV ∈ℝ
d
l l l l l n× h
V =I WV Q ,K ,V ∈ ℝ
l l lT
A =Q K l′
A ,A ∈ ℝ l n×n
l′ l
A = softmax(A )
Dimensions?
l′
d
l n× h
l
O =A V l O ∈ℝ
d×d
Y∈ℝ
1 h n×d
1 h
O = [O ; . . . ; O ] Y [O ; . . . ; O ] ∈ ℝ
n×d
O∈ℝ
52 Lecture 5: Attention & Transformers



Multi-head Attention is Computationally Ef cient
• Even though we compute h many attention heads, it’s not more costly.
d
n×d n×h× h
• We compute I WQ ∈ ℝ , and then reshape to ℝ .
• Likewise for I WK and I WV.
h×n× dh
• Then we transpose to ℝ ; now the head axis is like a batch axis.
• Almost everything else is identical. All we need to do is to reshape the tensors!
Softmax ( I T
WQ WK I T
) I WV = O′ Y = O ∈ℝn×d
Linear
+ Position Embedding
Input Embeddings
Inputs
55 Lecture 5: Attention & Transformers
ti
ti
Residual Connections
• Residual connections are a trick to help models train better.
(i) (i−1)
• Instead of X = Layer(X ) (where i represents the layer)
(i−1) (i)
X Layer X
∑
Let = ; this is the mean; ∈ ℝ.
•
=1
1
∑(
− ) ; this is the standard deviation;
2
Let = ∈ ℝ.
• =1
Linear
Input Embeddings
Inputs
58 Lecture 5: Attention & Transformers
Output Probabilities
Linear
Block Block
• Easier to parallelize:
Q
<latexit sha1_base64="Zj1Owf2jr65GlRqNMJIdIlsAOuc=">AAAB73icbVBNSwMxEJ2tX7V+VT16CRbBU9kVUS9C0YvHFmy70K4lm2bb0GyyJlmhLP0TXjwo4tW/481/Y9ruQVsfDDzem2FmXphwpo3rfjuFldW19Y3iZmlre2d3r7x/0NIyVYQ2ieRS+SHWlDNBm4YZTv1EURyHnLbD0e3Ubz9RpZkU92ac0CDGA8EiRrCxkt9A18hvPzR65YpbdWdAy8TLSQVy1Hvlr25fkjSmwhCOte54bmKCDCvDCKeTUjfVNMFkhAe0Y6nAMdVBNrt3gk6s0keRVLaEQTP190SGY63HcWg7Y2yGetGbiv95ndREV0HGRJIaKsh8UZRyZCSaPo/6TFFi+NgSTBSztyIyxAoTYyMq2RC8xZeXSeus6l1U3cZ5pXaTx1GEIziGU/DgEmpwB3VoAgEOz/AKb86j8+K8Ox/z1oKTzxzCHzifP4YNjvs=</latexit>
K = XW K V = XW V
<latexit sha1_base64="O/Xdn2nZwVqugGAVDtC02kvexhg=">AAAB73icbVBNSwMxEJ34WetX1aOXYBE8lV0R9SIUvQi9VLDtQruWbJptQ7PZNckKZemf8OJBEa/+HW/+G9N2D9r6YODx3gwz84JEcG0c5xstLa+srq0XNoqbW9s7u6W9/aaOU0VZg8YiVl5ANBNcsobhRjAvUYxEgWCtYHgz8VtPTGkey3szSpgfkb7kIafEWMmr4SvstR5q3VLZqThT4EXi5qQMOerd0lenF9M0YtJQQbRuu05i/Iwow6lg42In1SwhdEj6rG2pJBHTfja9d4yPrdLDYaxsSYOn6u+JjERaj6LAdkbEDPS8NxH/89qpCS/9jMskNUzS2aIwFdjEePI87nHFqBEjSwhV3N6K6YAoQo2NqGhDcOdfXiTN04p7XnHuzsrV6zyOAhzCEZyACxdQhVuoQwMoCHiGV3hDj+gFvaOPWesSymcO4A/Q5w9zs47v</latexit> <latexit sha1_base64="sG8GlMJZBZitk452XyeL5flQu3E=">AAAB73icbVBNS8NAEJ31s9avqkcvi0XwVBIR9SIUvXisYNNAG8tmu2mXbjZxdyOU0D/hxYMiXv073vw3btsctPXBwOO9GWbmhang2jjON1paXlldWy9tlDe3tnd2K3v7nk4yRVmTJiJRfkg0E1yypuFGMD9VjMShYK1weDPxW09MaZ7IezNKWRCTvuQRp8RYyffwFfZbD163UnVqzhR4kbgFqUKBRrfy1eklNIuZNFQQrduuk5ogJ8pwKti43Mk0Swkdkj5rWypJzHSQT+8d42Or9HCUKFvS4Kn6eyInsdajOLSdMTEDPe9NxP+8dmaiyyDnMs0Mk3S2KMoENgmePI97XDFqxMgSQhW3t2I6IIpQYyMq2xDc+ZcXiXdac89rzt1ZtX5dxFGCQziCE3DhAupwCw1oAgUBz/AKb+gRvaB39DFrXULFzAH8Afr8AZVYjwU=</latexit>
Q = XW
K V
<latexit sha1_base64="O/Xdn2nZwVqugGAVDtC02kvexhg=">AAAB73icbVBNSwMxEJ34WetX1aOXYBE8lV0R9SIUvQi9VLDtQruWbJptQ7PZNckKZemf8OJBEa/+HW/+G9N2D9r6YODx3gwz84JEcG0c5xstLa+srq0XNoqbW9s7u6W9/aaOU0VZg8YiVl5ANBNcsobhRjAvUYxEgWCtYHgz8VtPTGkey3szSpgfkb7kIafEWMmr4SvstR5q3VLZqThT4EXi5qQMOerd0lenF9M0YtJQQbRuu05i/Iwow6lg42In1SwhdEj6rG2pJBHTfja9d4yPrdLDYaxsSYOn6u+JjERaj6LAdkbEDPS8NxH/89qpCS/9jMskNUzS2aIwFdjEePI87nHFqBEjSwhV3N6K6YAoQo2NqGhDcOdfXiTN04p7XnHuzsrV6zyOAhzCEZyACxdQhVuoQwMoCHiGV3hDj+gFvaOPWesSymcO4A/Q5w9zs47v</latexit> <latexit sha1_base64="sG8GlMJZBZitk452XyeL5flQu3E=">AAAB73icbVBNS8NAEJ31s9avqkcvi0XwVBIR9SIUvXisYNNAG8tmu2mXbjZxdyOU0D/hxYMiXv073vw3btsctPXBwOO9GWbmhang2jjON1paXlldWy9tlDe3tnd2K3v7nk4yRVmTJiJRfkg0E1yypuFGMD9VjMShYK1weDPxW09MaZ7IezNKWRCTvuQRp8RYyffwFfZbD163UnVqzhR4kbgFqUKBRrfy1eklNIuZNFQQrduuk5ogJ8pwKti43Mk0Swkdkj5rWypJzHSQT+8d42Or9HCUKFvS4Kn6eyInsdajOLSdMTEDPe9NxP+8dmaiyyDnMs0Mk3S2KMoENgmePI97XDFqxMgSQhW3t2I6IIpQYyMq2xDc+ZcXiXdac89rzt1ZtX5dxFGCQziCE3DhAupwCw1oAgUBz/AKb+gRvaB39DFrXULFzAH8Afr8AZVYjwU=</latexit>
Q = XW K = XW V = XW
n × dq dk × n
n × dv
2 2
Need to compute n pairs of scores (= dot product) O(n d)
2
RNNs only require O(nd ) running time:
ht = g(Wht−1 + Uxt + b)