Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/gpu-mode/lectures/llms.txt

Use this file to discover all available pages before exploring further.

LLM inference workloads call hundreds of small GPU kernels per forward pass. Each kernel launch has overhead, each transition between kernels stalls the pipeline, and intermediate results move through global memory unnecessarily. Mirage addresses this by superoptimizing LLM operations — automatically searching for fused implementations that eliminate these inefficiencies. Lecture 79, presented by Mengdi Wu and Xinhao Cheng, introduces Mirage’s multi-level algebraic search and its programming interface. This page explains the problem, the approach, and how to use it.

The problem: many small kernels with high overhead

A typical transformer forward pass decomposes into dozens of separate CUDA kernels: QKV projections, attention score computation, softmax, value aggregation, MLP layers, layer normalization, residual additions, and more. Each kernel incurs:
  • Launch overhead — scheduling a kernel on the GPU takes microseconds. Repeated hundreds of times, this adds up.
  • Memory round-trips — intermediate tensors are written to global HBM after each kernel and read back by the next one. Global memory bandwidth is a hard bottleneck.
  • Poor occupancy — small kernels that don’t tile well leave SMs partially idle.
The standard solution is kernel fusion: combine multiple operations into a single kernel that keeps intermediates in registers or shared memory. But writing fused kernels by hand is difficult, error-prone, and hardware-specific. Compilers like torch.compile fuse some operations, but they operate at a high level and miss many optimization opportunities that require reasoning about GPU memory hierarchy and tiling. Mirage automates this fusion through algebraic superoptimization at multiple levels of the GPU memory hierarchy.

What Mirage does: algebraic kernel superoptimization

Mirage treats kernel optimization as a program synthesis problem. Given a computation graph (e.g. the operations in a transformer layer), Mirage:
  1. Searches over a space of algebraically equivalent implementations.
  2. Evaluates each candidate on the target hardware.
  3. Returns the fastest implementation that is provably equivalent to the original.
The key insight is that many optimization opportunities — operator fusion, tiling, reordering, layout transformations — can be expressed as algebraic rewrites on tensor computations. By searching over these rewrites systematically, Mirage finds fused mega kernels that human-written code and conventional compilers miss.
The term “mega kernel” refers to a single CUDA kernel that implements multiple logical operations from the original computation graph. Mirage generates these kernels automatically; you do not write them by hand.

Multi-level tiling

Mirage’s search space is structured around the three levels of the GPU memory hierarchy:
1

GPU level (L0)

At the top level, Mirage partitions the computation across the full device: how the global input tensors are split across streaming multiprocessors (SMs). This determines which data each SM needs to load from HBM.
2

Thread block level (L1)

Within each SM, Mirage determines how the computation is tiled into thread blocks (TBs). Each TB loads a tile of the input into shared memory (SMEM) and computes a tile of the output. Mirage searches over tile sizes, shared memory layouts, and which intermediate results to keep in SMEM.
3

Warp level (L2)

Within each thread block, Mirage assigns work to individual warps. At this level, it decides register allocation for intermediates, warp-level matrix multiply (WMMA/MMA) instructions, and how register tiles are laid out to minimize bank conflicts.
By optimizing across all three levels simultaneously, Mirage can find fusions that hide latency between levels — for example, keeping an intermediate tensor in registers across two operations that a naive compiler would spill to shared memory.

The Mirage search algorithm

Mirage’s search enumerates candidate kernel implementations by applying algebraic rewrites to the original computation graph:
1

Represent the computation as a µ-graph

Mirage represents computations as µ-graphs — directed acyclic graphs where nodes are tensor operators and edges are data dependencies. Each level of the hierarchy (GPU, TB, warp) has its own µ-graph layer.
2

Enumerate valid fusions

Mirage applies a set of rewrite rules that preserve the semantics of the computation. Rules include operator fusion, tile size changes, layout transforms, and reordering of independent operations. The search space is pruned using cost estimates and algebraic constraints.
3

Lower to CUDA

Each candidate µ-graph is lowered to a CUDA kernel using a code generation backend. The generated code uses CUTLASS primitives for matrix operations and custom intrinsics for other operations.
4

Profile and select

Candidate kernels are profiled on the target GPU. Mirage selects the fastest implementation and caches it for future use.
The search is not exhaustive over all possible kernels — the rewrite rule set constrains it to a structured space. Within this space, Mirage guarantees that every candidate is semantically equivalent to the original computation (up to floating-point rounding).

Safety: verifying semantic equivalence

Superoptimization is only useful if the optimized kernel produces correct results. Mirage uses probabilistic equivalence checking to verify candidate kernels:
  • It evaluates both the original and candidate implementations on random inputs.
  • It compares outputs with tolerance for floating-point differences.
  • By the Schwartz–Zippel lemma, if two polynomials (tensor computations) agree on sufficiently many random inputs, they are algebraically equivalent with high probability.
Probabilistic checks have a small false-positive rate. Mirage runs multiple independent checks to reduce this probability to negligible levels, but for safety-critical applications you should also run deterministic validation on representative inputs.
This approach is practical because it avoids the cost of formal theorem proving while providing strong correctness guarantees for the algebraic transformations that Mirage applies.

Results: speedups on attention and FFN

Mirage achieves significant speedups over hand-optimized kernels on key LLM operations:
OperationBaselineMirage speedup
Multi-head attention (MHA)FlashAttention-21.3–1.9×
Grouped-query attention (GQA)FlashAttention-21.4–2.1×
Feed-forward network (FFN)cuBLAS + fused activation1.2–1.6×
Attention + FFN combinedSeparate kernels1.5–2.5×
Speedups are largest when Mirage can fuse operations that span the attention and FFN blocks — something conventional compilers do not attempt because the operations are too large to fuse naively.
Actual speedups depend on model configuration (head count, head dimension, sequence length) and GPU generation (A100 vs. H100 vs. H200). Mirage re-runs its search for each target configuration.

Programming interface

Mirage exposes a Python API for defining computations and triggering optimization:
import mirage as mi

# Define the computation graph
graph = mi.new_kernel_graph()

# Input tensors: batch=1, seqlen=4096, heads=32, dim=128
Q = mi.input(graph, shape=(1, 32, 4096, 128), dtype=mi.float16)
K = mi.input(graph, shape=(1, 32, 4096, 128), dtype=mi.float16)
V = mi.input(graph, shape=(1, 32, 4096, 128), dtype=mi.float16)

# Define multi-head attention
output = mi.attention(graph, Q, K, V)

# Run superoptimization search
optimized = mi.superoptimize(graph, config={
    "target_gpu": "H100",
    "optimization_level": 2,
})
# Run the optimized kernel
import torch

q = torch.randn(1, 32, 4096, 128, dtype=torch.float16, device="cuda")
k = torch.randn(1, 32, 4096, 128, dtype=torch.float16, device="cuda")
v = torch.randn(1, 32, 4096, 128, dtype=torch.float16, device="cuda")

result = optimized(q, k, v)
# Integration with PyTorch via custom op
import torch
import mirage as mi

# Register as a PyTorch custom operator
@mi.torch_op
def fused_attention_ffn(q, k, v, w1, w2):
    # Mirage fuses attention + FFN into a single mega kernel
    attn_out = mi.attention(q, k, v)
    ffn_out = mi.mlp(attn_out, w1, w2)
    return ffn_out
The optimization_level parameter controls the trade-off between search time and kernel quality. Level 1 is fast and finds good solutions for common patterns. Level 2 runs a longer search and can find larger speedups for unusual configurations.

Using Mirage with existing models

Mirage can be applied to PyTorch models via a graph capture and replace workflow:
import torch
import mirage as mi

model = load_your_model()  # Any PyTorch model

# Capture the computation graph
with torch.no_grad():
    example_input = torch.randn(1, 512, device="cuda", dtype=torch.float16)
    traced = torch.jit.trace(model, example_input)

# Replace attention and FFN subgraphs with Mirage mega kernels
optimized_model = mi.optimize(traced, target_gpu="H100")

# Run inference with the optimized model
output = optimized_model(example_input)

Lecture 79 slides

Mengdi Wu and Xinhao Cheng’s original Mirage lecture slides

GPU Mode Discord

Ask questions and discuss Mirage, kernel fusion, and LLM compilers

Build docs developers (and LLMs) love