Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/gpu-mode/lectures/llms.txt

Use this file to discover all available pages before exploring further.

Ring Attention solves one of the fundamental constraints of transformer scaling: as sequence length grows, even Flash Attention’s O(N) memory per GPU becomes too large to fit on a single device. Ring Attention distributes the sequence across a ring of GPUs, allowing each device to hold only a fraction of the sequence while collectively computing exact attention. This page is based on Lecture 13 by Andreas Koepf.

The long-sequence problem

Flash Attention reduces attention’s memory from O(N²) to O(N) — a massive improvement. But O(N) still grows linearly. At sequence length 128K with head dimension 128, a single attention head requires roughly 64 MB of activations per GPU just for Q, K, V, and the output. For a 70-billion-parameter model with 64 heads, that is over 4 GB per layer, per forward pass. The GPU memory wall means there is a hard upper bound on sequence length for single-GPU attention. Ring Attention breaks that ceiling.
Ring Attention computes exact attention — not approximate. The key insight is that Flash Attention’s tiling approach generalizes naturally to a distributed setting where different GPUs hold different tiles of K and V.

How Ring Attention distributes the sequence

Ring Attention assigns each GPU a contiguous chunk of the sequence:
  • GPU 0 holds tokens [0, N/P)
  • GPU 1 holds tokens [N/P, 2N/P)
  • GPU P-1 holds tokens [(P-1)N/P, N)
where P is the number of GPUs. Each GPU owns the Q, K, and V projections for its chunk. To compute full attention, every GPU needs to attend to K and V from all other GPUs — not just its own chunk. Ring Attention does this by passing K/V blocks around the ring while computing locally.

The ring communication pattern

The algorithm runs for P rounds. In each round, every GPU:
  1. Computes attention between its local Q chunk and the current K/V chunk it holds
  2. Updates its local Flash Attention accumulators (running max, normalizer, output)
  3. Sends the current K/V chunk to the next GPU in the ring
  4. Receives a new K/V chunk from the previous GPU
After P rounds, every GPU has attended to every K/V chunk in the full sequence, and the Flash Attention online softmax merge rule combines the partial results correctly.
Round 0: GPU 0 has KV[0], GPU 1 has KV[1], GPU 2 has KV[2], GPU 3 has KV[3]
         All GPUs compute local Q @ local KV

Round 1: GPUs rotate KV chunks (GPU 0 ← KV[3], GPU 1 ← KV[0], ...)
         All GPUs compute local Q @ new KV, merge with accumulators

Round 2: Another rotation
...

Round P-1: Final rotation, all merges complete
The ring topology ensures each K/V chunk travels exactly P-1 hops and visits each GPU exactly once. Point-to-point bandwidth between adjacent GPUs is the only communication primitive required.

Combining with Flash Attention tiling

Ring Attention works at the level of Flash Attention’s outer loop. Each GPU runs Flash Attention for its Q chunk, but instead of iterating over local K/V tiles in the inner loop, it iterates over K/V chunks from the ring:
# Pseudocode: Ring Attention forward pass on one GPU
m = -inf   # running max (one value per query position)
l = 0      # running normalizer
O = 0      # running output accumulator

for step in range(num_gpus):
    # Get K/V chunk for this step (received from neighbor or local)
    kv_chunk = receive_kv_or_use_local(step)

    # Run Flash Attention inner loop over this K/V chunk
    # Updates m, l, O using online softmax merge
    m, l, O = flash_attention_inner(Q_local, kv_chunk, m, l, O)

    # Send K/V chunk to the next GPU while computing
    send_kv_to_next(kv_chunk)

# Final output: O / l
output = O / l
The online softmax merge from Flash Attention works identically across the ring steps. When a new K/V chunk arrives from a neighboring GPU, you rescale the existing accumulator by exp(m_old - m_new) before adding the new partial result — the same update rule used within a single GPU’s tile loop.

Memory scaling

The memory cost on each GPU scales as:
QuantityMemory per GPU
Q, K, V projectionsO(N/P · d)
Flash Attention accumulatorsO(N/P)
K/V communication bufferO(N/P · d) (one chunk in flight)
TotalO(N/P)
Each GPU’s memory use is O(N/P) — exactly P times less than single-GPU Flash Attention. Adding GPUs linearly extends the maximum sequence length you can handle at the same per-GPU memory budget.
Ring Attention’s communication volume is O(N · d · P) total (each K/V chunk traverses the full ring). This is the same as a standard all-gather of K and V, but pipelined with compute. Bandwidth-bound workloads may see diminishing returns as P grows.

Implementation with NCCL

Ring Attention uses NCCL’s point-to-point primitives (ncclSend / ncclRecv) to pass K/V chunks between adjacent GPUs. The critical optimization is overlapping communication with computation: while one GPU computes attention for the current K/V chunk, it simultaneously sends that chunk to the next GPU.
import torch
import torch.distributed as dist

def ring_attention_step(Q_local, KV_local, rank, world_size):
    # Allocate buffers for send and receive
    send_buf = KV_local.clone()
    recv_buf = torch.empty_like(KV_local)

    results = []
    current_kv = KV_local

    for step in range(world_size):
        # Compute attention with current KV (overlaps with send/recv below)
        result = flash_attention_local(Q_local, current_kv)
        results.append(result)

        if step < world_size - 1:
            # Send to next rank, receive from previous rank
            next_rank = (rank + 1) % world_size
            prev_rank = (rank - 1) % world_size

            send_op = dist.P2POp(dist.isend, send_buf, next_rank)
            recv_op = dist.P2POp(dist.irecv, recv_buf, prev_rank)
            dist.batch_isend_irecv([send_op, recv_op])

            send_buf = current_kv
            current_kv = recv_buf
            recv_buf = torch.empty_like(KV_local)

    return merge_flash_attention_results(results)
True overlap between compute and communication requires CUDA streams. The K/V send should be launched on a separate CUDA stream from the Flash Attention kernel so both operations proceed concurrently.

Practical use cases

Ring Attention is well-suited to tasks that require long context at training or inference time:
  • Long-context LLMs: models trained on books, codebases, or long documents (e.g., 128K–1M token contexts)
  • Multi-modal inputs: combining high-resolution image tokens with text in a single sequence
  • Scientific sequences: genomic data, protein structure prediction, or long time-series
  • Sliding-window hybrid: Ring Attention can be combined with local/sparse attention patterns to reduce communication while keeping global context

Causal masking considerations

For decoder-only models with causal masking, tokens in a chunk only attend to tokens in the same or earlier chunks. This means GPUs holding later chunks do less work (they mask out future tokens). Load balancing across the ring becomes uneven for causal attention. A common fix is to assign chunks in an interleaved or zigzag pattern so that every GPU handles both early and late tokens, balancing the masked-out fraction.

Further reading

Lecture 13 slides

Andreas Koepf’s Ring Attention slides

Flash Attention

The single-GPU tiling algorithm that Ring Attention builds on

GPU Collectives & NCCL

Lecture 17: collective communication primitives used for ring passes

Ring Attention paper

Liu et al., 2023 — Ring Attention with Blockwise Transformers

Build docs developers (and LLMs) love