Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/Wenyueh/MinivLLM/llms.txt

Use this file to discover all available pages before exploring further.

The decoding benchmark measures the per-step cost of generating one new token. During decode, a single query vector attends to all cached K/V values from every previous position across non-contiguous physical memory blocks. Run it with:
uv run python benchmark_decoding.py

What paged attention decode does

Unlike prefill, which processes a full prompt at once, decode runs one step at a time. At each step:
  1. The model produces a single query vector q of shape (batch_size, num_heads, head_dim).
  2. That query must attend to all context_len cached key and value vectors.
  3. Those cached vectors are stored in a paged KV cache — a pool of fixed-size physical blocks that are allocated on demand and may be non-contiguous in memory.
The block table maps logical token positions to physical block indices:
# Lookup pattern used by all three implementations
physical_block_idx = block_tables[batch_idx][token_idx // block_size]
block_offset       = token_idx % block_size
Each implementation must resolve this indirection for every one of the context_len tokens before it can compute attention scores.

Implementations compared

Iterates over each sequence in the batch with a Python for loop. For each sequence it walks the block table block by block, gathers slices of k_cache and v_cache with torch.cat, pads all sequences to max_context_len, then runs standard torch.matmul attention.The Python-level loop and repeated torch.cat calls make this the slowest option, but it is the easiest to read and debug.
# Core gather loop (per sequence)
for block_idx in range(num_blocks_needed):
    block_id = block_tables[i, block_idx].item()
    seq_k_list.append(k_cache[block_id])  # (block_size, num_kv_heads, head_dim)
    seq_v_list.append(v_cache[block_id])
seq_k = torch.cat(seq_k_list, dim=0)[:seq_len]

Benchmark configurations

The script runs four configurations that span different batch size and context length regimes:
batch_sizeseq_lennum_iterationsWhat it tests
260100Small batch, short context
1512100Single sequence, medium context
1625650Large batch, medium context
4204820Small batch, long context
All configurations use num_heads=32, num_kv_heads=8, head_dim=128, and block_size=16. Ten warmup iterations run before timing begins, and torch.cuda.synchronize() fences each timed loop.
block_size=16 means each physical KV block holds 16 token positions. A sequence of 512 tokens requires 32 physical blocks, and the block table for that sequence has 32 entries that may point to non-contiguous locations in the cache pool.

Block table lookup mechanics

The block table is a 2-D integer tensor of shape (batch_size, max_num_blocks). A value of -1 indicates that a slot has not been allocated. The setup used by the benchmark assigns contiguous block IDs for simplicity:
# From setup_test_data()
max_num_blocks = ceil(seq_len / block_size)
total_blocks   = batch_size * max_num_blocks

# KV cache pool
k_cache = torch.randn(total_blocks, block_size, num_kv_heads, head_dim, ...)
v_cache = torch.randn(total_blocks, block_size, num_kv_heads, head_dim, ...)

# Block table: each row lists the physical block indices for that sequence
block_tables = torch.arange(total_blocks).reshape(batch_size, max_num_blocks)
In a production engine the block table entries are non-contiguous — blocks for a single sequence may be scattered across the physical pool — which is exactly why the lookup indirection exists.

Why CUDA chunk processing matters

The Triton kernel processes tokens in chunks of BLOCK_N rather than one at a time or all at once. This matters for two reasons:
1

Amortize block table lookups

Resolving block_tables[batch_idx][token_idx // block_size] is a random-access load from global memory. By batching BLOCK_N tokens per iteration, the kernel can reuse the same physical block pointer for all block_size tokens that fall within one block, reducing the number of expensive pointer chases.
2

Enable online softmax across the full context

Chunking lets the kernel apply the same running-max online softmax used by Flash Attention. A running maximum m_i and normalizer l_i are updated after each chunk, so the kernel never has to store all context_len raw scores simultaneously — only the current chunk’s BLOCK_N scores fit in registers at any time.
3

Keep register pressure manageable

Loading all context tokens at once would exhaust GPU registers for long sequences. Chunk-based processing bounds the per-iteration working set to BLOCK_N × head_dim elements regardless of how long the context is.
For long-context workloads (seq_len=2048) the Triton kernel’s advantage over Optimized PyTorch grows because the Python-side batch gather loop in the PyTorch implementations scales with batch_size, while the Triton kernel keeps all work on the GPU.

Build docs developers (and LLMs) love