BLOG ART LLMs Context Windows

Memory Consumption and Limitations in LLMs with Large Context Windows, Part II

Part II: Tokens, Embeddings, and Memory

This post is the second in a series where we will explore the limits of large language models (LLMs) with respect to memory overhead and context windows. In the first post, we covered what an LLM is and defined some terminology relating to LLMs. In this post, we’ll dive deeper into how LLMs turn input text into tokens and then embeddings. We’ll use this information to understand how the length of the input context window affects the memory requirements when prompting the LLM.

 

Tokenization

Let’s start with an example prompt that you might send into an LLM: “Compare storytelling techniques in novels and in films.”. Your prompt’s input text needs to be turned into a form that is usable by the LLM. The first step in this process is changing the set of words that make up the prompt into a set of tokens that represent the words. The process is performed by the tokenizer.

Tokenizers can be as simple or complex as desired, but they all do the same job of turning raw text into numbers. A simple word-based tokenizer might split an input string into parts by delimiting the string by spaces, similar to thesplit() function in Python. Using our example prompt from above, the tokenized result would be an array that looks like this:

 

text = "Compare storytelling techniques in novels and in films."
print(text.split())

['Compare', 'storytelling', 'techniques', 'in', 'novels', 'and', 'in', 'films.']

 

The unique tokens that appear make up the model’s vocabulary. Each of these tokens get assigned an ID starting from 0. These IDs are used by the model to identify each word. In our example, the IDs would look like this:

 

[0, 1, 2, 3, 4, 5, 3, 6]

 

Notice the ID 3 is listed twice because the word “in” is used twice in the prompt we’re tokenizing. 

One limitation of this simple word-based tokenization is punctuation. If a prompt ends with a question mark instead of a period, the word-based tokenizer assumes that the final token (In our example, “films?” versus “films.”) is different and would get a different ID. Extrapolating this to a model with a vocabulary that contains a large fraction of the English language would vastly increase the size of the vocabulary if you were taking all of the common punctuation symbols into account for any given word. 

For many LLMs, a more complex tokenization is used called byte pair encoding[1]. Byte pair encoding is a data compression algorithm that iteratively replaces the most frequent pair of bytes in a sequence with a single, unused byte. In Natural Language Processing, it is often generalized to sequences of characters (like the words, numbers, and punctuation of a statement or question) instead of bytes. GPT models have historically used byte pair encoding as a part of their tokenization process. OpenAI provides an online tool and a Python package that can be used to estimate how many tokens a body of text might be tokenized into. 

Here’s an example of a more complex tokenization schema utilizing OpenAI’s online tool:

Embeddings

Once you have your raw text tokenized, you need to convert the tokenized text into word embeddings. Word embeddings are high-dimensional, real-valued vectors that encode the word meaning in such a way that vectors associated with similar words are expected to appear closer to each other in the vector space. The integer ID of each tokenized word is used to look up its embedding vector from an embedding table (or matrix). This table is a key component of the model architecture. Each row of this table corresponds to a unique token in the model’s vocabulary and is a vector in \(N\)-dimensional space, where \(N\) is the size of the embeddings. The dimensionality of the embeddings can vary depending on the specific model implementation, but they are usually on the order of 100-10,000 dimensions per embedding.

In addition to word embeddings, transformer models are reliant on positional embeddings. Positional embeddings capture information about the position of a word within a sequence. This is crucial for models like the transformer, which do not inherently understand the concept of sequence order. Positional embeddings are added to word embeddings, allowing the model to consider both the word’s meaning and its position in the sequence when making predictions. In transformer models, the word embedding and the positional embedding for each token are usually summed together element-wise. The resulting sum is then fed into the model. Mathematically, if $w_i$ is the word embedding for the \( i^{th} \) token and \( p_i \) is its positional embedding, the input \( x_i \) to the transformer would be \( x_i = w_i + p_i \). Both word embeddings and positional embeddings typically have the same number of dimensions (\( d \)) in transformer models.

The number of dimensions has a correlation to the memory overhead required by the model. For the model itself, the memory requirement for storing the word embedding matrix scales linearly with both the vocabulary size (\(N\)) and the embedding dimension (\(d\)). 

The proportionality \( \text{Memory} \propto N \times d \) captures this relationship.

The memory footprint of running an LLM is of significant concern when it comes to sending prompts into them and getting a usable response. As we discussed in the first section, each prompt you send into the LLM gets tokenized and the collection of all of the unique tokens makes up the model’s vocabulary. The model’s vocabulary has to be completely loaded into memory to prompt it. In addition to the model’s vocabulary, the prompt will be tokenized and turned into embeddings that get input into the layers of the neural network that make up the model. This concludes a fairly comprehensive overview of how raw text gets transformed into a form that can be operated upon by the model. 

Attention

Once the model has your prompt ingested in a usable format, the embeddings get passed through many layers that make up the core architecture of the model. Layers generally refer to the sequence of identical blocks that make up the encoder or decoder portions of the model. Layers can have a lot of different forms and functions, but most of them are out of the scope of these posts. The most important layers when it comes to the memory footprint of the model are the attention layers. Attention is the mechanism that allows the model to consider other words in the input when encoding a particular word. Attention keeps the model’s train of thought “on track” and prevents models generating incoherent text when presented with a long context input. For more detailed information on attention, see [2] and [3]. 

The attention mechanism is the primary contributor to the memory footprint of the model. Let’s dive into the reasons behind this:

Consider a sequence of tokens \( X = [x_1, x_2, \ldots, x_n] \), where \( n \) is the sequence length. The input embeddings for each token are transformed into Query (\( Q \)), Key (\( K \)), and Value (\( V \)) matrices using learned weight matrices \( W_Q, W_K, W_V \):

\begin{eqnarray*} Q = X \cdot W_Q \\ K = X \cdot W_K \\ V = X \cdot W_V \end{eqnarray*}

In the transformer model, the learned weight matrices \( W_Q, W_K, \) and \( W_V \) are trainable parameters. They are initialized, usually randomly following some distribution, at the start of training and then updated during backpropagation. These matrices serve to transform the input token embeddings into the Query, Key, and Value spaces. These are what people are talking about when they refer to “the weights” of a LLM.

The weight matrices \( W_Q, W_K, \) and \( W_V \) each have dimensions \( d_{\text{model}} \times d_k \) for Query and Key, and \( d_{\text{model}} \times d_v \) for Value, where:

\begin{equation*} \begin{split} \( d_{\text{model}} \) \text{is the dimensionality of the input token embeddings.} \\ \( d_k \) \text{is the dimensionality of the Key (and Query) vectors.} \\ \( d_v \) \text{is the dimensionality of the Value vectors.} \end{split} \end{equation*}

So if you have an input sequence \( X \) with shape \( (n, d_{\text{model}}) \), then after the linear transformations, you get:

\begin{equation*} \begin{split} \( Q \) with shape \( (n, d_k) \) \\ \( K \) with shape \( (n, d_k) \)\\ \( V \) with shape \( (n, d_v) \) \end{split} \end{equation*}

The attention score between a query and a key is calculated as the dot product of the query and key, which is then scaled down by the square root of the dimension \( d_k \) of the key vectors:

\[ \text{Attention}(Q, K, V) = \text{softmax}\left( \frac{QK^T}{\sqrt{d_k}} \right) \cdot V \]

The details of the equations aren’t especially important to make the connection between memory footprint and the attention mechanism, so we won’t go into them very much. The point you should gather from all the math is this: The attention mechanism requires storing multiple matrices in memory, the size of which depends on the sequence length \( n \). Specifically, the \( Q \), \( K \),  and \( V \) matrices each have dimensions \( (n, d_k) \) or \( (n, d_v) \). Long sequence (AKA long context), large matrix. 

Furthermore, you’ll have to store the attention scores, which involve a dot product between the Query and Key matrices. This results in an attention matrix of shape \((n, n)\), and this needs to be stored in memory before the softmax operation and subsequent multiplications with the Value matrix. 

Let’s put some real numbers to work. If we assume that the model uses float32 numbers (4 bytes per number), the memory required for the attention scores for one head would be:

\[ \text{Memory}_{\text{attention scores}} = n \times n \times 4 \text{ bytes} \]

Considering \( Q \), \( K \), \( V \), and the attention scores, the total additional memory \( M \) required for the attention operation would be roughly:

\[ M = 2 \times n \times n \times 4 \text{ bytes} + 3 \times n \times d_k \times 4 \text{ bytes} \]

The factor of 2 for \( n \times n \) matrices accounts for both the raw attention scores and the softmax-normalized scores. The factor of 3 for \( n \times d_k \) matrices accounts for \( Q \), \( K \), \( V \).

For the sake of illustration, let’s assume \( d_k = 64 \).

– For \( n = 1,000 \) tokens:

\[ M = 2 \times 1000 \times 1000 \times 4 \text{ bytes} + 3 \times 1000 \times 64 \times 4 \text{ bytes} \]

\[ M = 8000000 + 768000 = 8768000 \text{ bytes} \approx 8.7 \text{MB} \]

– For \( n = 10,000 \) tokens:

\[ M = 2 \times 10000 \times 10000 \times 4 \text{ bytes} + 3 \times 10000 \times 64 \times 4 \text{ bytes} \]

\[ M = 800000000 + 7680000 = 807680000 \text{ bytes} \approx 807 \text{MB} \]

– For \( n = 100,000 \) tokens:

\[ M = 2 \times 100000 \times 100000 \times 4 \text{ bytes} + 3 \times 100000 \times 64 \times 4 \text{ bytes} \]

\[ M = 80000000000 + 76800000 = 80076800000 \text{ bytes} \approx 80 \text{GB} \]

– For \( n = 1,000,000 \) tokens:

\[ M = 2 \times 1000000 \times 1000000 \times 4 \text{ bytes} + 3 \times 1000000 \times 64 \times 4 \text{ bytes} \]

\[ M \approx 8 \text{TB} \]

– For \( n = 1,000,000,000 \) tokens:

\[ M = 2 \times 1000000000 \times 1000000000 \times 4 \text{ bytes} + 3 \times 1000000000 \times 64 \times 4 \text{ bytes} \]

\[ M \approx 8000 \text{PB}!! \]

These calculations do not take into effect that most transformer architectures employ a multi-headed attention approach, which introduces another multiplicative factor into both sides of the addition in the memory calculation. Even disregarding multi-headed attention, the memory cost is prohibitive, especially for anyone hoping to run long-context models on their own hardware.

There are optimizations to be had here, and it’s a very active field of study. In the next blog post, we will discuss more on optimizations to the attention mechanism.

Want to chat more on this topic (or anything tech)? Connect with a member of our team. (We love this stuff!)

We're building an AI-powered Product Operations Cloud, leveraging AI in almost every aspect of the software delivery lifecycle. Want to test drive it with us? Join the ProdOps party at ProdOps.ai.