0% found this document useful (0 votes)
61 views32 pages

The Transformer Family Version 20 LilLog

The document presents an extensive update on Transformer architecture improvements, structured as Version 2.0, which is significantly longer than the previous version. It covers key concepts such as attention mechanisms, multi-head self-attention, encoder-decoder architecture, and various positional encoding methods. Additionally, it discusses enhancements for handling longer contexts and memory utilization in Transformer models, particularly through the introduction of Transformer-XL for better long-term dependency management.

Uploaded by

whereilive
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
61 views32 pages

The Transformer Family Version 20 LilLog

The document presents an extensive update on Transformer architecture improvements, structured as Version 2.0, which is significantly longer than the previous version. It covers key concepts such as attention mechanisms, multi-head self-attention, encoder-decoder architecture, and various positional encoding methods. Additionally, it discusses enhancements for handling longer contexts and memory utilization in Transformer models, particularly through the introduction of Transformer-XL for better long-term dependency management.

Uploaded by

whereilive
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
You are on page 1/ 32

The Transformer Family Version 2.

0 | Lil'Log

Lil'Log
Posts
Archive
Search
Tags
FAQ
emojisearch.app

The Transformer Family Version 2.0


Date: January 27, 2023 | Estimated Reading Time: 45 min | Author: Lilian Weng

Table of Contents

Many new Transformer architecture improvements have been proposed since my last post on “The Transformer
Family” about three years ago. Here I did a big refactoring and enrichment of that 2020 post — restructure the
hierarchy of sections and improve many sections with more recent papers. Version 2.0 is a superset of the old
version, about twice the length.

Notations
Symbol Meaning

d The model size / hidden state dimension / positional encoding size.


h The number of heads in multi-head attention layer.

L The segment length of input sequence.

N The total number of attention layers in the model; not considering MoE.
The input sequence where each element has been mapped into an
X ∈ R L× d embedding vector of shape d, same as the model size.

W k ∈ R d× dk The key weight matrix.


q
W ∈ R d× dk The query weight matrix.

W v ∈ R d× dv The value weight matrix. Often we have dk = dv = d.

W ki, W i ∈ R d× dk h; W vi ∈ R d× dv h The weight matrices per head.


q / /

W o ∈ R dv× d The output weight matrix.


Q = X W q ∈ R L× dk The query embedding inputs.

K = X W k ∈ R L× dk The key embedding inputs.

V = X W v ∈ R L× dv The value embedding inputs.


q i, k i ∈ R dk, v i ∈ R dv Row vectors in query, key, value matrices, Q , K and V .

Si A collection of key positions for the i -th query q i to attend to.


The self-attention matrix between a input sequence of lenght L and
A ∈ R L× L itself. A = soft max( Q K ⊤/ √ dk) .

a ij ∈ A The scalar attention score between query q i and key k j .


position encoding matrix, where the i -th row p i is the positional
L× d

https://lilianweng.github.io/posts/2023-01-27-the-transformer-family-v2/[2023/11/11 23:09:05]
The Transformer Family Version 2.0 | Lil'Log

P ∈R encoding for input x i .

Transformer Basics
The Transformer (which will be referred to as “vanilla Transformer” to distinguish it from other enhanced versions;
Vaswani, et al., 2017) model has an encoder-decoder architecture, as commonly used in many NMT models. Later
simplified Transformer was shown to achieve great performance in language modeling tasks, like in encoder-only
BERT or decoder-only GPT.

Attention and Self-Attention


Attention is a mechanism in neural network that a model can learn to make predictions by selectively attending to
a given set of data. The amount of attention is quantified by learned weights and thus the output is usually formed
as a weighted average.
Self-attention is a type of attention mechanism where the model makes prediction for one part of a data sample
using other parts of the observation about the same sample. Conceptually, it feels quite similar to non-local means.
Also note that self-attention is permutation-invariant; in other words, it is an operation on sets.
There are various forms of attention / self-attention, Transformer (Vaswani et al., 2017) relies on the scaled dot-
product attention: given a query matrix Q , a key matrix K and a value matrix V , the output is a weighted sum of
the value vectors, where the weight assigned to each value slot is determined by the dot-product of the query with
the corresponding key:

QK ) ⊤
at t n ( Q , K , V ) = soft max( V
√ dk
And for a query and a key vector q i, k j ∈ R d (row vectors in query and key matrices), we have a scalar score:

q ik j ⊤ ) exp ( q ik j ⊤)
a ij = soft max( =
√ dk √ dk ∑ r∈S i exp ( q ik r ⊤)

where S i is a collection of key positions for the i -th query to attend to.
See my old post for other types of attention if interested.

Multi-Head Self-Attention
The multi-head self-attention module is a key component in Transformer. Rather than only computing the
attention once, the multi-head mechanism splits the inputs into smaller chunks and then computes the scaled dot-
product attention over each subspace in parallel. The independent attention outputs are simply concatenated and
linearly transformed into expected dimensions.

Mult iHeadAt t n ( X q, X k, X v) = [head 1; … ; head h]W o


where head i = At t ent ion ( X qW qi, X kW ki, X vW vi)
q k d× dk/ h v d× dv/ h

https://lilianweng.github.io/posts/2023-01-27-the-transformer-family-v2/[2023/11/11 23:09:05]
The Transformer Family Version 2.0 | Lil'Log

where [. ; . ] is a concatenation operation. W , W ∈ R ,W i ∈R are weight matrices to map input


i i
embeddings of size L × d into query, key and value matrices. And W o ∈ R dv× d is the output linear
transformation. All the weights should be learned during training.

Fig. 1. Illustration of the multi-head scaled dot-product attention mechanism. (Image source:
Figure 2 in Vaswani, et al., 2017)

Encoder-Decoder Architecture
The encoder generates an attention-based representation with capability to locate a specific piece of information
from a large context. It consists of a stack of 6 identity modules, each containing two submodules, a multi-head
self-attention layer and a point-wise fully connected feed-forward network. By point-wise, it means that it applies
the same linear transformation (with same weights) to each element in the sequence. This can also be viewed as a
convolutional layer with filter size 1. Each submodule has a residual connection and layer normalization. All the
submodules output data of the same dimension d.
The function of Transformer decoder is to retrieve information from the encoded representation. The architecture
is quite similar to the encoder, except that the decoder contains two multi-head attention submodules instead of
one in each identical repeating module. The first multi-head attention submodule is masked to prevent positions
from attending to the future.

https://lilianweng.github.io/posts/2023-01-27-the-transformer-family-v2/[2023/11/11 23:09:05]
The Transformer Family Version 2.0 | Lil'Log

Fig. 2. The architecture of the vanilla Transformer model. (Image source: Figure 17)

Positional Encoding
Because self-attention operation is permutation invariant, it is important to use proper positional encoding to
provide order information to the model. The positional encoding P ∈ R L× d has the same dimension as the input
embedding, so it can be added on the input directly. The vanilla Transformer considered two types of encodings:

Sinusoidal Positional Encoding

Sinusoidal positional encoding is defined as follows, given the token position i = 1, … , L and the dimension
δ = 1, … , d:

sin ( 10000i 2δ′/ d ) if δ = 2δ′


P E ( i , δ) = {
cos( i 2δ′/ d )
10000
if δ = 2δ′ + 1

In this way each dimension of the positional encoding corresponds to a sinusoid of different wavelengths in
different dimensions, from 2π to 10000 ⋅2π .

https://lilianweng.github.io/posts/2023-01-27-the-transformer-family-v2/[2023/11/11 23:09:05]
The Transformer Family Version 2.0 | Lil'Log

Fig. 3. Sinusoidal positional encoding with L = 32 and d = 128. The value is between -1
(black) and 1 (white) and the value 0 is in gray.

Learned Positional Encoding

Learned positional encoding assigns each element with a learned column vector which encodes its absolute
position (Gehring, et al. 2017) and furthermroe this encoding can be learned differently per layer (Al-Rfou et al.
2018).

Relative Position Encoding

Shaw et al. (2018)) incorporated relative positional information into W k and W v. Maximum relative position is
clipped to a maximum absolute value of k and this clipping operation enables the model to generalize to unseen
sequence lengths. Therefore, 2k + 1 unique edge labels are considered and let us denote P k, P v ∈ R 2k+ 1 as
learnable relative position representations.

k k v v
Aij = P clip(j − i,k) Aij = P clip(j − i,k) where clip ( x , k) = clip ( x , − k, k)

Transformer-XL (Dai et al., 2019) proposed a type of relative positional encoding based on reparametrization of
dot-product of keys and queries. To keep the positional information flow coherently across segments, Transformer-
XL encodes the relative position instead, as it could be sufficient enough to know the position offset for making
good predictions, i.e. i − j , between one key vector k τ ,j and its query q τ ,i .
If omitting the scalar 1/ √ dk and the normalizing term in softmax but including positional encodings, we can write
the attention score between query at position i and key at position j as:

q
a ij = q ik j ⊤ = ( x i + p i) W (( x j + p j ) W k) ⊤
⊤ q k⊤ ⊤ k⊤ ⊤ q k⊤ ⊤
= x iW qW k x ⊤ q
j + x iW W p j + p iW W x j + p iW W p j

Transformer-XL reparameterizes the above four terms as follows:

rel q ⊤ ⊤ q ⊤ ⊤ ⊤⊤ ⊤ ⊤
a ij = x iW W kE x j + x iW W kR r i− j + u W kE x j + vW kR r i− j
       
content-based addressin g content-dep endent p ositional bias global content bias global p ositional bias

https://lilianweng.github.io/posts/2023-01-27-the-transformer-family-v2/[2023/11/11 23:09:05]
The Transformer Family Version 2.0 | Lil'Log

Replace p j with relative positional encoding r i− j ∈ R d;


Replace p iW q with two trainable parameters u (for content) and v (for location) in two different terms;
Split W k into two matrices, W k for content information and W k for location information.
E R

Rotary Position Embedding

Rotary position embedding (RoPE; Su et al. 2021) encodes the absolution position with a rotation matrix and
multiplies key and value matrices of every attention layer with it to inject relative positional information at every
layer.
When encoding relative positional information into the inner product of the i -th key and the j -th query, we would
like to formulate the function in a way that the inner product is only about the relative position i − j . Rotary
Position Embedding (RoPE) makes use of the rotation operation in Euclidean space and frames the relative position
embedding as simply rotating feature matrix by an angle proportional to its position index.
Given a vector z , if we want to rotate it counterclockwise by θ, we can multiply it by a rotation matrix to get Rz
where the rotation matrix R is defined as:

cos θ − sin θ]
R= [
sin θ cos θ
When generalizing to higher dimensional space, RoPE divide the d-dimensional space into d/ 2 subspaces and
constructs a rotation matrix R of size d × d for token at position i :

⎡ cos iθ1 − sin iθ1 0 0 … 0 0 ⎤


⎢ sin iθ1 cos iθ1 0 0 … 0 0 ⎥
0 0 cos iθ2 − sin iθ2 … 0 0
R dΘ,i = 0 0 sin iθ1 cos iθ1 … 0 0

⋮ ⋮ ⋮ ⋮ ⋱ ⋮ ⋮
0 0 0 0 … cos iθd/ 2 − sin iθd/ 2
⎣ 0 0 0 0 … sin iθd/ 2 cos iθd/ 2 ⎦

where in the paper we have Θ = θi = 10000− 2(i− 1)/ d, i ∈ [1, 2, … , d/ 2]. Note that this is essentially equivalent
to sinusoidal positional encoding but formulated as a rotation matrix.
Then both key and query matrices incorporates the positional information by multiplying with this rotation matrix:

( d q ) ⊤( d q d
q⊤
i k j = R Θ,iW x i R Θ,j W kx j ) = x ⊤ k
i W R Θ,j − iW x j
d d
where R Θ,j − i = ( R Θ,i) ⊤R Θ,j
d

https://lilianweng.github.io/posts/2023-01-27-the-transformer-family-v2/[2023/11/11 23:09:05]
The Transformer Family Version 2.0 | Lil'Log

Fig. 4. Visual illustration of how rotary position embedding is implemented.(Image source:


Su et al., 2021)

Longer Context
The length of an input sequence for transformer models at inference time is upper-bounded by the context length
used for training. Naively increasing context length leads to high consumption in both time (O ( L 2d) ) and memory
( O ( L 2) ) and may not be supported due to hardware constraints.
This section introduces several improvements in transformer architecture to better support long context at
inference; E.g. using additional memory, design for better context extrapolation, or recurrency mechanism.

Context Memory
The vanilla Transformer has a fixed and limited attention span. The model can only attend to other elements in the
same segments during each update step and no information can flow across separated fixed-length segments. This
context segmentation causes several issues:
The model cannot capture very long term dependencies.
It is hard to predict the first few tokens in each segment given no or thin context.
The evaluation is expensive. Whenever the segment is shifted to the right by one, the new segment is re-processed
from scratch, although there are a lot of overlapped tokens.
Transformer-XL (Dai et al., 2019; “XL” means “extra long”) modifies the architecture to reuse hidden states
between segments with an additional memory. The recurrent connection between segments is introduced into the
model by continuously using the hidden states from the previous segments.

https://lilianweng.github.io/posts/2023-01-27-the-transformer-family-v2/[2023/11/11 23:09:05]
The Transformer Family Version 2.0 | Lil'Log

Fig. 5. A comparison between the training phrase of vanilla Transformer & Transformer-XL
with a segment length 4. (Image source: left part of Figure 2 in Dai et al., 2019).

Let’s label the hidden state of the n -th layer for the ( τ + 1) -th segment in the model as h (n+) ∈ R L× d. In addition
τ 1
to the hidden state of the last layer for the same segment h (n − 1), it also depends on the hidden state of the same
τ+ 1
layer for the previous segment h τn .
( ) By incorporating information from the previous hidden states, the model
extends the attention span much longer in the past, over multiple segments.

( ) ( ) ( )
h̃ τn+−11 = [st op -gradient ( h τn− 1 ) ∘h τn+−11 ]
Q (τn+)1 = h (τn+−11)W q
( ) ( )
K τn+ 1 = h̃ τn+−11 W k
( ) ( )
V τn+ 1 = h̃ τn+−11 W v
( ) ( ) ( ) ( )
h τn+ 1 = t ransformer-la yer ( Q τn+ 1, K τn+ 1, V τn+ 1)

Note that both keys and values rely on extended hidden states, while queries only consume hidden states at the
current step. The concatenation operation [. ∘. ] is along the sequence length dimension. And Transformer-XL
needs to use relative positional encoding because previous and current segments would be assigned with the same
encoding if we encode absolute positions, which is undesired.
Compressive Transformer (Rae et al. 2019) extends Transformer-XL by compressing past memories to support
longer sequences. It explicitly adds memory slots of size m m per layer for storing past activations of this layer to
preserve long context. When some past activations become old enough, they are compressed and saved in an
additional compressed memory of size m cm per layer.

https://lilianweng.github.io/posts/2023-01-27-the-transformer-family-v2/[2023/11/11 23:09:05]
The Transformer Family Version 2.0 | Lil'Log

Fig. 6. Compressive transformer maintains two types of memory slots, memory and
compressed memory, to support long context. (Image source: Rae et al. 2019).

Both memory and compressed memory are FIFO queues. Given the model context length L , the compression
function of compression rate c is defined as f c : R L× d → R [ Lc ]× d, mapping L oldest activations to [ L ] compressed
c
memory elements. There are several choices of compression functions:

1. Max/mean pooling of kernel and stride size c;


2. 1D convolution with kernel and stride size c (need to learn additional parameters);
3. Dilated convolution (need to learn additional parameters). In their experiments, convolution compression
works out the best on EnWik8 dataset;
4. Most used memories.

Compressive transformer has two additional training losses:

1. Auto-encoding loss (lossless compression objective) measures how well we can reconstruct the original
memories from compressed memories

() ()
L ac = ∥old _ m em i − g( n ew _ cm i )∥2
where g : R [ Lc ]× d → R L× d reverses the compression function f .
2. Attention-reconstruction loss (lossy objective) reconstructs content-based attention over memory vs
compressed memory and minimize the difference:

() () () ()
L ar = ∥at t n ( h i , old _ m em i ) − at t n ( h i , n ew _ cm i )∥2
Transformer-XL with a memory of size m has a maximum temporal range of m × N , where N is the number of
layers in the model, and attention cost O ( L 2 + Lm ) . In comparison, compressed transformer has a temporal
range of ( m m + c ⋅m cm ) × N and attention cost O ( L 2 + L ( m m + m cm )) . A larger compression rate c gives
better tradeoff between temporal range length and attention cost.
Attention weights, from oldest to newest, are stored in three locations: compressed memory → memory → causally
masked sequence. In the experiments, they observed an increase in attention weights from oldest activations

https://lilianweng.github.io/posts/2023-01-27-the-transformer-family-v2/[2023/11/11 23:09:05]
The Transformer Family Version 2.0 | Lil'Log

stored in the regular memory, to activations stored in the compressed memory, implying that the network is
learning to preserve salient information.

Fig. 7. Attention weights with one standard deviation as error bars versus memory positions,
from oldest (left) to newest (right). (Image source: Rae et al. 2019).

Non-Differentiable External Memory


kNN-LM (Khandelwal et al. 2020) enhances a pretrained LM with a separate kNN model by linearly interpolating
the next token probabilities predicted by both models. The kNN model is built upon an external key-value store
which can store any large pre-training dataset or OOD new dataset. This datastore is preprocessed to save a large
number of pairs, (LM embedding representation of context, next token) and the nearest neighbor retrieval happens
in the LM embedding space. Because the datastore can be gigantic, we need to rely on libraries for fast dense
vector search such as FAISS or ScaNN. The indexing process only happens once and parallelism is easy to
implement at inference time.
At inference time, the next token probability is a weighted sum of two predictions:

p( y| x ) = λ pkNN( y| x ) + ( 1 − λ ) pLM( y| x )
pkNN( y| x ) ∝ ∑ �� [y = wi] exp ( − d( ki, f ( x )))
(ki ,wi )∈N

where N contains a set of nearest neighbor data points retrieved by kNN; d( . , . ) is a distance function such as L2
distance.
According to the experiments, larger datastore size or larger k is correlated with better perplexity. The weighting
scalar λ should be tuned, but in general it is expected to be larger for out-of-domain data compared to in-domain
data and larger datastore can afford a larger λ .
SPALM (Adaptive semiparametric language models; Yogatama et al. 2021) incorporates both (1) Transformer-XL
style memory for hidden states from external context as short-term memory and (2) kNN-LM style key-value store
as long memory.

https://lilianweng.github.io/posts/2023-01-27-the-transformer-family-v2/[2023/11/11 23:09:05]
The Transformer Family Version 2.0 | Lil'Log

Fig. 8. Illustration of how SPALM combines context memory of past hidden states (short
term memory) with an external key-value datastore (long term memory) to support longer
context. (Image source: Yogatama et al. 2021).

SPALM runs kNN search to fetch k tokens with most relevant context. For each token we can get the same
embedding representation provided by a pretrained LM, denoted as {y i}k= . The gating mechanism first
i 1
aggregates the retrieved token embeddings with a simple attention layer using h R
t
(the hidden state for token x t at
layer R ) as a query and then learns a gating parameter g t to balance between local information h R
t
and long-term
information m t.

k
exp ( y ⊤ R)
i ht
mt = ∑ ⋅y i
i= 1 ∑ k exp ( y ⊤h R)
j=1 j t
g t = σ( w ⊤
ght
R)

z t = ( 1 − g t) ⊙ m t + g t ⊙ h Rt
p( x t+ 1 ∣x ≤ t) = soft max( z t; W )

where w g is a parameter vector to learn; σ( . ) is sigmoid; W is the word embedding matrix shared between both
input and output tokens. Different from kNN-LM, they didn’t find the nearest neighbor distance to be helpful in the
aggregation of retrieved tokens.
During training, the key representations in the long-term memory stay constant, produced by a pretrained LM, but
the value encoder, aka the word embedding matrix, gets updated.
Memorizing Transformer (Wu et al. 2022) adds a kNN-augmented attention layer near the top stack of a
decoder-only Transformer. This special layer maintains a Transformer-XL style FIFO cache of past key-value pairs.
The same QKV values are used for both local attention and kNN mechanisms. The kNN lookup returns top-k (key,
value) pairs for each query in the input sequence and then they are processed through the self-attention stack to

https://lilianweng.github.io/posts/2023-01-27-the-transformer-family-v2/[2023/11/11 23:09:05]
The Transformer Family Version 2.0 | Lil'Log

compute a weighted average of retrieved values. Two types of attention are combined with a learnable per-head
gating parameter. To prevent large distributional shifts in value magnitude, both keys and values in the cache are
normalized.
What they found during experiments with Memorizing Transformer:
It is observed in some experiments that training models with a small memory and then finetuned with a larger
memory works better than training with a large memory from scratch.
The smaller Memorizing Transformer with just 8k tokens in memory can match the perplexity of a larger vanilla
Transformer with 5X more trainable parameters.
Increasing the size of external memory provided consistent gains up to a size of 262K.
A non-memory transformer can be finetuned to use memory.

Fig. 9. Fine-tuning a vanilla Transformer with a key-value memory can achieve similar
performance as training a memorizing transformer from scratch. (Image source: Wu et al.
2022).

Distance-Enhanced Attention Scores


Distance Aware Transformer(DA-Transformer; Wu, et al. 2021) and Attention with Linear Biases (ALiBi; Press
et al. 2022) are motivated by similar ideas — in order to encourage the model to extrapolate over longer context
than what the model is trained on, we can explicitly attach the positional information to every pair of attention
score based on the distance between key and query tokens.
Note that the default positional encoding in vanilla Transformer only adds positional information to the input
sequence, while later improved encoding mechanisms alter attention scores of every layer, such as rotary position
embedding, and they take on form very similar to distance enhanced attention scores.
DA-Transformer (Wu, et al. 2021) multiplies attention scores at each layer by a learnable bias that is formulated as a
function of the distance between key and query. Different attention heads use different parameters to distinguish
diverse preferences to short-term vs long-term context. Given two positions, i , j , DA-Transformer uses the
following weighting function to alter the self-attention score:

(i )
R = α iR where R ij = | i − j |

https://lilianweng.github.io/posts/2023-01-27-the-transformer-family-v2/[2023/11/11 23:09:05]
The Transformer Family Version 2.0 | Lil'Log

f ( R (i); β i) = 1 + exp ( β i)
1 + exp ( β i − R (i))
() () ()
( Q (i )
, (i )
, (i )
) ( ReLU( Q i K i ⊤) f ( R i ) ) (i)
at t n K V = row-soft max V
√d
where α i is a learnable parameters to weight relative distance differently per head where the head is indexed by
superscript (i); β i is a learnable parameter to control the upper bound and ascending slope wrt the distance for the
i -th attention head. The weighting function f ( . ) is designed in a way that: (1) f ( 0) = 1; (2) f ( R i ) = 0 when
()

R i → − ∞ ; (3) f ( R i ) is bounded when R i → + ∞ ; (4) the scale is tunable; (5) and the function is monotonic.
() () ()

The extra time complexity brought by f ( R (i)) is O ( L 2) and it is small relative to the self attention time complexity
O ( L 2d) . The extra memory consumption is minimal, ~O ( 2h ) .
Instead of multipliers, ALiBi (Press et al. 2022) adds a constant bias term on query-key attention scores,
proportional to pairwise distances. The bias introduces a strong recency preference and penalizes keys that are too
far away. The penalties are increased at different rates within different heads.

soft max( q iK ⊤ + α i ⋅[0, − 1, − 2, … , − ( i − 1)])


where α i is a head-specific weighting scalar. Different from DA-transformer, α i is not learned but fixed as a
geometric sequence; for example, for 8 heads, α i = 1 , 12 , … , 18 . The overall idea is very much similar to what
2 2 2
relative positional encoding aims to solve.

Fig. 10. Illustration of how ALiBi enhances attention scores with a positional bias term.
(Image source: Press et al. 2021).

With ALiBi, Press et al. (2022) trained a 1.3B model on context length 1024 during training and extrapolated to 2046
at inference time.

https://lilianweng.github.io/posts/2023-01-27-the-transformer-family-v2/[2023/11/11 23:09:05]
The Transformer Family Version 2.0 | Lil'Log

Fig. 11. Extrapolation experiments for running inference with Transformers of different
configs, including sinusoidal positional encoding, rotary positional encoding, simplified
relative positional encoding in T5 and ALiBi. All models were trained with small context
length but inference ran for much longer context. (Image source: Press et al. 2021).

Make it Recurrent
Universal Transformer (Dehghani, et al. 2019) combines self-attention in Transformer with the recurrent
mechanism in RNN, aiming to benefit from both a long-term global receptive field of Transformer and learned
inductive biases of RNN. Rather than going through a fixed number of layers, Universal Transformer dynamically
adjusts the number of steps using adaptive computation time. If we fix the number of steps, an Universal
Transformer is equivalent to a multi-layer Transformer with shared parameters across layers.
On a high level, the universal transformer can be viewed as a recurrent function for learning the hidden state
representation per token. The recurrent function evolves in parallel across token positions and the information
between positions is shared through self-attention.

Fig. 12. How the Universal Transformer refines a set of hidden state representations
repeatedly for every position in parallel. (Image source: Figure 1 in Dehghani, et al. 2019).

Given an input sequence of length L , Universal Transformer iteratively updates the representation h t ∈ R L× d at
step t for an adjustable number of steps. At step 0, h 0 is initialized to be same as the input embedding matrix. All
the positions are processed in parallel in the multi-head self-attention mechanism and then go through a recurrent

https://lilianweng.github.io/posts/2023-01-27-the-transformer-family-v2/[2023/11/11 23:09:05]
The Transformer Family Version 2.0 | Lil'Log

transition function.

A t = La yerNorm ( h t 1 + Mult iHeadAt t ent ion ( h t 1 + P t)


− −

h t = La yerNorm ( A t− 1 + T ransit ion ( A t))


where T ransit ion ( . ) is either a separable convolution or a fully-connected neural network that consists of two
position-wise (i.e. applied to each row of A t individually) affine transformation + one ReLU.
The positional encoding P t uses sinusoidal position signal but with an additional time dimension:

sin ( 10000i 2δ′/ d ) ⊕ sin ( 10000t 2δ′/ d ) if δ = 2δ′


P E ( i , t , δ) = {
cos( i 2δ′/ d ) ⊕ cos( t 2δ′/ d )
10000 10000
if δ = 2δ′ + 1

Fig. 13. A simplified illustration of Universal Transformer. The encoder and decoder share the
same basic recurrent structure. But the decoder also attends to final encoder representation
h T . (Image source: Figure 2 in Dehghani, et al. 2019)

In the adaptive version of Universal Transformer, the number of recurrent steps T is dynamically determined by
ACT. Each position is equipped with a dynamic ACT halting mechanism. Once a per-token recurrent block halts, it
stops taking more recurrent updates but simply copies the current value to the next step until all the blocks halt or
until the model reaches a maximum step limit.

Adaptive Modeling
Adaptive modeling refers to a mechanism that can adjust the amount of computation according to different inputs.
For example, some tokens may only need local information and thus demand a shorter attention span; Or some
tokens are relatively easier to predict and do not need to be processed through the entire attention stack.

https://lilianweng.github.io/posts/2023-01-27-the-transformer-family-v2/[2023/11/11 23:09:05]
The Transformer Family Version 2.0 | Lil'Log

Adaptive Attention Span


One key advantage of Transformer is the capability of capturing long-term dependencies. Depending on the
context, the model may prefer to attend further sometime than others; or one attention head may had different
attention pattern from the other. If the attention span could adapt its length flexibly and only attend further back
when needed, it would help reduce both computation and memory cost to support longer maximum context size
in the model.
This is the motivation for Adaptive Attention Span. Sukhbaatar et al (2019) proposed a self-attention mechanism
that seeks an optimal attention span. They hypothesized that different attention heads might assign scores
differently within the same context window (See Fig. 14) and thus the optimal span would be trained separately per
head.

Fig. 14. Two attention heads in the same model, A & B, assign attention differently within the
same context window. Head A attends more to the recent tokens, while head B look further
back into the past uniformly. (Image source: Sukhbaatar, et al. 2019)

Given the i -th token, we need to compute the attention weights between this token and other keys within its
attention span of size s :

eij = q ik j ⊤
exp ( eij )
a ij = soft max( eij ) =
∑ ir−= 1i− s exp ( eir )
i− 1 i− 1
yi = ∑ a ir v r = ∑ a ir x r W v
r = i− s r = i− s

A soft mask function m z is added to control for an effective adjustable attention span, which maps the distance
between query and key into a [0, 1] value. m z is parameterized by z ∈ [0, s ] and z is to be learned:

1
m z( x ) = clip ( ( R + z − x ) , 0, 1)
R
where R is a hyper-parameter which defines the softness of m z.

https://lilianweng.github.io/posts/2023-01-27-the-transformer-family-v2/[2023/11/11 23:09:05]
The Transformer Family Version 2.0 | Lil'Log

Fig. 15. The soft masking function used in the adaptive attention span. (Image source:
Sukhbaatar, et al. 2019.)

The soft mask function is applied to the softmax elements in the attention weights:

m z( i − j ) exp ( s ij )
a ij =
∑ ir−= 1i− s m z( i − r ) exp ( s ir )

In the above equation, z is differentiable so it is trained jointly with other parts of the model. Parameters
z i , i = 1, … , h are learned separately per head. Moreover, the loss function has an extra L1 penalty on ∑
() h (i).
i= 1 z
Using Adaptive Computation Time, the approach can be further enhanced to have flexible attention span length,
adaptive to the current input dynamically. The span parameter zt of an attention head at time t is a sigmoidal
function, zt = Sσ( v ⋅x t + b) , where the vector v and the bias scalar b are learned jointly with other parameters.
In the experiments of Transformer with adaptive attention span, Sukhbaatar, et al. (2019) found a general tendency
that lower layers do not require very long attention spans, while a few attention heads in higher layers may use
exceptionally long spans. Adaptive attention span also helps greatly reduce the number of FLOPS, especially in a
big model with many attention layers and a large context length.

Depth-Adaptive Transformer
At inference time, it is natural to assume that some tokens are easier to predict and thus do not require as much
computation as others. Therefore we may only process its prediction through a limited number of layers to achieve
a good balance between speed and performance.
Both Depth-Adaptive Transformer (Elabyad et al. 2020) and Confident Adaptive Language Model (CALM;
Schuster et al. 2022) are motivated by this idea and learn to predict optimal numbers of layers needed for different
input tokens.
Depth-adaptive transformer (Elabyad et al. 2020) attaches an output classifier to every layer to produce exit
predictions based on activations of that layer. The classifier weight matrices can be different per layer or shared
across layers. During training, the model sample different sequences of exits such that the model is optimized with
hidden states of different layers. The learning objective incorporates likelihood probabilities predicted at different
layers, n = 1, … , N :

|y|

LLnt = log p( yt| h nt− 1) LL = ∑ LL nt


n
t= 1

Adaptive depth classifiers outputs a parametric distribution qt. It is trained with cross entropy loss against an oracle

https://lilianweng.github.io/posts/2023-01-27-the-transformer-family-v2/[2023/11/11 23:09:05]
The Transformer Family Version 2.0 | Lil'Log

distribution qt∗. The paper explored three confiurations for how to learn such a classifier qt.

Fig. 16. Illustration of three types of adaptive depth classifiers.


(Image source: Elabyad et al. 2020).

1. Sequence-specific depth classifier: All tokens of the same sequence share the same exit block. It depends on
the average of the encoder representation of the sequence. Given an input sequence x of length L , the
classifier takes x̄ = 1 ∑ L x t as input and outputs a multinomial distribution of N dimensions,
L t= 1
corresponding to N layers.

q( n | x ) = soft max( W nx̄ + bn) ∈ R N


∗( , y)
qlik x = δ( ar g max LLn − λn )
n
∗ ( , y)
or qcorr x = δ( ar g max C n − λn ) where C n = |{t | yt = ar g max p( y| h nt− 1)}|
n y

where δ is dirac delta (unit impulse) function and − λn is a regularization term to encourage lower layer exits.
The ground truth q∗can be prepared in two way, based on maximum likelihood q∗ or correctness q∗ .
lik corr

2. Token-specific depth classifier (multinomial): Each token is decoded with different exit block, predicted
conditioned on the first decoder hidden state h 1:
t

qt( n | x , y < t) = soft max( W nh 1t + bn)

3. Token-specific depth classifier (geometric-like): A binary exit prediction distribution is made per layer per
token, X n . The RBF kernel κ ( t , t ′) = exp ( |t− t′|2 ) is used to smooth the predictions to incorporate the impact
t σ
of current decision on future time steps.

Xtn = sigmoid ( w ⊤ n
n h t + bn
) ∀n ∈ [1, … , N − 1]
n∏ ( n ′)
qt( n | x , y < t) = { Xt n′< n 1 − X t if n < N
∏ n ′ < N ( 1 − Xt )n ′
ot herwise
|y|
n n
∗( , y)
qlik x = δ( ar g max L̃Lt − λn ) where L̃Lt = ∑ κ ( t , t ′) LL nt′
n ′t =1
|y|

https://lilianweng.github.io/posts/2023-01-27-the-transformer-family-v2/[2023/11/11 23:09:05]
The Transformer Family Version 2.0 | Lil'Log

∗ ( , y)
or qcor x = δ( ar g max C~tn − λn ) where C tn = �� [yt = ar g max p( y| h nt− 1)], C~tn = ∑ κ ( t , t ′) C tn′
n y ′ t =1

At inference time, the confidence threshold for making an exit decision needs to be calibrated. Depth-adaptive
transformer finds such a threshold on a validation set via grid search. CALM (Schuster et al. 2022) applied the Learn
then Test (LTT) framework (Angelopoulos et al. 2021) to identify a subset of valid thresholds and chose the
minimum value as the threshold for inference. Except for training per-layer exit classifier, CALM also explored other
methods for adaptive depth prediction, including the softmax responses (i.e. difference between top two softmax
outputs) and hidden state saturation (i.e. cos( h n , h n + 1) ) as confidence scores for exit decisions. They found
t t
softmax responses result in best inference speedup.

Efficient Attention
The computation and memory cost of the vanilla Transformer grows quadratically with sequence length and hence
it is hard to be applied on very long sequences. Many efficiency improvements for Transformer architecture have
something to do with the self-attention module - making it cheaper, smaller or faster to run. See the survey paper
on Efficient Transformers (Tay et al. 2020).

Sparse Attention Patterns


Fixed Local Context

A simple alternation to make self-attention less expensive is to restrict the attention span of each token to local
context only, so that self-attention grows linearly with the sequence length.
The idea was introduced by Image Transformer (Parmer, et al 2018), which formulates image generation as
sequence modeling using an encoder-decoder transformer architecture:
The encoder generates a contextualized, per-pixel-channel representation of the source image;
Then the decoder autoregressively generates an output image, one channel per pixel at each time step.
Let’s label the representation of the current pixel to be generated as the query q . Other positions whose
representations will be used for computing q are key vector k 1, k 2, … and they together form a memory matrix
M . The scope of M defines the context window for pixel query q .
Image Transformer introduced two types of localized M , as illustrated below.

https://lilianweng.github.io/posts/2023-01-27-the-transformer-family-v2/[2023/11/11 23:09:05]
The Transformer Family Version 2.0 | Lil'Log

Fig. 17. Illustration of 1D and 2D attention span for visual inputs in Image Transformer. The
black line marks a query block and the cyan outlines the actual attention span for pixel q.
(Image source: Figure 2 in Parmer et al, 2018)

1. 1D Local Attention: The input image is flattened in the raster scanning order, that is, from left to right and top
to bottom. The linearized image is then partitioned into non-overlapping query blocks. The context window
consists of pixels in the same query block as q and a fixed number of additional pixels generated before this
query block.

2. 2D Local Attention: The image is partitioned into multiple non-overlapping rectangular query blocks. The
query pixel can attend to all others in the same memory blocks. To make sure the pixel at the top-left corner
can also have a valid context window, the memory block is extended to the top, left and right by a fixed
amount, respectively.

Strided Context

Sparse Transformer (Child et al., 2019) introduced factorized self-attention, through sparse matrix factorization,
making it possible to train dense attention networks with hundreds of layers on sequence length up to 16,384,
which would be infeasible on modern hardware otherwise.
Given a set of attention connectivity pattern S = {S 1, … , S n }, where each S i records a set of key positions that
the i -th query vector attends to.

At t end ( X , S ) = ( a ( x i, S i) )
i∈{1,… ,L}
( x iW q)( x j W k) ⊤
j ∈S i )
where a ( x i, S i) = soft max( ( x j W v) j ∈S i
√ dk
Note that although the size of S i is not fixed, a ( x i, S i) is always of size dv and thus At t end ( X , S ) ∈ R L× dv.

https://lilianweng.github.io/posts/2023-01-27-the-transformer-family-v2/[2023/11/11 23:09:05]
The Transformer Family Version 2.0 | Lil'Log

In anto-regressive models, one attention span is defined as S i = {j : j ≤ i } as it allows each token to attend to
all the positions in the past.
In factorized self-attention, the set S i is decomposed into a tree of dependencies, such that for every pair of ( i , j )
where j ≤ i , there is a path connecting i back to j and i can attend to j either directly or indirectly.
Precisely, the set S i is divided into p non-overlapping subsets, where the m -th subset is denoted as
Aim ⊂ S i, m = 1, … , p. Therefore the path between the output position i and any j has a maximum length
( )

p + 1. For example, if ( j , a , b, c, … , i ) is a path of indices between i and j , we would have


j ∈ A(a1), a ∈ A(b2), b ∈ A(c3), … , so on and so forth.
Sparse Factorized Attention
Sparse Transformer proposed two types of fractorized attention. It is easier to understand the concepts as
illustrated in Fig. 10 with 2D image inputs as examples.

Fig. 18. The top row illustrates the attention connectivity patterns in (a) Transformer, (b)
Sparse Transformer with strided attention, and (c) Sparse Transformer with fixed attention.
The bottom row contains corresponding self-attention connectivity matrices. Note that the
top and bottom rows are not in the same scale. (Image source: Child et al., 2019 + a few of
extra annotations.)

1. Strided attention with stride ℓ ∼ √ n . This works well with image data as the structure is aligned with strides.
In the image case, each pixel would attend to all the previous ℓ pixels in the raster scanning order (naturally
cover the entire width of the image) and then those pixels attend to others in the same column (defined by
another attention connectivity subset).

( )
Ai1 = {t , t + 1, … , i }, where t = max( 0, i − ℓ)
( )
Ai2 = {j : ( i − j ) mod ℓ = 0}

https://lilianweng.github.io/posts/2023-01-27-the-transformer-family-v2/[2023/11/11 23:09:05]
The Transformer Family Version 2.0 | Lil'Log

2. Fixed attention. A small set of tokens summarize previous locations and propagate that information to all
future locations.

( ) j i
Ai1 = {j : ⌊ ⌋= ⌊ ⌋}
ℓ ℓ
(2)
Ai = {j : j mod ℓ ∈ {ℓ − c, … , ℓ − 1}}
where c is a hyperparameter. If c = 1, it restricts the representation whereas many depend on a few
positions. The paper chose c ∈ {8, 16, 32} for ℓ ∈ {128, 256}.

Use Factorized Self-Attention in Transformer


There are three ways to use sparse factorized attention patterns in Transformer architecture:

1. One attention type per residual block and then interleave them,
W o, where n is the index of the current residual block.
( mod p))
at t n ( X ) = At t end ( X , A n
2. Set up a single head which attends to locations that all the factorized heads attend to,
at t n ( X ) = At t end ( X , ∪pm = 1A m ) W o.
( )

3. Use a multi-head attention mechanism, but different from vanilla Transformer, each head might adopt a
pattern presented above, 1 or 2. → This option often performs the best.

Sparse Transformer also proposed a set of changes so as to train the Transformer up to hundreds of layers,
including gradient checkpointing, recomputing attention & FF layers during the backward pass, mixed precision
training, efficient block-sparse implementation, etc. Please check the paper for more details or my previous post on
techniques for scaling up model training.
Blockwise Attention (Qiu et al. 2019) introduces a sparse block matrix to only allow each token to attend to a
small set of other tokens. Each attention matrix of size L × L is partitioned into n × n smaller blocks of size
L
n× L and a sparse block matrix M ∈ {0, 1}L× L is defined by a permutation π of 1, … , n , which records the
n
column index per row in the block matrix.

QK ⊤
at t n ( Q , K , V , M ) = soft max( ⊙ M) V
√d
( A ⊙ M ) ij = { Aij if M ij = 1
−∞ if M ij = 0
( ⌊(i− 1)n + 1⌋) = ⌊(j − 1)n + 1⌋
where M ij = { 1 if π L L
0 ot herwise
The actual implementation of Blockwise Attention only stores QKV as block matrices, each of size n × n :

⎡ ( ^q1^kπ(1) ) v^ π(1)


⎢ soft max √d ⎥
Blockwise-at t n ( Q , K , V , M ) = ⋮
^q n ^k ⊤
⎣ soft max( ⊙ ) v^ π(n) ⎦
π (n )
√d

where q^ , ^ and are the -the row in the QKV block matrix respectively. Each q ⊤ , , , is of size

https://lilianweng.github.io/posts/2023-01-27-the-transformer-family-v2/[2023/11/11 23:09:05]
The Transformer Family Version 2.0 | Lil'Log

i ki v^ i i ik π (i ) ∀i = 1 … n
N
× N and therefore Blockwise Attention is able to reduce the memory complexity of attention matrix from
n n
O ( L 2) to O ( Ln × Ln × n ) = O ( L 2/ n ) .

Combination of Local and Global Context

ETC (Extended Transformer Construction; Ainslie et al. 2019), Longformer (Beltagy et al. 2020) and Big Bird
(Zaheer et al. 2020) models combine both local and global context when building an attention matrix. All these
models can be initialized from existing pretrained models.
Global-Local Attention of ETC (Ainslie et al. 2019) takes two inputs, (1) the long input x l of size n l which is the
regular input sequence and (2) the global input x g of size n g which contains a smaller number of auxiliary tokens,
n g ≪ n l. Attention is thus split into four components based on directional attention across these two inputs: g2g,
g2l, l2g and l2l. Because the l2l attention piece can be very large, it is restricted to a fixed size attention span of
radius w (i.e. local attention span) and the l2l matrix can be reshaped to n l × ( 2w + 1) .
ETC utilizes four binary matrices to handle structured inputs, M g2g, M g2l, M l2g and M l2l. For example, each
element zg ∈ R d in the attention output zg = ( zg, … , zgn g) for g2g attention piece is formatted as:
i 1

g g 1 g Q( g K + K ) ⊤ − ( − g g
a ij2 = xi W xj W P ij 1 M ij2 ) C
√d
g g ng
g2g exp ( a ij2 ) g ∑ Agj2gx gjW V
Ai j = z i = i
∑ nk=g 1 exp ( a gik2g) j=1

where P Kj is a learnable vector for relative position encoding and C is a very large constant (C = 10000 in the
i
paper) to offset any attention weights when mask is off.

Fig. 19. Attention patterns of ETC, Longformer and Big Bird.

One more update in ETC is to incorporate a CPC (contrastive predictive coding) task using NCE loss into the
pretraining stage, besides the MLM task: The representation of one sentence should be similar to the
representation of context around it when this sentence is masked.
The global input x g for ETC is constructed as follows: Assuming there are some segments within the long inputs

https://lilianweng.github.io/posts/2023-01-27-the-transformer-family-v2/[2023/11/11 23:09:05]
The Transformer Family Version 2.0 | Lil'Log

(e.g. by sentence), each segment is attached with one auxiliary token to learn global inputs. Relative position
encoding is used to mark the global segment tokens with the token position. Hard masking in one direction (i.e.,
tokens before vs after are labeled differently) is found to bring performance gains in some datasets.
Attention pattern in Longformer contains three components:

1. Local attention: Similar to ETC, local attention is controlled by a sliding window of fixed size w;
2. Global attention of preselected tokens: Longformer has a few pre-selected tokens (e.g. [CLS] token)
assigned with global attention span, that is, attending to all other tokens in the input sequence.
3. Dilated attention: Dilated sliding window of fixed size r and gaps of dilation size d, similar to Sparse
Transformer;

Big Bird is quite similar to Longformer, equipped with both local attention and a few preselected tokens with global
attention span, but Big Bird replaces dilated attention with a new mechanism where all tokens attend to a set of
random tokens. The design is motivated by the fact that attention pattern can be viewed as a directed graph and a
random graph has the property that information is able to rapidly flow between any pair of nodes.
Longformer uses smaller window size at lower layers and larger window sizes at higher layers. Ablation studies
showed that this setup works better than reversed or fixed size config. Lower layers do not have dilated sliding
windows to better learn to use immediate local context. Longformer also has a staged training procedure where
initially the model is trained with small window size to learn from local context and then subsequent stages of
training have window sizes increased and learning rate decreased.

Content-based Attention
The improvements proposed by Reformer (Kitaev, et al. 2020) aim to solve the following pain points in vanilla
Transformer:
Quadratic time and memory complexity within self-attention module.
Memory in a model with N layers is N -times larger than in a single-layer model because we need to store
activations for back-propagation.
The intermediate FF layers are often quite large.
Reformer proposed two main changes:

1. Replace the dot-product attention with locality-sensitive hashing (LSH) attention, reducing the complexity
from O ( L 2) to O ( L log L ) .
2. Replace the standard residual blocks with reversible residual layers, which allows storing activations only once
during training instead of N times (i.e. proportional to the number of layers).

Locality-Sensitive Hashing Attention


In Q K ⊤ part of the attention formula, we are only interested in the largest elements as only large elements
contribute a lot after softmax. For each query q i ∈ Q , we are looking for row vectors in K closest to q i. In order to
find nearest neighbors quickly in high-dimensional space, Reformer incorporates Locality-Sensitive Hashing (LSH)
into its attention mechanism.
A hashing scheme x ↦ h ( x ) is locality-sensitive if it preserves the distancing information between data points,

https://lilianweng.github.io/posts/2023-01-27-the-transformer-family-v2/[2023/11/11 23:09:05]
The Transformer Family Version 2.0 | Lil'Log

such that close vectors obtain similar hashes while distant vectors have very different ones. The Reformer adopts a
hashing scheme as such, given a fixed random matrix R ∈ R d× b/ 2 (where b is a hyperparam), the hash function is
h ( x ) = ar g max([xR ; − xR ]) .

Fig. 20. Illustration of Locality-Sensitive Hashing (LSH) attention. (Image source: right part of
Figure 1 in Kitaev, et al. 2020).

In LSH attention, a query can only attend to positions in the same hashing bucket, S i = {j : h ( q i) = h ( k j )}. It is
carried out in the following process, as illustrated in Fig. 20:
(a) The attention matrix for full attention is often sparse.
(b) Using LSH, we can sort the keys and queries to be aligned according to their hash buckets.
(c) Set Q = K (precisely k j = q j / | q j | ), so that there are equal numbers of keys and queries in one bucket, easier
for batching. Interestingly, this “shared-QK” config does not affect the performance of the Transformer.
(d) Apply batching where chunks of m consecutive queries are grouped together.

Fig. 21. The LSH attention consists of 4 steps: bucketing, sorting, chunking, and attention
computation. (Image source: left part of Figure 1 in Kitaev, et al. 2020).

https://lilianweng.github.io/posts/2023-01-27-the-transformer-family-v2/[2023/11/11 23:09:05]
The Transformer Family Version 2.0 | Lil'Log

Reversible Residual Network


Another improvement by Reformer is to use reversible residual layers (Gomez et al. 2017). The motivation for
reversible residual network is to design the architecture in a way that activations at any given layer can be
recovered from the activations at the following layer, using only the model parameters. Hence, we can save
memory by recomputing the activation during backprop rather than storing all the activations.
Given a layer x ↦ y, the normal residual layer does y = x + F ( x ) , but the reversible layer splits both input and
output into pairs ( x 1, x 2) ↦ ( y1, y2) and then executes the following:

y1 = x 1 + F ( x 2) , y2 = x 2 + G ( y1)
and reversing is easy:

x 2 = y2 − G ( y1) , x 1 = y1 − F ( x 2)
Reformer applies the same idea to Transformer by combination attention ( F ) and feed-forward layers (G ) within a
reversible net block:

Y1 = X 1 + At t ent ion ( X 2) , Y2 = X 2 + FeedForward ( Y1)


The memory can be further reduced by chunking the feed-forward computation:

( ) () ( ) ( ) () ()
Y2 = [Y2 1 ; … ; Y2 c ] = [X 21 + FeedForward ( Y1 1 ) ; … ; X 2c + FeedForward ( Y1 c )]

The resulting reversible Transformer does not need to store activation in every layer.
Routing Transformer (Roy et al. 2021) is also built on content-based clustering of keys and queries. Instead of
using a static hashing function like LSH, it utilizes online k-means clustering and combines it with local, temporal
sparse attention to reduce the attention complexity from O( L 2) to O( L 1.5) .
Within routing attention, both keys and queries are clustered with k-means clustering method and the same set of
centroids μ = ( μ1, … , μk) ∈ R k× d. Queries are routed to keys that get assigned to the same centroid. The total
complexity is O( Lkd + L 2d/ k) , where O( Lkd) is for running clustering assignments and O( L 2d/ k) is for
attention computation. The cluster centroids are updated by EMA (exponential moving average) using all
associated keys and queries.
In the experiments for Routing Transformer, some best config only has routing attention enabled in the last two
layers of the model and half of the attention heads, while the other half utilizing local attention. They also observed
that local attention is a pretty strong baseline and larger attention window always leads to better results.

Low-Rank Attention
Linformer (Wang et al. 2020) approximates the full attention matrix with a low rank matrix, reducing the time &
space complexity to be linear. Instead of using expensive SVD to identify low rank decomposition, Linformer adds
two linear projections E i, F i ∈ R L× k for key and value matrices, respectively, reducing their dimensions from
L × d to k × d. As long as k ≪ L , the attention memory can be greatly reduced.

head i = at t n ( X qW qi, E iX kW ki, F iX vW vi)


q( k) ⊤
= ( X qW i E iX kW i ) v

https://lilianweng.github.io/posts/2023-01-27-the-transformer-family-v2/[2023/11/11 23:09:05]
The Transformer Family Version 2.0 | Lil'Log

soft max F iX vW i
√d
  
low rank attention matrix Ā∈Rk× d

Additional techniques can be applied to further improve efficiency of Linformer:


Parameter sharing between projection layers, such as head-wise, key-value and layer-wise (across all layers)
sharing.
Use different k at different layers, as heads in higher layers tend to have a more skewed distribution (lower rank)
and thus we can use smaller k at higher layers.
Use different types of projections; e.g. mean/max pooling, convolution layer with kernel and stride L / k.

Fig. 22. (Left) Informer has two projection layers added for keys and values. (Right) Plot of
inference time as a function of sequence length. (Image source: Wang et al. 2020).

Random Feature Attention (RFA; Peng et al. 2021) relies on random feature methods (Rahimi & Recht, 2007) to
approximate softmax operation in self-attention with low rank feature maps in order to achieve linear time and
space complexity. Performers (Choromanski et al. 2021) also adopts random feature attention with improvements
on the kernel construction to further reduce the kernel approximation error.
The main theorem behind RFA is from Rahimi & Recht, 2007:

Let ϕ : R d → R 2D be a nonlinear transformation:

ϕ( x ) = 1 [ ( ⊤ ), ) ( ⊤ ) ( ⊤ )]⊤
sin w 1 x … , sin ( w ⊤
D x , cos w 1 x , … , cos w D x
√D
When -dimensional random vectors are i.i.d. from ( , 2 ),

https://lilianweng.github.io/posts/2023-01-27-the-transformer-family-v2/[2023/11/11 23:09:05]
The Transformer Family Version 2.0 | Lil'Log

d wi N 0 σ Id

∥x − y ∥2
[ϕ (
E wi x ⋅ ) ϕ ( y )] = ex p ( − )
2σ2
An unbiased estimation of exp ( x ⋅y ) is:

1
exp ( x ⋅y / σ2) = exp ( 2 (∥x ∥2 + ∥y ∥2 − ∥x − y ∥2)

∥x ∥2 ∥y ∥2 ∥ y ∥2
= exp ( ) exp ( )( − x − )
2σ2 2σ2 2σ2
∥x ∥2 ∥y ∥2
≈ exp ( ) exp ( ) ϕ( x ) ⋅ϕ( y )
2σ2 2σ2
= exp ( 12 ) ϕ( x ) ⋅ϕ( y ) ; unit vect ors
σ
Then we can write the attention function as follows, where ⊗ is outer product operation and σ2 is the temperature:

exp ( q t ⋅k i/ σ2) ϕ ( q t) ϕ ( k i) v ⊤
at t n ( q t, {k i}, {v i}) = ∑ ⊤
vi ≈ ∑ i
i ∑ j exp ( q t ⋅k j / σ2) i ∑ j ϕ ( q t) ϕ ( k j )

ϕ ( q t) ⊤ ∑ i ϕ ( k i) ⊗ v i
= = RFA( q t, {k i}, {v i})
(
ϕ qt ) ⊤ ∑ (
j ϕ kj )

Fig. 23. (Left) The order of computation for default softmax operation. (Right) The order of
computation when using random feature attention, a lot cheaper than default softmax.
(Image source: Peng et al. 2021).

Causal Attention RFA has token at time step t only attend to earlier keys and values {k i}i≤ t, {v i}i≤ t . Let us use a
tuple of variables, ( S t ∈ R 2D× d, z ∈ R 2D ) , to track the hidden state history at time step t , similar to RNNs:

ϕ( q t) ⊤S t
causal-RFA( q t, {k i}i≤ t, {v i}i≤ t) =
ϕ( q t) ⋅z t
where S t = S t− 1 + ϕ( k t) ⊗ v t, z t = z t− 1 + ϕ ( k t)
where 2D is the size of ϕ ( . ) and D should be no less than the model size d for reasonable approximation.
RFA leads to significant speedup in autoregressive decoding and the memory complexity mainly depends on the
choice of D when constructing the kernel ϕ ( . ) .
Performer modifies the random feature attention with positive random feature maps to reduce the estimation

https://lilianweng.github.io/posts/2023-01-27-the-transformer-family-v2/[2023/11/11 23:09:05]
The Transformer Family Version 2.0 | Lil'Log

error. It also keeps the randomly sampled w 1, … , w D to be orthogonal to further reduce the variance of the
estimator.

Fig. 24. Comparison of approximation error when using (Left) i.i.d vs orthogonal features
and (Right) sin/cos vs positive random features. (Image source: Choromanski et al. 2021).

Transformers for Reinforcement Learning


The self-attention mechanism avoids compressing the whole past into a fixed-size hidden state and does not suffer
from vanishing or exploding gradients as much as RNNs. Reinforcement Learning tasks can for sure benefit from
these traits. However, it is quite difficult to train Transformer even in supervised learning, let alone in the RL context.
It could be quite challenging to stabilize and train a LSTM agent by itself, after all.
The Gated Transformer-XL (GTrXL; Parisotto, et al. 2019) is one attempt to use Transformer for RL. GTrXL
succeeded in stabilizing training with two changes on top of Transformer-XL:

1. The layer normalization is only applied on the input stream in a residual module, but NOT on the shortcut
stream. A key benefit to this reordering is to allow the original input to flow from the first to last layer.
2. The residual connection is replaced with a GRU-style (Gated Recurrent Unit; Chung et al., 2014) gating
mechanism.

() ()
r = σ( Wr l y + Ur l x )
() () (l)
z = σ( Wzl y + Uzl x − bg )
^h = t anh ( Wg(l)y + Ug(l)( r ⊙ x ))
g(l)( x , y) = ( 1 − z) ⊙ x + z ⊙ ^h

The gating function parameters are explicitly initialized to be close to an identity map - this is why there is a bg
term. A bg > 0 greatly helps with the learning speedup.

https://lilianweng.github.io/posts/2023-01-27-the-transformer-family-v2/[2023/11/11 23:09:05]
The Transformer Family Version 2.0 | Lil'Log

Fig. 25. Comparison of the model architecture of Transformer-XL, Transformer-XL with the
layer norm reordered, and Gated Transformer-XL. (Image source: Figure 1 in Parisotto, et al.
2019)

Decision Transformer (DT; Chen et al 2021) formulates Reinforcement Learning problems as a process of
conditional sequence modeling, outputting the optimal actions conditioned on the desired return, past states and
actions. It therefore becomes straightforward to use Transformer architecture. Decision Transformer is for off-policy
RL, where the model only has access to a fixed collection of trajectories collected by other policies.
To encourage the model to learn how to act in order to achieve a desired return, it feeds the model with desired
future return R ^ = ∑ Tt′= t r t′ instead of the current reward. The trajectory consists of a list of triplets, (return-to-go
R^ t, states_t, action a_t$), and it is used as an input sequence for Transformer:

τ = ( R^ 1, s 1, a 1, R^ 2, s 2, a 2, … , R^ T , s T , a T )
Three linear layers are added and trained for return-to-go, state and action respectively to extract token
embeddings. The prediction head learns to predict a t corresponding to the input token s t. The training uses cross-
entropy loss for discrete actions or MSE for continuous actions. Predicting the states or return-to-go was not found
to help improve the performance in their experiments.
The experiments compared DT with several model-free RL algorithm baselines and showed that:
DT is more efficient than behavior cloning in low data regime;
DT can model the distribution of returns very well;
Having a long context is crucial for obtaining good results;
DT can work with sparse rewards.

Citation

https://lilianweng.github.io/posts/2023-01-27-the-transformer-family-v2/[2023/11/11 23:09:05]
The Transformer Family Version 2.0 | Lil'Log

Cited as:

Weng, Lilian. (Jan 2023). The transformer family version 2.0. Lil’Log. https://lilianweng.github.io/posts/2023-01-
27-the-transformer-family-v2/.

Or

@article{weng2023transformer,
title = "The Transformer Family Version 2.0",
author = "Weng, Lilian",
journal = "lilianweng.github.io",
year = "2023",
month = "Jan",
url = "https://lilianweng.github.io/posts/2023-01-27-the-transformer-family-v2/"
}

References
[1] Ashish Vaswani, et al. “Attention is all you need.” NIPS 2017.
[2] Rami Al-Rfou, et al. “Character-level language modeling with deeper self-attention.” AAAI 2019.
[3] Olah & Carter, “Attention and Augmented Recurrent Neural Networks”, Distill, 2016.
[4] Sainbayar Sukhbaatar, et al. “Adaptive Attention Span in Transformers”. ACL 2019.
[5] Rewon Child, et al. “Generating Long Sequences with Sparse Transformers” arXiv:1904.10509 (2019).
[6] Nikita Kitaev, et al. “Reformer: The Efficient Transformer” ICLR 2020.
[7] Alex Graves. (“Adaptive Computation Time for Recurrent Neural Networks”)[https://arxiv.org/abs/1603.08983]
[8] Niki Parmar, et al. “Image Transformer” ICML 2018.
[9] Zihang Dai, et al. “Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context.” ACL 2019.
[10] Aidan N. Gomez, et al. “The Reversible Residual Network: Backpropagation Without Storing Activations” NIPS
2017.
[11] Mostafa Dehghani, et al. “Universal Transformers” ICLR 2019.
[12] Emilio Parisotto, et al. “Stabilizing Transformers for Reinforcement Learning” arXiv:1910.06764 (2019).
[13] Rae et al. “Compressive Transformers for Long-Range Sequence Modelling.” 2019.
[14] Press et al. “Train Short, Test Long: Attention With Linear Biases Enables Input Length Extrapolation.” ICLR 2022.
[15] Wu, et al. “DA-Transformer: Distance Aware Transformer” 2021.
[16] Elabyad et al. “Depth-Adaptive Transformer.” ICLR 2020.
[17] Schuster et al. “Confident Adaptive Language Modeling” 2022.
[18] Qiu et al. “Blockwise self-attention for long document understanding” 2019
[19] Roy et al. “Efficient Content-Based Sparse Attention with Routing Transformers.” 2021.
[20] Ainslie et al. “ETC: Encoding Long and Structured Inputs in Transformers.” EMNLP 2019.
[21] Beltagy et al. “Longformer: The long-document transformer.” 2020.
[22] Zaheer et al. “Big Bird: Transformers for Longer Sequences.” 2020.
[23] Wang et al. “Linformer: Self-Attention with Linear Complexity.” arXiv preprint arXiv:2006.04768 (2020).
[24] Tay et al. 2020 “Sparse Sinkhorn Attention.” ICML 2020.
[25] Peng et al. “Random Feature Attention.” ICLR 2021.

https://lilianweng.github.io/posts/2023-01-27-the-transformer-family-v2/[2023/11/11 23:09:05]
The Transformer Family Version 2.0 | Lil'Log

[26] Choromanski et al. “Rethinking Attention with Performers.” ICLR 2021.


[27] Khandelwal et al. “Generalization through memorization: Nearest neighbor language models.” ICLR 2020.
[28] Yogatama et al. “Adaptive semiparametric language models.” ACL 2021.
[29] Wu et al. “Memorizing Transformers.” ICLR 2022.
[30] Su et al. “Roformer: Enhanced transformer with rotary position embedding.” arXiv preprint arXiv:2104.09864
(2021).
[31] Shaw et al. “Self-attention with relative position representations.” arXiv preprint arXiv:1803.02155 (2018).
[32] Tay et al. “Efficient Transformers: A Survey.” ACM Computing Surveys 55.6 (2022): 1-28.
[33] Chen et al., “Decision Transformer: Reinforcement Learning via Sequence Modeling” arXiv preprint
arXiv:2106.01345 (2021).

architecture attention transformer foundation long-read reinforcement-learning

Prompt Engineering »
Large Transformer Model Inference Optimization

© 2023 Lil'Log Powered by Hugo & PaperMod

https://lilianweng.github.io/posts/2023-01-27-the-transformer-family-v2/[2023/11/11 23:09:05]

You might also like