The decoding benchmark measures the per-step cost of generating one new token. During decode, a single query vector attends to all cached K/V values from every previous position across non-contiguous physical memory blocks. Run it with:Documentation Index
Fetch the complete documentation index at: https://mintlify.com/Wenyueh/MinivLLM/llms.txt
Use this file to discover all available pages before exploring further.
What paged attention decode does
Unlike prefill, which processes a full prompt at once, decode runs one step at a time. At each step:- The model produces a single query vector
qof shape(batch_size, num_heads, head_dim). - That query must attend to all
context_lencached key and value vectors. - Those cached vectors are stored in a paged KV cache — a pool of fixed-size physical blocks that are allocated on demand and may be non-contiguous in memory.
context_len tokens before it can compute attention scores.
Implementations compared
- Naive PyTorch
- Optimized PyTorch
- Triton Kernel
Iterates over each sequence in the batch with a Python
for loop. For each sequence it walks the block table block by block, gathers slices of k_cache and v_cache with torch.cat, pads all sequences to max_context_len, then runs standard torch.matmul attention.The Python-level loop and repeated torch.cat calls make this the slowest option, but it is the easiest to read and debug.Benchmark configurations
The script runs four configurations that span different batch size and context length regimes:batch_size | seq_len | num_iterations | What it tests |
|---|---|---|---|
| 2 | 60 | 100 | Small batch, short context |
| 1 | 512 | 100 | Single sequence, medium context |
| 16 | 256 | 50 | Large batch, medium context |
| 4 | 2048 | 20 | Small batch, long context |
num_heads=32, num_kv_heads=8, head_dim=128, and block_size=16. Ten warmup iterations run before timing begins, and torch.cuda.synchronize() fences each timed loop.
block_size=16 means each physical KV block holds 16 token positions. A sequence of 512 tokens requires 32 physical blocks, and the block table for that sequence has 32 entries that may point to non-contiguous locations in the cache pool.Block table lookup mechanics
The block table is a 2-D integer tensor of shape(batch_size, max_num_blocks). A value of -1 indicates that a slot has not been allocated. The setup used by the benchmark assigns contiguous block IDs for simplicity:
Why CUDA chunk processing matters
The Triton kernel processes tokens in chunks ofBLOCK_N rather than one at a time or all at once. This matters for two reasons:
Amortize block table lookups
Resolving
block_tables[batch_idx][token_idx // block_size] is a random-access load from global memory. By batching BLOCK_N tokens per iteration, the kernel can reuse the same physical block pointer for all block_size tokens that fall within one block, reducing the number of expensive pointer chases.Enable online softmax across the full context
Chunking lets the kernel apply the same running-max online softmax used by Flash Attention. A running maximum
m_i and normalizer l_i are updated after each chunk, so the kernel never has to store all context_len raw scores simultaneously — only the current chunk’s BLOCK_N scores fit in registers at any time.