Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/Wenyueh/MinivLLM/llms.txt

Use this file to discover all available pages before exploring further.

The prefill benchmark measures how fast each attention implementation processes the full input prompt before any tokens are generated. Run it with:
uv run python benchmark_prefilling.py

What prefill does

During prefill, every token in the input sequence attends to every previous token. For a sequence of length N, this produces an N × N attention matrix. The cost of this step dominates short-context latency, and the memory required to hold that matrix determines which implementations can scale.

Implementations compared

PyTorch Standard

O(N²) memory. Materializes the full attention matrix in GPU global memory using standard torch.matmul and torch.softmax. Works at any sequence length but memory usage grows quadratically.

Naive Triton

O(N²) memory. A custom Triton kernel that loads the entire Q, K, V sequence into shared memory and computes the full attention matrix there. Limited to ≤128 tokens due to the GPU shared memory budget.

Flash Attention

O(N) memory. Tiled computation with online softmax: accumulates the output block by block without ever materializing the full N×N matrix. Memory footprint is proportional to sequence length, not sequence length squared.

Function signatures

def pytorch_standard_attention(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    cu_seqlens: torch.Tensor,
    scale: float,
    num_heads: int,
    num_kv_heads: int,
    head_dim: int,
) -> torch.Tensor:
    """Standard PyTorch attention - O(N²) memory"""
Both functions receive packed token tensors of shape (total_tokens, num_heads, head_dim) and a cu_seqlens offsets array that marks where each sequence starts and ends in the packed layout.

Benchmark configurations

The script runs four configurations to expose different regimes:
num_seqsseq_lenTotal tokensWhat it tests
260120Short sequences — kernel launch overhead dominates
464256Small batch at Naive Triton’s shared memory limit
210242048Medium sequences — Naive Triton is skipped
140964096Long sequences — Flash Attention’s efficiency advantage is largest
Each configuration runs a warmup pass before timing, and the timing loop uses torch.cuda.synchronize() to get accurate wall-clock measurements.

Shared memory constraint for Naive Triton

The Naive Triton kernel stores the entire BLOCK_SIZE × BLOCK_SIZE attention matrix in GPU shared memory. The memory cost in bytes is:
attention_matrix_bytes = BLOCK_SIZE² × 4  (float32)
With a GPU shared memory budget of roughly 48 KB per block:
BLOCK_SIZEAttention matrixStatus
6464 × 64 × 4 = 16 KBSafe
128128 × 128 × 4 = 64 KBExceeds limit — results may be incorrect
When head_dim > 64 the kernel selects BLOCK_SIZE = 64. Sequences longer than BLOCK_SIZE are silently skipped at runtime.
The Naive Triton kernel is automatically skipped for any configuration where seq_len > BLOCK_SIZE. You will see a “SKIPPED” message for the seq_len=1024 and seq_len=4096 runs.

Crossover analysis

The script includes a find_crossover_point() function that sweeps seq_len from 16 to 1024 (with num_seqs=2, num_heads=32, num_kv_heads=8, head_dim=128) to find exactly where Flash Attention overtakes Naive Triton:
Seq Len  |  Naive (ms)  |  Flash (ms)  |  Winner  |  Speedup
-----------------------------------------------------------------
     16  |       0.xxx  |       0.xxx  |   Naive  |  x.xxX
     32  |       0.xxx  |       0.xxx  |   Naive  |  x.xxX
     64  |       0.xxx  |       0.xxx  |   Flash  |  x.xxX   <-- crossover
    128  |         OOM  |       0.xxx  |   Flash  |  N/A
    ...
   1024  |         OOM  |       0.xxx  |   Flash  |  N/A
At short sequences Naive Triton wins because it launches fewer kernels and does less work per kernel. Once the sequence is long enough that the O(N²) matrix cost compounds, Flash Attention’s tiled approach wins — and beyond the shared memory limit, Flash Attention is the only option.

Kernel launch analysis

Flash Attention’s grid has three dimensions; Naive Triton’s grid has two:
# Naive Triton grid: one thread block per (sequence, head)
naive_grid = (num_seqs, num_heads)
naive_kernels = num_seqs * num_heads

# Flash Attention grid: one thread block per (query tile, head, sequence)
num_blocks_m = ceil(seq_len / BLOCK_M)
flash_grid = (num_blocks_m, num_heads, num_seqs)
flash_kernels = num_blocks_m * num_heads * num_seqs
For 2 sequences of 60 tokens each, with BLOCK_M=32 and 32 heads:
ImplementationGridTotal kernels
Naive Triton(2, 32)64
Flash Attention(2, 32, 2)128
Each kernel launch carries ~5–20 µs of fixed overhead. For 64 extra launches that is roughly 0.64–1.28 ms of extra latency — which is why Naive Triton can be faster at short sequences despite doing the same mathematical work.
This overhead becomes negligible at longer sequences where the compute time per kernel dwarfs the launch cost. At seq_len=1024, Flash Attention processes each tile independently and with O(N) memory access patterns, which is why it wins decisively.

Why Flash Attention wins at long sequences

1

O(N) memory access

Flash Attention never writes a full N×N matrix. Each tile of Q reads a tile of K and V from HBM once, accumulates into a running output, and discards the intermediate scores. Total HBM traffic is O(N) rather than O(N²).
2

Online softmax keeps precision

The kernel tracks a running maximum m_i and a running normalizer l_i per row. When a new K/V tile arrives, it rescales the previous accumulator with alpha = exp(m_i_old - m_i_new) before adding the new contribution. The final output is divided by l_i once. No second pass over the data is needed.
3

Shared memory is reused, not exhausted

Each tile is small enough to fit entirely in shared memory. The kernel loops over tiles, reusing the same shared memory budget for each one. Naive Triton allocates a slot for the entire N×N matrix upfront, which caps the maximum sequence length.
4

No sequence length ceiling

Because memory is O(N), Flash Attention can handle arbitrarily long sequences (subject only to total HBM capacity). Naive Triton silently skips sequences that exceed BLOCK_SIZE.

Build docs developers (and LLMs) love