LLM inference workloads call hundreds of small GPU kernels per forward pass. Each kernel launch has overhead, each transition between kernels stalls the pipeline, and intermediate results move through global memory unnecessarily. Mirage addresses this by superoptimizing LLM operations — automatically searching for fused implementations that eliminate these inefficiencies. Lecture 79, presented by Mengdi Wu and Xinhao Cheng, introduces Mirage’s multi-level algebraic search and its programming interface. This page explains the problem, the approach, and how to use it.Documentation Index
Fetch the complete documentation index at: https://mintlify.com/gpu-mode/lectures/llms.txt
Use this file to discover all available pages before exploring further.
The problem: many small kernels with high overhead
A typical transformer forward pass decomposes into dozens of separate CUDA kernels: QKV projections, attention score computation, softmax, value aggregation, MLP layers, layer normalization, residual additions, and more. Each kernel incurs:- Launch overhead — scheduling a kernel on the GPU takes microseconds. Repeated hundreds of times, this adds up.
- Memory round-trips — intermediate tensors are written to global HBM after each kernel and read back by the next one. Global memory bandwidth is a hard bottleneck.
- Poor occupancy — small kernels that don’t tile well leave SMs partially idle.
torch.compile fuse some operations, but they operate at a high level and miss many optimization opportunities that require reasoning about GPU memory hierarchy and tiling.
Mirage automates this fusion through algebraic superoptimization at multiple levels of the GPU memory hierarchy.
What Mirage does: algebraic kernel superoptimization
Mirage treats kernel optimization as a program synthesis problem. Given a computation graph (e.g. the operations in a transformer layer), Mirage:- Searches over a space of algebraically equivalent implementations.
- Evaluates each candidate on the target hardware.
- Returns the fastest implementation that is provably equivalent to the original.
The term “mega kernel” refers to a single CUDA kernel that implements multiple logical operations from the original computation graph. Mirage generates these kernels automatically; you do not write them by hand.
Multi-level tiling
Mirage’s search space is structured around the three levels of the GPU memory hierarchy:GPU level (L0)
At the top level, Mirage partitions the computation across the full device: how the global input tensors are split across streaming multiprocessors (SMs). This determines which data each SM needs to load from HBM.
Thread block level (L1)
Within each SM, Mirage determines how the computation is tiled into thread blocks (TBs). Each TB loads a tile of the input into shared memory (SMEM) and computes a tile of the output. Mirage searches over tile sizes, shared memory layouts, and which intermediate results to keep in SMEM.
The Mirage search algorithm
Mirage’s search enumerates candidate kernel implementations by applying algebraic rewrites to the original computation graph:Represent the computation as a µ-graph
Mirage represents computations as µ-graphs — directed acyclic graphs where nodes are tensor operators and edges are data dependencies. Each level of the hierarchy (GPU, TB, warp) has its own µ-graph layer.
Enumerate valid fusions
Mirage applies a set of rewrite rules that preserve the semantics of the computation. Rules include operator fusion, tile size changes, layout transforms, and reordering of independent operations. The search space is pruned using cost estimates and algebraic constraints.
Lower to CUDA
Each candidate µ-graph is lowered to a CUDA kernel using a code generation backend. The generated code uses CUTLASS primitives for matrix operations and custom intrinsics for other operations.
Safety: verifying semantic equivalence
Superoptimization is only useful if the optimized kernel produces correct results. Mirage uses probabilistic equivalence checking to verify candidate kernels:- It evaluates both the original and candidate implementations on random inputs.
- It compares outputs with tolerance for floating-point differences.
- By the Schwartz–Zippel lemma, if two polynomials (tensor computations) agree on sufficiently many random inputs, they are algebraically equivalent with high probability.
Results: speedups on attention and FFN
Mirage achieves significant speedups over hand-optimized kernels on key LLM operations:| Operation | Baseline | Mirage speedup |
|---|---|---|
| Multi-head attention (MHA) | FlashAttention-2 | 1.3–1.9× |
| Grouped-query attention (GQA) | FlashAttention-2 | 1.4–2.1× |
| Feed-forward network (FFN) | cuBLAS + fused activation | 1.2–1.6× |
| Attention + FFN combined | Separate kernels | 1.5–2.5× |
Actual speedups depend on model configuration (head count, head dimension, sequence length) and GPU generation (A100 vs. H100 vs. H200). Mirage re-runs its search for each target configuration.
Programming interface
Mirage exposes a Python API for defining computations and triggering optimization:Using Mirage with existing models
Mirage can be applied to PyTorch models via a graph capture and replace workflow:Lecture 79 slides
Mengdi Wu and Xinhao Cheng’s original Mirage lecture slides
GPU Mode Discord
Ask questions and discuss Mirage, kernel fusion, and LLM compilers