Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/gpu-mode/lectures/llms.txt

Use this file to discover all available pages before exploring further.

Flash Attention rewrites the standard attention algorithm to minimize reads and writes to GPU High Bandwidth Memory (HBM). By tiling the query, key, and value matrices and keeping intermediate results in fast SRAM, it achieves exact attention in significantly less memory bandwidth — enabling much longer sequence lengths at the same hardware cost. This page covers Lecture 12 by Thomas Viehmann and Lecture 36 by Jay Shah (Flash Attention 3).

The attention bottleneck

Self-attention computes a weighted sum of values, where weights come from the similarity between queries and keys:
Attention(Q, K, V) = softmax(QKᵀ / √d) · V
The bottleneck is not arithmetic — modern GPUs can compute matrix multiplications extremely fast. The bottleneck is memory bandwidth: the standard algorithm materializes the full N×N attention score matrix in HBM, requiring O(N²) memory reads and writes. For a sequence of length 4096 with head dimension 64, that is a 128 MB intermediate buffer for a single attention head.
Flash Attention does not approximate attention. It computes the exact same result as standard attention, just with a different order of operations that avoids materializing the full N×N matrix.

Standard attention and its O(N²) memory problem

The naive attention algorithm proceeds in three distinct passes, each requiring a round-trip to HBM:
1

Compute scores

Load Q and K from HBM, compute S = QKᵀ / √d, write S back to HBM. Cost: O(N²d) reads, O(N²) writes.
2

Compute softmax

Load S from HBM, compute P = softmax(S), write P back to HBM. Cost: O(N²) reads and writes.
3

Compute output

Load P and V from HBM, compute O = PV, write O to HBM. Cost: O(N²d) reads, O(Nd) writes.
Each pass over the N×N matrix dominates wall-clock time on memory-bandwidth-bound hardware. On an A100 (2 TB/s HBM bandwidth), a 2048-token sequence already makes attention a bottleneck.

Flash Attention: tiling and on-the-fly softmax

The key insight from Dao et al. (2022) is that softmax can be computed incrementally using a running maximum and normalizer. This means you never need the full row of scores in memory at once — you can process small tiles of K and V and update the output accumulator tile by tile. Running softmax decomposition: For a row of scores [s₁, s₂, ..., sₙ], maintain:
  • m: running max of scores seen so far
  • l: running sum of exp(sᵢ - m)
  • O: running weighted sum Σ exp(sᵢ - m) · vᵢ
When a new tile arrives with max m_new:
l_new = exp(m - m_new) · l + Σ exp(sᵢ - m_new)
O_new = exp(m - m_new) · O + Σ exp(sᵢ - m_new) · vᵢ
The final output is O_new / l_new. This online algorithm is the mathematical foundation of Flash Attention.

Forward pass algorithm

The Flash Attention forward pass tiles the outer loop over query blocks (rows of Q) and the inner loop over key/value blocks:
// From lecture_012/flash_attention.cu by Thomas Viehmann
constexpr int d = 128;
constexpr int B_r = 8;   // query tile size (rows)
constexpr int B_c = 32;  // key/value tile size (columns)

__global__ void flash_attention_k(
    float *out, float *out_l,
    float *Q, float *K, float *V,
    float scaling, int n, int T_r, int T_c
) {
    // Shared memory holds one tile of Q, K, V at a time
    __shared__ float Q_i[B_r][d];
    __shared__ float K_j[B_c][d];
    __shared__ float V_j[B_c][d];
    __shared__ float S[B_r][B_c];

    // Per-thread accumulators stay in registers — never written to HBM
    float l_i[B_r_over_bdy];   // normalizer
    float m_i[B_r_over_bdy];   // running max
    float O_i[B_r_over_bdy][d_over_bdx];  // output accumulator

    // Load Q tile once per block into shared memory
    // Inner loop: iterate over K/V tiles
    for (int j = 0; j < T_c; j++) {
        // Load K_j, V_j into shared memory
        // Compute S = Q_i @ K_j^T  (scaled dot products)
        // Update running max m_i, normalizer l_i, output O_i
    }

    // Write O_i / l_i to HBM once — only at the end
    for (int ii = 0; ii < B_r_over_bdy; ii++) {
        out[...] = O_i[ii][dd] / l_i[ii];
        out_l[...] = m_i[ii] + logf(l_i[ii]);  // log-sum-exp for backward
    }
}
The out_l buffer stores the log-sum-exp value (m + log l) for each query row. The backward pass needs this to recompute attention weights without re-reading Q, K, V in full.
The full kernel from lecture_012/flash_attention.cu implements the online softmax update inline inside the K/V tile loop:
for (int ii = 0; ii < B_r_over_bdy; ii++) {
    float m = m_i[ii];
    float last_m = m;

    // Find new max over this K tile
    for (int jj = 0; jj < B_c; jj++) {
        if (m < S[ii * bdy + tid_y][jj]) m = S[ii * bdy + tid_y][jj];
    }
    m_i[ii] = m;

    // Rescale previous accumulator
    float l = expf(last_m - m) * l_i[ii];
    for (int dd = 0; dd < d_over_bdx; dd++)
        O_i[ii][dd] *= expf(last_m - m);

    // Accumulate new tile
    for (int jj = 0; jj < B_c; jj++) {
        float P_ij = expf(S[ii * bdy + tid_y][jj] - m);
        l += P_ij;
        for (int dd = 0; dd < d_over_bdx; dd++)
            O_i[ii][dd] += P_ij * V_j[jj][dd * bdx + tid_x];
    }
    l_i[ii] = l;
}

Backward pass with recomputation

The backward pass recomputes attention weights on the fly rather than storing the N×N matrix from the forward pass. Given the saved log-sum-exp values L = m + log(l), it can reconstruct P = exp(S - L) during the backward sweep. This is a deliberate trade: the backward pass does more arithmetic (recomputing softmax), but it avoids the enormous memory cost of storing P. The recomputation overhead is small compared to the bandwidth savings.
Flash Attention’s backward pass requires the out_l buffer written during the forward pass. If you call a custom forward kernel and discard this buffer, you cannot use the standard backward pass.

IO complexity analysis

Let N = sequence length, d = head dimension, M = SRAM size (bytes).
AlgorithmHBM reads/writes
Standard attentionO(Nd + N²)
Flash AttentionO(N²d² / M)
When M ≥ d² (which holds for any realistic SRAM size and head dimension), Flash Attention requires strictly fewer HBM accesses than standard attention. In practice, this translates to a 2–4× wall-clock speedup on A100 hardware for typical sequence lengths. Memory usage drops from O(N²) to O(N): the N×N attention matrix is never materialized.

Flash Attention 2 and 3

Flash Attention 2 (Dao, 2023) improved on the original with:
  • Better work partitioning across warps (fewer unnecessary reads within a warp)
  • Non-causal attention uses 2× fewer non-masked operations
  • Sequences of different lengths in the same batch (variable-length batching)
Flash Attention 3 (Shah et al., 2024) — covered in Lecture 36 by Jay Shah — targets Hopper (H100) architecture specifically:
  • Exploits H100’s asynchronous data movement (TMA) to overlap GEMM and softmax
  • Uses WGMMA (warpgroup matrix multiply) for 2× the throughput of WMMA
  • FP8 attention support
  • Achieves up to 75% of H100 theoretical peak FLOPS
Lecture 36 by Jay Shah covers the integration of CUTLASS and Flash Attention 3 on Hopper GPUs. The slides are in the lecture_036/ folder.

Using Flash Attention in PyTorch

PyTorch 2.0+ includes Flash Attention as the default backend for scaled_dot_product_attention:
import torch
import torch.nn.functional as F

# Standard usage — PyTorch selects Flash Attention automatically when available
q = torch.randn(batch, heads, seq_len, head_dim, device="cuda", dtype=torch.float16)
k = torch.randn(batch, heads, seq_len, head_dim, device="cuda", dtype=torch.float16)
v = torch.randn(batch, heads, seq_len, head_dim, device="cuda", dtype=torch.float16)

# Uses Flash Attention when running on a supported GPU with fp16/bf16
output = F.scaled_dot_product_attention(q, k, v)

# With causal masking (decoder self-attention)
output = F.scaled_dot_product_attention(q, k, v, is_causal=True)
To confirm which backend PyTorch selects:
with torch.backends.cuda.sdp_kernel(
    enable_flash=True, enable_math=False, enable_mem_efficient=False
):
    output = F.scaled_dot_product_attention(q, k, v)
Flash Attention requires fp16 or bf16 inputs. If you pass fp32 tensors, PyTorch will fall back to the standard math backend. Cast inputs with .half() or .to(torch.bfloat16) to get the IO-efficient kernel.

Further reading

Lecture 12 source code

Thomas Viehmann’s CUDA implementation of Flash Attention forward pass

Lecture 36: Flash Attention 3

Jay Shah on CUTLASS integration and Hopper-specific optimizations

Flash Attention paper

Dao et al., 2022 — the original Flash Attention paper

Ring Attention

Extend Flash Attention across multiple GPUs for arbitrarily long sequences

Build docs developers (and LLMs) love