The prefill benchmark measures how fast each attention implementation processes the full input prompt before any tokens are generated. Run it with:Documentation Index
Fetch the complete documentation index at: https://mintlify.com/Wenyueh/MinivLLM/llms.txt
Use this file to discover all available pages before exploring further.
What prefill does
During prefill, every token in the input sequence attends to every previous token. For a sequence of lengthN, this produces an N × N attention matrix. The cost of this step dominates short-context latency, and the memory required to hold that matrix determines which implementations can scale.
Implementations compared
PyTorch Standard
O(N²) memory. Materializes the full attention matrix in GPU global memory using standard
torch.matmul and torch.softmax. Works at any sequence length but memory usage grows quadratically.Naive Triton
O(N²) memory. A custom Triton kernel that loads the entire Q, K, V sequence into shared memory and computes the full attention matrix there. Limited to ≤128 tokens due to the GPU shared memory budget.
Flash Attention
O(N) memory. Tiled computation with online softmax: accumulates the output block by block without ever materializing the full N×N matrix. Memory footprint is proportional to sequence length, not sequence length squared.
Function signatures
(total_tokens, num_heads, head_dim) and a cu_seqlens offsets array that marks where each sequence starts and ends in the packed layout.
Benchmark configurations
The script runs four configurations to expose different regimes:num_seqs | seq_len | Total tokens | What it tests |
|---|---|---|---|
| 2 | 60 | 120 | Short sequences — kernel launch overhead dominates |
| 4 | 64 | 256 | Small batch at Naive Triton’s shared memory limit |
| 2 | 1024 | 2048 | Medium sequences — Naive Triton is skipped |
| 1 | 4096 | 4096 | Long sequences — Flash Attention’s efficiency advantage is largest |
torch.cuda.synchronize() to get accurate wall-clock measurements.
Shared memory constraint for Naive Triton
The Naive Triton kernel stores the entireBLOCK_SIZE × BLOCK_SIZE attention matrix in GPU shared memory. The memory cost in bytes is:
BLOCK_SIZE | Attention matrix | Status |
|---|---|---|
| 64 | 64 × 64 × 4 = 16 KB | Safe |
| 128 | 128 × 128 × 4 = 64 KB | Exceeds limit — results may be incorrect |
head_dim > 64 the kernel selects BLOCK_SIZE = 64. Sequences longer than BLOCK_SIZE are silently skipped at runtime.
Crossover analysis
The script includes afind_crossover_point() function that sweeps seq_len from 16 to 1024 (with num_seqs=2, num_heads=32, num_kv_heads=8, head_dim=128) to find exactly where Flash Attention overtakes Naive Triton:
Kernel launch analysis
Flash Attention’s grid has three dimensions; Naive Triton’s grid has two:BLOCK_M=32 and 32 heads:
| Implementation | Grid | Total kernels |
|---|---|---|
| Naive Triton | (2, 32) | 64 |
| Flash Attention | (2, 32, 2) | 128 |
This overhead becomes negligible at longer sequences where the compute time per kernel dwarfs the launch cost. At
seq_len=1024, Flash Attention processes each tile independently and with O(N) memory access patterns, which is why it wins decisively.Why Flash Attention wins at long sequences
O(N) memory access
Flash Attention never writes a full N×N matrix. Each tile of Q reads a tile of K and V from HBM once, accumulates into a running output, and discards the intermediate scores. Total HBM traffic is O(N) rather than O(N²).
Online softmax keeps precision
The kernel tracks a running maximum
m_i and a running normalizer l_i per row. When a new K/V tile arrives, it rescales the previous accumulator with alpha = exp(m_i_old - m_i_new) before adding the new contribution. The final output is divided by l_i once. No second pass over the data is needed.Shared memory is reused, not exhausted
Each tile is small enough to fit entirely in shared memory. The kernel loops over tiles, reusing the same shared memory budget for each one. Naive Triton allocates a slot for the entire N×N matrix upfront, which caps the maximum sequence length.