⬅ Home

KV Caching in Attention

Large Language Models (LLMs) like ChatGPT generate text one token (word or sub-word unit) at a time. At the heart of this generation process lies the "attention mechanism," which allows the model to weigh the importance of different words in the input sequence when producing the next word. While powerful, calculating attention for each new token can be computationally intensive if we do it naively. This is where KV caching comes in as a crucial optimization.

The Challenge: Repetitive Work in Attention

Imagine an LLM has already generated "The cat sat on the" and is now trying to predict the next word, say "mat." To do this, the attention mechanism needs to compute Query (Q), Key (K), and Value (V) vectors for all tokens in the current sequence ("The", "cat", "sat", "on", "the", and the new token being processed).

Without KV caching, every time we generate a new token, the model would recompute the K and V vectors for *all* preceding tokens in the sequence. For "mat", it would calculate K and V for "The", "cat", "sat", "on", "the". When it then tries to predict the token *after* "mat", it would again recompute K and V for "The", "cat", "sat", "on", "the", "mat". This is incredibly wasteful, especially for long sequences!

KV Caching to the Rescue!

KV caching elegantly solves this problem. The core idea is simple: store (cache) the Key (K) and Value (V) vectors for all previously processed tokens.

When generating a new token (let's call this token T):

This way, we avoid redundant computations for K and V vectors of past tokens, significantly speeding up the process.

Why This Works: The Math Doesn't Lie

You might wonder why we can just reuse old K and V vectors. The reason lies in how they are calculated. For any token i in the sequence, its Key vector Ki is typically computed as K_i = Embedding_i @ W_K, where Embedding_i is the input embedding of token i, and W_K is a weight matrix learned by the model. A similar calculation happens for the Value vector Vi (V_i = Embedding_i @ W_V).

When we are generating the T-th token, we have a new input, Embedding_T. This is used to calculate the *newest* Key (KT) and Value (VT).

Crucially, the K and V vectors for all previous tokens (K1...KT-1 and V1...VT-1) were calculated based on *their own* input embeddings (Embedding_1...Embedding_T-1) and the same weight matrices (W_K, W_V). The arrival of a new token Embedding_T doesn't change the previous embeddings, nor does it change the weight matrices. Therefore, the K and V vectors for all prior tokens remain valid and can be directly reused from the cache. We only need the latest input to produce the latest Key and Value pair.

A Closer Look: Tensor Transformations

Let's break down the process step-by-step, comparing a full pass (no cache) with an incremental pass (with KV cache). We're assuming we are generating the T-th token.

Step Without KV Cache (Full Pass for token T) With KV Cache (Incremental Decode for token T)
Input to Attention Full sequence embedding: (B, T, d_model) Single new token's embedding: (B, 1, d_model)
1. Generate Q, K, V q, k, v shapes are (B, nh, T, hs) q_new: (B, nh, 1, hs)
k_new, v_new: (B, nh, 1, hs)
2. Retrieve from Cache N/A k_cache, v_cache are retrieved.
Shape: (B, nh, T-1, hs)
3. Concatenate K and V N/A (k and v are already full) k = concat(k_cache, k_new) → (B, nh, T, hs)
v = concat(v_cache, v_new) → (B, nh, T, hs)
4. att_scores = q @ k.transpose(-2, -1) (B,nh,T,hs) @ (B,nh,hs,T)
Result: (B, nh, T, T)
(B,nh,1,hs) @ (B,nh,hs,T)
Result: (B, nh, 1, T)
5. Apply Causal Mask A full (T, T) causal mask is applied (e.g., `masked_fill`). No explicit (T,T) masking needed for this step. The `1` dimension in `q_new` ensures its attention scores are for the current token attending to all T tokens (past and current). Causality is inherently handled by only feeding previous tokens' K/V.
6. Softmax (on att_scores) Applied to the (B, nh, T, T) tensor. Applied to the (B, nh, 1, T) tensor.
7. Output y = att_probs @ v (B,nh,T,T) @ (B,nh,T,hs)
Result: (B, nh, T, hs)
(B,nh,1,T) @ (B,nh,T,hs)
Result: (B, nh, 1, hs)

Notice the critical difference in step 4 and 7. With KV caching, the query q_new has a sequence length of 1. This means we are only calculating attention scores for the current token with respect to all keys. The resulting attention output y also has a sequence length of 1, representing the context-aware information for the current prediction.

From Attention to the Next Word

After the attention mechanism (using KV caching) produces its output for the current token, which has a shape like (B, nh, 1, hs), this output undergoes further transformations:

  1. It's typically reshaped or projected back to the model's main embedding dimension, resulting in a tensor of shape (B, 1, d_model). This vector is the rich, context-aware representation of the token we are about to predict.
  2. This single vector then passes through the subsequent Feed-Forward Network (FFN) layer and Layer Normalization within the Transformer block.
  3. Finally, the output from the last Transformer block (still representing a single token prediction, so shape (B, 1, d_model)) is passed to a linear layer (often called the "language model head" or lm_head). This layer projects it from d_model to the vocabulary size (vocab_size).
  4. The result is a tensor of logits of shape (B, 1, vocab_size). These logits are raw, unnormalized scores for every possible token in the vocabulary.
  5. A softmax function is applied to these logits to convert them into probabilities.
  6. A sampling strategy (like greedy sampling, top-k sampling, or nucleus sampling) is then used to select the actual next token from this probability distribution.

And the cycle repeats for the next token, again leveraging the (now updated) KV cache!

The Big Wins: Benefits of KV Caching

KV caching isn't just a minor tweak; it's a fundamental optimization for LLM inference:

Conclusion

KV caching is a cornerstone of efficient autoregressive inference in Transformer-based Large Language Models. By intelligently reusing previously computed Key and Value states, it dramatically reduces computational overhead, allowing for faster text generation and the ability to handle much longer sequences. Understanding KV caching is key to appreciating how modern LLMs can generate extensive and coherent text so rapidly.