Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/gpu-mode/lectures/llms.txt

Use this file to discover all available pages before exploring further.

Transformers process context through attention over a key-value (KV) cache, and that cache grows linearly with sequence length. At 100K tokens, a single KV cache for a 7B-parameter model can consume tens of gigabytes of memory — making long-context inference both slow and expensive. Guangxuan Xiao presents ScaleML Lecture 72, covering the techniques his research group has developed to extend context windows without paying the full linear cost. This lecture is part of the GPU Mode ScaleML Series.

The long-context challenge

Standard attention has two costs that both scale with sequence length NN:
  • Memory: the KV cache stores 2×N×dhead×nheads2 \times N \times d_{head} \times n_{heads} tensors per layer
  • Compute: attention scores are O(N2)O(N^2) in the sequence length
For a sequence of 128K tokens, the memory cost alone makes inference infeasible on a single GPU without algorithmic intervention. The challenge is reducing these costs while preserving the model’s ability to use long-range context.

KV cache size

For LLaMA-2-7B with a 128K token context: ~16 GB just for the KV cache, exceeding the usable VRAM of most consumer GPUs.

Attention compute

Quadratic attention complexity means doubling the context quadruples the attention FLOPs. Even with FlashAttention, the compute cost is real.

StreamingLLM: attention sinks and sliding window

StreamingLLM is Guangxuan Xiao’s work on enabling LLMs to operate on infinite-length streams without retraining or fine-tuning. The key insight comes from analyzing where attention mass concentrates.

Attention sinks

When you examine attention patterns in trained LLMs, a surprising pattern emerges: a disproportionate fraction of attention weight is concentrated on the first few tokens of the sequence, regardless of their semantic content. These initial tokens act as “attention sinks” — they absorb attention probability that would otherwise be spread across the full context.
import torch
import matplotlib.pyplot as plt

def visualize_attention_sinks(attn_weights):
    """
    attn_weights: [num_heads, seq_len, seq_len]
    Plots the mean attention received by each position.
    """
    # Mean over heads and query positions
    received = attn_weights.mean(dim=0).mean(dim=0)  # [seq_len]
    plt.bar(range(len(received)), received.cpu().numpy())
    plt.xlabel("Key position")
    plt.ylabel("Mean attention weight received")
    plt.title("Attention sink pattern")
    plt.show()
The attention sink phenomenon is not specific to any one model family. It appears in LLaMA, Mistral, Falcon, and other decoder-only transformers. The first token is particularly affected because it is always visible (never masked) to every subsequent token during causal pretraining.

Sliding window with sink tokens

StreamingLLM exploits this observation to build a streaming-compatible attention mechanism:
  1. Always keep the first ksinkk_{sink} tokens (the sinks) in the KV cache
  2. Keep the most recent kwindowk_{window} tokens in a sliding window
  3. Evict everything in between
class StreamingKVCache:
    def __init__(self, sink_size=4, window_size=512):
        self.sink_size = sink_size
        self.window_size = window_size
        self.sink_keys = None
        self.sink_values = None
        self.recent_keys = []
        self.recent_values = []

    def update(self, new_key, new_value):
        """
        new_key, new_value: [batch, heads, 1, head_dim]
        """
        if self.sink_keys is None and len(self.recent_keys) < self.sink_size:
            # Fill sink buffer first
            self.recent_keys.append(new_key)
            self.recent_values.append(new_value)
        else:
            if self.sink_keys is None:
                # Promote first tokens to sink
                self.sink_keys = torch.cat(self.recent_keys, dim=2)
                self.sink_values = torch.cat(self.recent_values, dim=2)
                self.recent_keys = []
                self.recent_values = []
            # Slide the window
            self.recent_keys.append(new_key)
            self.recent_values.append(new_value)
            if len(self.recent_keys) > self.window_size:
                self.recent_keys.pop(0)
                self.recent_values.pop(0)

    def get_cache(self):
        recent_k = torch.cat(self.recent_keys, dim=2)
        recent_v = torch.cat(self.recent_values, dim=2)
        if self.sink_keys is not None:
            return (
                torch.cat([self.sink_keys, recent_k], dim=2),
                torch.cat([self.sink_values, recent_v], dim=2),
            )
        return recent_k, recent_v
This gives StreamingLLM an O(1)O(1) memory footprint relative to sequence length, at the cost of losing tokens that fall outside the window.
StreamingLLM does not require any retraining or fine-tuning. It works with off-the-shelf pretrained models by modifying the attention mask to match the sink + window pattern. The position IDs of the retained tokens must be adjusted to preserve the positional encoding semantics.

SnapKV: selective KV cache compression

While StreamingLLM drops tokens based on recency, SnapKV takes a quality-aware approach: it identifies which KV entries are actually important for answering a query and keeps only those. The core observation is that attention patterns during prompt processing are predictive of which KV entries will matter during generation. SnapKV compresses the prompt’s KV cache before generation begins by:
1

Process the prompt normally

Run the full prefill pass to compute attention over the complete prompt. This produces standard KV entries for every prompt token.
2

Measure observation frequency

For each key position in the prompt, count how often it receives high attention weight across the last few query positions (a proxy for the query tokens that matter most).
3

Select important positions

Keep only the top-kk positions by observation frequency, plus a local window around the most recent tokens. Evict the rest.
4

Generate with compressed cache

Run autoregressive generation with the reduced KV cache. Memory per token is now bounded by kk rather than the full prompt length.
def snapkv_compress(keys, values, query_context, budget_per_layer):
    """
    keys, values: [batch, heads, prompt_len, head_dim]
    query_context: last few query vectors [batch, heads, context_len, head_dim]
    budget_per_layer: number of KV positions to keep
    """
    # Compute attention scores from observation window
    scores = torch.einsum(
        'bhqd,bhkd->bhqk', query_context, keys
    ) / keys.shape[-1] ** 0.5  # [batch, heads, context_len, prompt_len]

    # Aggregate importance across observation queries
    importance = scores.mean(dim=2)  # [batch, heads, prompt_len]

    # Select top-k positions per head
    _, top_indices = importance.topk(budget_per_layer, dim=-1)
    top_indices, _ = top_indices.sort(dim=-1)  # preserve order

    # Gather selected KV entries
    compressed_keys = keys.gather(
        2, top_indices.unsqueeze(-1).expand(-1, -1, -1, keys.shape[-1])
    )
    compressed_values = values.gather(
        2, top_indices.unsqueeze(-1).expand(-1, -1, -1, values.shape[-1])
    )
    return compressed_keys, compressed_values
SnapKV is most effective for long-prompt, short-generation tasks (e.g., document QA, summarization). For tasks where generation itself is long, the growing generation cache eventually dominates memory anyway.

Sparse attention patterns

Both StreamingLLM and SnapKV are instances of a broader family of sparse attention methods. Instead of computing full N×NN \times N attention, sparse methods restrict each query to attending over a structured subset of keys. The two canonical primitives are:
PatternDescriptionGood for
Local / sliding windowEach token attends to its nearest ww neighborsCapturing syntactic and local semantic structure
Global / landmarkA small set of special tokens attends to and is attended to by all positionsPropagating long-range information
StridedEvery kk-th position is included in the attention setEfficient long-range coverage
Sink + localStreamingLLM’s pattern: fixed sinks + sliding windowStreaming inference
Most practical sparse attention systems combine local and global patterns. For example, Longformer uses a sliding window with task-specific global tokens, and BigBird adds random attention on top.

KV cache eviction strategies

When operating under a memory budget, you must decide which KV entries to evict. The lecture surveys four strategies:
Keep the most recent kk tokens. Simple and predictable. Works well when relevant context is local. Fails for retrieval-heavy tasks where the answer is in the distant past.
Keep tokens that historically received high attention. Tracks which keys are “heavy hitters.” More expensive to maintain (requires running statistics) but significantly better for long-document tasks.
Train a small auxiliary network to predict which KV entries will be needed. Highest quality but requires training and adds inference overhead.
Always keep a recent window (recency) plus a fixed budget for globally important tokens (score-based). Combines the benefits of both. Used by SnapKV and several follow-up works.

Evaluation metrics for long-context

Evaluating long-context models requires benchmarks that actually require long-range reasoning, not ones that can be answered from a short local window.

RULER

A synthetic benchmark with controlled needle-in-a-haystack, variable tracking, and aggregation tasks at specified context lengths (4K to 128K). Tests whether a model can actually use distant context.

LongBench

A multi-task benchmark covering single-document QA, multi-document QA, summarization, few-shot learning, and code tasks. Real-world, not synthetic.

HELMET

A more recent benchmark emphasizing recall and multi-hop reasoning at very long contexts (32K–512K tokens).

SCROLLS

Summarization and QA over long documents, with average lengths in the tens of thousands of tokens.
Perplexity on long sequences is a misleading proxy for long-context ability. A model can achieve low perplexity by exploiting short-range patterns while completely ignoring distant tokens. Always evaluate on tasks that require long-range retrieval or reasoning.

Memory-compute tradeoffs

The lecture concludes with a unified view of the tradeoff space:
Full attention (baseline)
  Memory: O(N)  ← KV cache
  Compute: O(N²) ← attention scores
  Quality: best

StreamingLLM (sink + window)
  Memory: O(k)  ← fixed cache size
  Compute: O(N·k) ← linear in tokens generated
  Quality: good for streaming, poor for long-range retrieval

SnapKV (importance-based compression)
  Memory: O(k)  ← compressed prompt cache
  Compute: O(N·k) ← same as StreamingLLM after compression
  Quality: better than recency for retrieval-heavy tasks
  Overhead: one prefill pass + O(N·w) scoring

Sparse attention (architectural)
  Memory: O(N)  ← full KV cache, but sparse access
  Compute: O(N·s) ← s = sparsity budget per token
  Quality: depends on sparsity pattern and task
For production inference systems, SnapKV or a hybrid recency + importance strategy typically provides the best quality-per-memory tradeoff for document QA workloads. StreamingLLM is the right choice when you genuinely need infinite streaming (e.g., always-on agents processing continuous input).

Lecture references

Lecture 72 slides

ScaleML Lecture 72 slides by Guangxuan Xiao (StreamingLLM.pdf in the lecture_072 folder)

Guangxuan Xiao

Speaker homepage — research on efficient LLM inference

StreamingLLM paper

“Efficient Streaming Language Models with Attention Sinks” (Xiao et al., 2023)

GPU Mode YouTube

Full lecture recordings on the GPU Mode YouTube channel

Build docs developers (and LLMs) love