The Animated Transformer

by Praveen Sampath

The Transformer is foundational to the recent advancements in large language models (LLMs). In this article, we will attempt to unravel some of its inner workings and hopefully gain some insight into how these models function.

The only prerequisite for following along with this article is a basic understanding of linear algebra - if you know how to multiply matrices, you are good to go.

Let's begin!

What is a Transformer?

The Transformer is a machine learning model for sequence modeling. Given a sequence of things, the model can predict what the next thing in the sequence might be. In this article, we will look at word sequence prediction, but you can apply Transformers to any kind of sequential data.

As an example, take the following phrase that we would like to complete (a.k.a. the “prompt”):

“The robots will bring _________”

Conceptually, we might think of the Transformer as a function that operates on this phrase as input:

NOTE

Tap/click to pause animations, hold and drag to seek.

The input sentence is converted to a sequence of tokens (read: uniquely identifying numbers) before being passed as input to the model. The reason for this will become evident in the next section, where we will discuss how the input to the model is prepared.

The model parameters, denoted by θ, are the set of weights of the model that are tuned to the training data. We will take a peek into what θ contains, in each of the following sections.

Tokenization

The Transformer operates on sequences, and so as a first step, we need to tokenize the given English phrase to a sequence of tokens that can be passed as input to the model. One obvious approach is to treat each word in the sentence as a token.

The model doesn't understand words, it identifies tokens using a unique number assigned to each token. To achieve this, we assign a unique number to each word in our dictionary:

vocabulary

The call to the Transformer for our input sequence “the robots will bring” might then look something like:

vocabulary

(where token number 2532 is the model's best guess for the token that might appear next in the sequence)

We are now ready to pass our input to the Transformer. In the next several sections, we will walk through all the major steps involved in getting to the desired output: the next word in the sequence.

NOTE

In this article we only discuss the architecture of auto-regressive, decoder-only Transformers, like the GPT family of models from OpenAI. For simplicity, we do not consider encoder-decoder architectures such as the one described in the seminal “Attention is All You Need” 2017 paper.

We also do not cover model training, i.e., how the model parameters for the Transformer are tweaked and tuned to data (doing so would have made this already very long article unbearably long). Instead we take a trained model, and walk through the computation that takes place during inference.

Most of the details described here come from the excellent nanoGPT implementation by Andrej Karpathy, which is roughly equivalent to GPT-2.

1. Embeddings: Numbers Speak Louder than Words

For each token, the Transformer maintains a vector called an “embedding”. An embedding aims to capture the semantic meaning of the token - similar tokens have similar embeddings.

The input tokens are mapped to their corresponding embeddings:

Our transformer has embedding vectors of length 768. All the embeddings can be packed together in a single T × C matrix, where T = 4 is the number of input tokens, and C = 768, the size of each embedding.

In order to capture the significance of the position of a token within a sequence, the Transformer also maintains embeddings for each position. Here, we fetch embeddings for positions 0 to 3, since we only have 4 tokens:

Finally, these two T × C matrices are added together to obtain a position-dependent embedding for each token:

The token and position embeddings are all part of θ, the model parameters, which means they are tuned during model training.

2. Queries, keys and values

The Transformer then computes three vectors for each of the T vectors (each of row in the T × C matrix from the previous section): “query”, “key” and “value” vectors. This is done by way of three linear transformations (i.e., multiplying with a weight matrix):

The weight matrices that produce Q, K, V matrices are all part of θ.

The query, key and value vectors for each token are packed together into T × C matrices, just like the input embedding matrix. These vectors are the primary participants involved in the main event which is coming up shortly: self-attention.

NOTE

But what are these queries, keys and values? One way to think about these vectors is using the following analogy:

Imagine you have a database of images with their text descriptions, and you would like to build an image search engine. Users will search by providing some text (the query), which will be matched against the text descriptions in your database (the key), and the final result is the image itself (the value). Only those images (values) whose corresponding text description (key) best matches the user input (query) will be returned in the search results.

Self-attention works in a similar manner - the tokens in the user input are trying to “query” other tokens to find the ones they should be paying attention to.

3. Two heads are better than one

To recap: so far, starting from our input T × C matrix containing the token + position embeddings, we have computed three T × C matrices: (1) the Query matrix Q, (2) the Key matrix K, and (3) the Value matrix V.

The Transformer then splits these matrics into multiple so-called “heads”:

Here, we see the Q matrix being split length-wise into twelve (T × H) heads. Since Q has 768 columns, each head has 64 columns.

Self-attention operates independently within each head, and it does so in parallel. In other words, the first head of the Query matrix only interacts with the first heads of the Key and Value matrices. There is no interaction between different heads.

The idea behind splitting into multiple heads is to afford greater freedom to the Transformer to capture different characteristics of the input embeddings. E.g., the first head might specialize in capturing the part-of-speech relationships, and another might focus on semantic meaning, etc.

4. Time to pay attention

Self-attention, as we've alluded to earlier, is the core idea behind the Transformer model.

We first compute an “attention scores” matrix by multiplying the query and key matrices (note that we are only looking at the first head here, but the same operation occurs for all heads):

This matrix tells us how much attention, or weightage, a particular token needs to pay to every other token in the sequence for producing its output, i.e., prediction for the next token. E.g., the token "bring" has an attention score of 0.3 for the token "robot" (row 4, column 2 in matrix A1).

5. Applying attention

The attention score for a token needs to be masked if it occurs earlier in the sequence for a given target token. E.g., in our input phrase: “the robots will bring _____” it makes sense for the token “bring” to pay attention to the token “robots”, but not vice-versa, because a token should not be allowed to look to the future tokens for making a prediction of its next token.

So we hide the upper-right triangle of the square matrix A1, effectively setting the attention score to 0.

We then bring the third actor onto the stage, the Value matrix V:

The output for the token “robots” is a weighted sum of the Value vectors for the previous token “the” and itself. Specifically, in this case, it applies a 47% weight to the former, and 53% weight to its own Value vector (work out the matrix multiplication between A1(T × T) and V1(T × H) to convince yourself that this is true). The outputs for all other tokens is computed similarly.

5. Putting all heads together

The final output for each head of self-attention is a matrix Y of dimensions T × H (T = 4, H = 64).

Having computed the output embeddings for all tokens across all 12 heads, we now combine the individual T × H into a single matrix of dimension T × C, by simply stacking them side-by-side:

64 embedding dims per head (H)× 12 heads = 768, the original size of our input embeddings (C).

NOTE

Sections 2 - 5 collectively describe what happens in a single self-attention unit.

The input and output of a single self-attention unit are both T × C dimension matrices.

6. Feed forward

Everything we have done up to this point has involved only linear operations - i.e., matrix multiplications. This is not sufficient to capture complex relationships between tokens, so the Transformer introduces a single hidden-layer neural network (also referred to as a feed forward network, or a multi-layer perceptron (MLP)), with non-linearity.

The C length row vectors are transformed to vectors of length (4 * C) by way of a linear transform, a non-linear function like ReLU is applied, and finally the vectors are linearly transformed back to vectors of length C.

All the weight matrices involved in the feed forward network are part of θ.

7. We need to go deeper

All the steps in Sections 2 to 6 above constitute a single Transformer block. Each block takes as input a T × C matrix, and outputs a T × C matrix.

In order to arm our Transformer model with the ability to capture complex relationships between words, many such blocks are stacked together in sequence:

8. Making a prediction

Finally, we are ready to make a prediction:

The last output of the last block in the Transformer will give us a C length vector for each of our input tokens. Since we only care about what comes after the last token, “bring”, we look at its vector. A linear transform on this vector - multiplying with another weight matrix of dimensions V × C, where V is the total number of words in our dictionary - will give us a vector of length V.

This vector when normalized gives us a probability distribution over every word in our dictionary, which allows us to the pick the one with the highest probability as the next token.

In this case, our Transformer has assigned a probability of 92% on “prosperity” being the next token, while there's only a 10% chance of “destruction”, so our completed sentence now reads: “the robots will bring prosperity”. I suppose we can now rest easy with the knowledge that AI's imminent takeover of human civilization promises a future of prosperity and well-being, rather than death and destruction.

9. Text Generator Go Brrr

Now that we can predict the next token, we can generate text, one token at a time:

The first token that the model produces is added to the prompt and fed back into it to produce the second token, which is then fed back into it to produce the third, and so on. The Transformer has a limit on the maximum number of tokens (N) that it can take as input, and therefore as the number of generated tokens increases, eventually we need to either cap the number of tokens to keep only the last N, or devise some other technique for shortening the prompt without losing information from the oldest tokens.

And that's it!

If you made it all the way to the end, well done! I hope this was worthwhile and contributed in some way to your understanding of LLMs.

NOTE

To focus on the most important aspects of a Transformer's implementation, several details were intentionally left out. Some of these include: layer normalization, dropout, residual connections, etc.

I highly recommend reading the full nanoGPT implementation to get all the details.

This project was made with Manim, a Python library for mathematical animations, created by Grant Sanderson who runs the YouTube channel 3Blue1Brown (which you should definitely check out, BTW).

Share this article