Documentation Index
Fetch the complete documentation index at: https://mintlify.com/Wenyueh/MinivLLM/llms.txt
Use this file to discover all available pages before exploring further.
Flash Attention is the attention algorithm used by miniVLLM during the prefill phase. It produces numerically identical results to standard scaled dot-product attention but accesses GPU high-bandwidth memory (HBM) in O(N) passes instead of O(N²), making it fast enough to handle long sequences that would otherwise be bottlenecked by memory bandwidth.
The memory problem with standard attention
Standard attention materializes the full N×N score matrix in HBM:
# Standard PyTorch — O(N²) memory
attn_scores = torch.matmul(q_seq, k_seq.transpose(1, 2)) * scale # (H, N, N)
attn_probs = torch.softmax(attn_scores, dim=-1)
out_seq = torch.matmul(attn_probs, v_seq)
For a 4096-token sequence with 32 heads and float16, the score matrix alone occupies roughly 1 GB. HBM bandwidth, not arithmetic throughput, becomes the bottleneck.
Flash Attention’s solution: tiled computation
Flash Attention never materializes the full N×N matrix. Instead it processes Q in horizontal tiles of BLOCK_M rows and streams K, V in vertical tiles of BLOCK_N columns. The softmax denominator is maintained incrementally using an online softmax accumulator — only O(BLOCK_M) values are live in shared memory at once.
Query tiles (BLOCK_M rows each)
┌──────┐
│ Q₀ │ ──→ streams over all K, V tiles and accumulates output
├──────┤
│ Q₁ │ ──→ same
└──────┘
Key / Value tiles (BLOCK_N columns each)
┌───┬───┬───┬───┐
│K₀ │K₁ │K₂ │...│
└───┴───┴───┴───┘
HBM reads scale with sequence length N (one full pass over K and V per Q tile) rather than N².
The flash_attention_prefill function
flash_attention_prefill in layers/attention.py is the Python entry point:
def flash_attention_prefill(
q: torch.Tensor, # (total_tokens, num_heads, head_dim)
k: torch.Tensor, # (total_tokens, num_kv_heads, head_dim)
v: torch.Tensor, # (total_tokens, num_kv_heads, head_dim)
cu_seqlens: torch.Tensor, # cumulative sequence lengths
scale: float,
num_heads: int,
num_kv_heads: int,
head_dim: int,
) -> torch.Tensor: # (total_tokens, num_heads, head_dim)
All sequences in the batch are concatenated into a single flat tensor. cu_seqlens is the array of cumulative lengths that tells the kernel where each sequence starts and ends — for example, [0, 512, 1024, 1536] represents three sequences of 512 tokens each.
Block size selection
Shared memory usage grows with BLOCK_M × head_dim (for Q) and BLOCK_N × head_dim (for K, V). The kernel picks conservative tile sizes to stay within the ~48 KB shared memory limit:
if head_dim <= 64:
BLOCK_M, BLOCK_N = 64, 64
elif head_dim <= 128:
BLOCK_M, BLOCK_N = 32, 32
else:
BLOCK_M, BLOCK_N = 16, 16
Grid layout
grid = (triton.cdiv(max_seq_len, BLOCK_M), num_heads, num_seqs)
Each Triton program processes one BLOCK_M-row tile of Q, for one attention head, in one sequence. Programs for different sequences and heads are dispatched simultaneously.
The kernel: online softmax
The key to Flash Attention is updating the running max m_i and normalizer l_i as new K tiles arrive, rescaling the accumulated output acc accordingly.
# flash_attention_varlen_kernel — attention.py
# Per-row state (size BLOCK_M)
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # running sum of exp scores
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - 1e10 # running maximum
acc = tl.zeros([BLOCK_M, head_dim], dtype=tl.float32) # output accumulator
for block_n in range(num_blocks):
# ... load K tile, compute QK^T ...
qk = tl.dot(q, k) * scale
# causal mask applied here
# online softmax update
m_ij = tl.max(qk, axis=1) # max over this K tile
m_i_new = tl.maximum(m_i, m_ij) # new global max
alpha = tl.exp(m_i - m_i_new) # rescale factor for old accumulator
p = tl.exp(qk - m_i_new[:, None]) # softmax numerators (this tile)
acc = acc * alpha[:, None] # rescale previous output
# ... load V tile ...
acc = acc + tl.dot(p.to(v.dtype), v) # accumulate weighted values
l_i = l_i * alpha + tl.sum(p, axis=1) # update normalizer
m_i = m_i_new
# final normalization
acc = acc / l_i[:, None]
alpha = exp(m_i - m_i_new) is always ≤ 1. It corrects the previously accumulated output and normalizer for the updated maximum, maintaining numerical stability throughout.
Variable-length sequence support
A batched prefill processes sequences of different lengths in one kernel launch. Rather than padding all sequences to the same length, miniVLLM passes cu_seqlens — a cumulative-length tensor — so each program can find its own sequence boundary:
seq_start = tl.load(cu_seqlens_q_ptr + seq_idx)
seq_end = tl.load(cu_seqlens_q_ptr + seq_idx + 1)
seq_len = seq_end - seq_start
# early exit if this Q tile is beyond the sequence
if start_m * BLOCK_M >= seq_len:
return
This avoids wasted compute on padding tokens.
Grouped Query Attention (GQA)
Models like Qwen3 use fewer KV heads than query heads. Each query head maps to a KV head by integer division:
kv_head_idx = off_h // (num_heads // num_kv_heads)
For example, with num_heads=32 and num_kv_heads=8, query heads 0–3 all read from KV head 0, query heads 4–7 from KV head 1, and so on. This halves or quarters the KV cache memory footprint without any additional code paths.
When flash attention is used
Flash attention is used only during prefill. The Attention.forward method selects the algorithm based on context.is_prefill:
if context.is_prefill:
o = flash_attention_prefill(
q, k, v, cu_seqlens, scale,
self.num_heads, self.num_kv_heads, self.head_dim
)
else:
o = paged_attention_decode(
q, k_cache, v_cache, block_tables, context_lens,
scale, self.num_heads, self.num_kv_heads, self.head_dim, self.block_size
)
During decode only one new token is generated per sequence per step, so the entire KV context lives in the paged cache. Flash Attention’s tiling advantage only applies when attending over a sequence being processed for the first time.
Flash Attention starts outperforming naive Triton at roughly sequence length 64–128 for head_dim=128. Below that threshold, the extra kernel launches from finer tiling add more overhead than the HBM savings recover. The benchmark_prefilling.py script measures this crossover point empirically.