Flash Attention rewrites the standard attention algorithm to minimize reads and writes to GPU High Bandwidth Memory (HBM). By tiling the query, key, and value matrices and keeping intermediate results in fast SRAM, it achieves exact attention in significantly less memory bandwidth — enabling much longer sequence lengths at the same hardware cost. This page covers Lecture 12 by Thomas Viehmann and Lecture 36 by Jay Shah (Flash Attention 3).Documentation Index
Fetch the complete documentation index at: https://mintlify.com/gpu-mode/lectures/llms.txt
Use this file to discover all available pages before exploring further.
The attention bottleneck
Self-attention computes a weighted sum of values, where weights come from the similarity between queries and keys:Flash Attention does not approximate attention. It computes the exact same result as standard attention, just with a different order of operations that avoids materializing the full N×N matrix.
Standard attention and its O(N²) memory problem
The naive attention algorithm proceeds in three distinct passes, each requiring a round-trip to HBM:Compute scores
Load Q and K from HBM, compute S = QKᵀ / √d, write S back to HBM. Cost: O(N²d) reads, O(N²) writes.
Compute softmax
Load S from HBM, compute P = softmax(S), write P back to HBM. Cost: O(N²) reads and writes.
Flash Attention: tiling and on-the-fly softmax
The key insight from Dao et al. (2022) is that softmax can be computed incrementally using a running maximum and normalizer. This means you never need the full row of scores in memory at once — you can process small tiles of K and V and update the output accumulator tile by tile. Running softmax decomposition: For a row of scores[s₁, s₂, ..., sₙ], maintain:
m: running max of scores seen so farl: running sum ofexp(sᵢ - m)O: running weighted sumΣ exp(sᵢ - m) · vᵢ
m_new:
O_new / l_new. This online algorithm is the mathematical foundation of Flash Attention.
Forward pass algorithm
The Flash Attention forward pass tiles the outer loop over query blocks (rows of Q) and the inner loop over key/value blocks:lecture_012/flash_attention.cu implements the online softmax update inline inside the K/V tile loop:
Backward pass with recomputation
The backward pass recomputes attention weights on the fly rather than storing the N×N matrix from the forward pass. Given the saved log-sum-exp valuesL = m + log(l), it can reconstruct P = exp(S - L) during the backward sweep.
This is a deliberate trade: the backward pass does more arithmetic (recomputing softmax), but it avoids the enormous memory cost of storing P. The recomputation overhead is small compared to the bandwidth savings.
IO complexity analysis
LetN = sequence length, d = head dimension, M = SRAM size (bytes).
| Algorithm | HBM reads/writes |
|---|---|
| Standard attention | O(Nd + N²) |
| Flash Attention | O(N²d² / M) |
M ≥ d² (which holds for any realistic SRAM size and head dimension), Flash Attention requires strictly fewer HBM accesses than standard attention. In practice, this translates to a 2–4× wall-clock speedup on A100 hardware for typical sequence lengths.
Memory usage drops from O(N²) to O(N): the N×N attention matrix is never materialized.
Flash Attention 2 and 3
Flash Attention 2 (Dao, 2023) improved on the original with:- Better work partitioning across warps (fewer unnecessary reads within a warp)
- Non-causal attention uses 2× fewer non-masked operations
- Sequences of different lengths in the same batch (variable-length batching)
- Exploits H100’s asynchronous data movement (TMA) to overlap GEMM and softmax
- Uses WGMMA (warpgroup matrix multiply) for 2× the throughput of WMMA
- FP8 attention support
- Achieves up to 75% of H100 theoretical peak FLOPS
Lecture 36 by Jay Shah covers the integration of CUTLASS and Flash Attention 3 on Hopper GPUs. The slides are in the
lecture_036/ folder.Using Flash Attention in PyTorch
PyTorch 2.0+ includes Flash Attention as the default backend forscaled_dot_product_attention:
Further reading
Lecture 12 source code
Thomas Viehmann’s CUDA implementation of Flash Attention forward pass
Lecture 36: Flash Attention 3
Jay Shah on CUTLASS integration and Hopper-specific optimizations
Flash Attention paper
Dao et al., 2022 — the original Flash Attention paper
Ring Attention
Extend Flash Attention across multiple GPUs for arbitrarily long sequences