Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/gpu-mode/lectures/llms.txt

Use this file to discover all available pages before exploring further.

Writing high-performance GPU kernels — the kind that saturate hardware and outperform hand-tuned libraries — requires deep knowledge of GPU architecture and a significant amount of low-level code. ScaleML Lecture 75 is split between two speakers: William Brandon delivers a GPU programming fundamentals refresher, and Simran Arora introduces ThunderKittens, a domain-specific language (DSL) for writing efficient attention kernels at a higher level of abstraction. This lecture is part of the GPU Mode ScaleML Series.

Part 1: GPU programming fundamentals (William Brandon)

Before writing any kernel, you need a working mental model of GPU hardware. This section covers the hierarchy from threads to hardware, and how attention kernels map onto it.

GPU memory hierarchy

Modern NVIDIA GPUs expose three levels of memory that a kernel programmer must understand:

Global (HBM)

High-bandwidth memory on the GPU die. Terabytes per second of bandwidth, but high latency (~400–800 cycles). Every tensor you allocate with torch.empty(..., device='cuda') lives here.

Shared memory (SRAM)

On-chip scratchpad shared by all threads in a block. Very low latency (~5 ns) and very high bandwidth, but tiny — typically 48 KB to 228 KB per SM depending on configuration. The key resource for tiling algorithms.

Registers

Per-thread private storage. The fastest possible memory — register reads have no latency overhead. Limited to a few hundred per thread; spilling to global memory is catastrophically expensive.
Global memory (HBM)       ~2 TB/s
      ↓  load/store
Shared memory (SRAM)      ~20 TB/s per SM
      ↓  load/store
Register file              ~340 TB/s per SM (H100 estimate)
FlashAttention achieves its speedup primarily by tiling the attention computation so that the Q, K, V tiles fit in shared memory. This avoids repeated round-trips to global memory for the intermediate attention matrix, which is the bottleneck in naive attention implementations.

Warps, blocks, and the execution model

Threads are the basic unit of execution, but they execute in lockstep groups of 32 called warps:
  • A warp is 32 threads that execute the same instruction simultaneously (SIMT)
  • A block (also called a cooperative thread array, CTA) is a group of up to 1024 threads that share an SM and its shared memory
  • A grid is the full set of blocks launched by a single kernel call
// Query the thread and block indices inside a kernel
int thread_id   = threadIdx.x;                        // 0..blockDim.x-1
int block_id    = blockIdx.x;                         // 0..gridDim.x-1
int global_id   = blockIdx.x * blockDim.x + threadIdx.x;

// Warp identity
int warp_id     = threadIdx.x / 32;
int lane_id     = threadIdx.x % 32;  // position within warp (0..31)
Warp divergence — when threads in the same warp take different branches — serializes execution and halves (or worse) throughput. In attention kernels, avoid data-dependent branches inside the inner loop. Masking strategies (multiplying by 0/1) are usually preferable to if statements.

Writing an attention kernel from scratch

A naive attention kernel suffers from reading and writing the full N×NN \times N attention matrix to global memory. Here is the structure of a tiled implementation that avoids this:
// Pseudocode for a tiled attention kernel
// Each block handles one (query tile, output tile) pair

__global__ void tiled_attention(
    const float* Q,   // [N, d]
    const float* K,   // [N, d]
    const float* V,   // [N, d]
    float*       O,   // [N, d]
    int N, int d
) {
    // Allocate tiles in shared memory
    __shared__ float Q_tile[TILE_Q][HEAD_DIM];
    __shared__ float K_tile[TILE_K][HEAD_DIM];
    __shared__ float V_tile[TILE_K][HEAD_DIM];

    // Load Q tile once for this block
    load_tile(Q_tile, Q, blockIdx.x * TILE_Q, N, d);

    float O_acc[TILE_Q][HEAD_DIM] = {0};  // output accumulator in registers
    float m[TILE_Q]   = {-INFINITY};      // running max for online softmax
    float l[TILE_Q]   = {0};              // running denominator

    // Iterate over K/V tiles
    for (int kv_start = 0; kv_start < N; kv_start += TILE_K) {
        load_tile(K_tile, K, kv_start, N, d);
        load_tile(V_tile, V, kv_start, N, d);

        // Compute S = Q_tile @ K_tile^T  (in registers via tensor core MMA)
        // Apply online softmax update (FlashAttention algorithm)
        // Accumulate O_acc += softmax(S) @ V_tile
        flash_attention_inner(Q_tile, K_tile, V_tile, O_acc, m, l);
    }

    // Normalize and write output
    store_tile(O, O_acc, l, blockIdx.x * TILE_Q, N, d);
}
The flash_attention_inner function implements the online softmax trick from FlashAttention: it maintains running statistics (max mm and denominator ll) that let you update the output accumulator incrementally without materializing the full attention row.

Part 2: ThunderKittens (Simran Arora)

Writing the kernel above correctly — with correct memory access patterns, tensor core utilization, and pipeline overlap — takes thousands of lines of careful CUDA C++. ThunderKittens is a DSL that makes this tractable.

What ThunderKittens is

ThunderKittens (TK) is a C++ library developed at Stanford and Together AI that provides tile-level abstractions for writing GPU kernels. Instead of reasoning about individual threads and bytes, you write operations on tiles — rectangular blocks of data that map directly onto warp-level hardware primitives (tensor core operations, shared memory loads, register files).
Raw CUDA:        threads → warps → registers → tensor cores (manual)
ThunderKittens:  tiles   → tile operations → hardware (automatic)
The key insight is that modern GPU kernels — especially attention — are best described as sequences of tile operations: load a tile, multiply tiles, apply a softmax over a tile, store a tile. TK gives you exactly this vocabulary.

Tile types

ThunderKittens has three fundamental tile types, each mapping to a different level of the memory hierarchy:

rt (register tile)

A tile of values distributed across the registers of threads in a warp. This is the primary compute tile — tensor core MMA operations consume and produce register tiles.

st (shared tile)

A tile of values in shared memory, accessible by all threads in a block. Used for staging data between global memory and registers.

gl (global layout)

A descriptor for data in global memory, parameterized by shape and stride. Used for structured loads and stores.
#include "kittens.cuh"
using namespace kittens;

// Define tile shapes
using q_tile = rt_bf<16, 64>;   // 16×64 register tile, bfloat16
using k_tile = rt_bf<16, 64>;
using v_tile = rt_bf<16, 64>;
using o_tile = rt_bf<16, 64>;

// Shared memory tiles for staging
using sq_tile = st_bf<16, 64>;
using sk_tile = st_bf<16, 64>;

Core operations: load, store, MMA

TK provides three categories of operations that map directly to hardware: Load and store move data between the memory hierarchy levels:
// Load from global to shared memory (coalesced, asynchronous)
load(sq, gl_q, {batch, head, row_offset, 0});  // gl_q is a global layout descriptor

// Load from shared to register tile (warp-level load)
load(q_reg, sq);  // q_reg: register tile, sq: shared tile
MMA (matrix multiply-accumulate) performs tensor core operations:
// Compute: acc += A @ B^T  (warp-level, uses tensor cores)
rt_fl<16, 16> acc;   // float32 accumulator
zero(acc);
mma_ABt(acc, q_reg, k_reg, acc);  // acc += q_reg @ k_reg^T
Reduction and element-wise operations work over tile dimensions:
// Row-wise max for softmax (row_vec holds one value per row)
row_max(row_max_vec, acc);       // max across columns, result in row vector
sub_row(acc, acc, row_max_vec);  // broadcast subtract
exp2(acc, acc);                  // element-wise exp2
row_sum(row_sum_vec, acc);       // sum across columns
div_row(acc, acc, row_sum_vec);  // normalize

Writing Flash Attention with ThunderKittens

Here is a simplified but representative TK kernel for Flash Attention forward pass:
#include "kittens.cuh"
using namespace kittens;

// Tile dimensions
constexpr int TILE_DIM  = 64;  // head dimension
constexpr int BLOCK_DIM = 16;  // sequence tile size

template<int HEAD_DIM>
__global__ void flash_attention_tk(
    const bf16* Q, const bf16* K, const bf16* V, bf16* O,
    int seq_len, int num_heads
) {
    // Register tiles for Q and output accumulator
    rt_bf<BLOCK_DIM, HEAD_DIM> q_reg;
    rt_fl<BLOCK_DIM, HEAD_DIM> o_acc;
    rt_fl<BLOCK_DIM, BLOCK_DIM> attn_scores;

    // Running softmax statistics
    rt_fl<BLOCK_DIM, 1> row_max_old, row_max_new;
    rt_fl<BLOCK_DIM, 1> row_sum;

    // Shared memory tiles for K/V streaming
    __shared__ st_bf<BLOCK_DIM, HEAD_DIM> k_smem, v_smem;

    // Descriptors for global memory
    auto gl_q = make_global_layout(Q, seq_len, HEAD_DIM);
    auto gl_k = make_global_layout(K, seq_len, HEAD_DIM);
    auto gl_v = make_global_layout(V, seq_len, HEAD_DIM);

    int q_row = blockIdx.x * BLOCK_DIM;

    // Load Q tile into registers (stays resident throughout)
    st_bf<BLOCK_DIM, HEAD_DIM> q_smem;
    load(q_smem, gl_q, {blockIdx.z, blockIdx.y, q_row, 0});
    load(q_reg, q_smem);

    zero(o_acc);
    neg_infty(row_max_old);
    zero(row_sum);

    // Stream over K/V tiles
    for (int kv_row = 0; kv_row < seq_len; kv_row += BLOCK_DIM) {
        load(k_smem, gl_k, {blockIdx.z, blockIdx.y, kv_row, 0});
        load(v_smem, gl_v, {blockIdx.z, blockIdx.y, kv_row, 0});

        rt_bf<BLOCK_DIM, HEAD_DIM> k_reg, v_reg;
        load(k_reg, k_smem);
        load(v_reg, v_smem);

        // S = Q @ K^T / sqrt(d)
        zero(attn_scores);
        mma_ABt(attn_scores, q_reg, k_reg, attn_scores);
        mul(attn_scores, attn_scores, 1.0f / sqrtf(HEAD_DIM));

        // Online softmax: update running max and denominator
        row_max(row_max_new, attn_scores);
        max(row_max_new, row_max_old, row_max_new);

        // Rescale previous accumulator
        sub_row(attn_scores, attn_scores, row_max_new);
        exp2(attn_scores, attn_scores);

        // Correction factor for previous block
        rt_fl<BLOCK_DIM, 1> correction;
        sub(correction, row_max_old, row_max_new);
        exp2(correction, correction);
        mul_row(o_acc, o_acc, correction);
        mul_row(row_sum, row_sum, correction);

        // Accumulate softmax-weighted V
        rt_fl<BLOCK_DIM, BLOCK_DIM> attn_scores_norm;
        copy(attn_scores_norm, attn_scores);
        row_sum_add: {
            rt_fl<BLOCK_DIM, 1> block_sum;
            row_sum(block_sum, attn_scores);
            add(row_sum, row_sum, block_sum);
        }
        mma_AB(o_acc, attn_scores_norm, v_reg, o_acc);

        copy(row_max_old, row_max_new);
    }

    // Normalize output
    div_row(o_acc, o_acc, row_sum);

    // Write output
    st_bf<BLOCK_DIM, HEAD_DIM> o_smem;
    auto gl_o = make_global_layout(O, seq_len, HEAD_DIM);
    store(o_smem, o_acc);
    store(gl_o, o_smem, {blockIdx.z, blockIdx.y, q_row, 0});
}
The key advantage of writing this in TK versus raw CUDA is that TK handles the register layout required by tensor cores automatically. In raw CUDA, you must manually pack data into the mma.sync fragment layout — a major source of bugs and performance loss. TK’s rt type is always in the correct layout.

Performance: how TK achieves near-cuBLAS speeds

ThunderKittens achieves high performance through several mechanisms:
1

Tensor core alignment

All register tiles are sized and aligned to match the hardware’s MMA instruction dimensions (16×8×16 for bfloat16 on Ampere/Hopper). No reshape or copy is needed before issuing tensor core instructions.
2

Warp-level abstraction

TK operations map to single warps. This makes it easy to pipeline warps using warpgroup abstractions, overlapping memory loads from one warp with compute in another.
3

Asynchronous copies

Shared memory loads use cp.async instructions (exposed via TK’s load for shared tiles), overlapping memory transfers with compute from previous tiles.
4

No unnecessary abstraction overhead

TK is a thin header-only library — no runtime, no JIT compilation. The compiler sees the full kernel and can optimize register allocation, instruction scheduling, and unrolling.
In the lecture, benchmarks show TK Flash Attention achieving 90–95% of cuBLAS GEMM throughput (MFU) on H100 GPUs — on par with the hand-tuned FlashAttention-3 implementation.

When to use ThunderKittens vs. raw CUDA vs. Triton

  • You are writing a new attention variant or custom matmul-style kernel
  • You need tensor core utilization and correct warp-level data layout
  • You want near-hardware-peak performance without writing raw PTX or assembly
  • You are doing kernel research and need to iterate quickly on new designs
  • You need fine-grained control over instruction scheduling or PTX-level optimizations
  • You are writing non-attention kernels (reductions, scans, custom elementwise ops)
  • You need to interface directly with CUTLASS or cuBLAS primitives
  • You are targeting architectures where TK does not yet have support
  • You are a Python programmer who wants GPU performance without C++
  • Your kernel is embarrassingly parallel and fits the Triton tile model
  • You want easy multi-backend support (NVIDIA, AMD, Intel via the Triton backend ecosystem)
  • You do not need the final ~10% of performance that raw CUDA/TK can provide
ThunderKittens is actively developed and is specifically designed for the post-Hopper GPU generation where tensor core throughput is the dominant performance lever. If you are writing custom attention mechanisms for research, TK is currently one of the most practical ways to get hardware-competitive implementations without a team of CUDA experts.

Lecture references

GPU Programming Fundamentals slides

William Brandon’s slides on GPU programming fundamentals (Lecture 75, Part 1)

ThunderKittens slides

Simran Arora’s ThunderKittens slides (ThunderKittens.pdf in the lecture_075 folder)

Simran Arora

Speaker homepage — research on efficient ML systems and custom hardware kernels

GPU Mode YouTube

Full lecture recordings on the GPU Mode YouTube channel

Build docs developers (and LLMs) love