Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/gpu-mode/lectures/llms.txt

Use this file to discover all available pages before exploring further.

Fusing GPU kernels means combining multiple operations that would normally run as separate CUDA kernels into a single kernel that keeps intermediate results in fast registers or shared memory. The reward is dramatic: operations that are memory-bandwidth-bound individually become compute-bound when fused, often delivering 2–5× speedups with no change to numerical results. This page is based on Lecture 18 by Kapil Sharma.

Why kernel fusion matters

Modern GPUs have a large arithmetic throughput gap relative to their memory bandwidth. An A100 can perform ~312 TFLOPS of FP16 matrix math, but its HBM bandwidth is ~2 TB/s. For a simple elementwise operation like ReLU on a 1 GB activation tensor:
  • Without fusion: read 1 GB from HBM, apply ReLU, write 1 GB back to HBM
  • With fusion into the preceding matmul: the matmul’s output never leaves registers before ReLU is applied
The “GPUs go brrr” mental model (referenced in the lecture) categorizes operations as either memory bandwidth bound or compute bound. Pointwise operations and small normalizations are almost always bandwidth bound. Fusing them into adjacent compute-heavy kernels (matmuls, convolutions) hides their cost almost entirely.
The lecture covers this in the context of recommendation models (DLRM) rather than transformer attention. The fusion principles are the same, but the dominant operations differ: embedding lookups, dense linear layers, and interaction features rather than QKV projections.

Common fusion patterns

Pointwise + pointwise

The simplest fusion: chain multiple elementwise operations into a single kernel pass.
# Unfused: three kernel launches, three round-trips to HBM
x = torch.relu(x)
x = x * scale
x = x + bias

# Fused: one kernel launch, one round-trip
# torch.compile generates this automatically
x = torch.relu(x) * scale + bias
Each unfused step reads and writes the entire tensor. The fused version reads once and writes once.

Matmul + bias + activation

One of the most impactful fusion targets in neural networks. Instead of:
  1. Launch GEMM kernel → write output to HBM
  2. Launch bias-add kernel → read/write output from HBM
  3. Launch activation kernel → read/write output from HBM
A fused kernel computes the bias-add and activation on each output tile as soon as the GEMM tile is ready, before it would be written to HBM.
# torch.nn.Linear with activation fused via torch.compile
import torch

model = torch.nn.Sequential(
    torch.nn.Linear(1024, 1024),
    torch.nn.ReLU(),
    torch.nn.Linear(1024, 512),
    torch.nn.GELU(),
)

# torch.compile identifies the linear+activation pattern and fuses it
compiled_model = torch.compile(model)

Normalization + scale + shift

Layer norm and RMS norm each require a statistics pass (mean, variance) and a normalization pass. When followed by a learned scale and shift (gamma, beta), the scale and shift can be fused into the normalization pass:
# Unfused
x = layer_norm(x)           # kernel 1
x = x * gamma + beta        # kernel 2

# Fused: layer_norm applies gamma and beta internally
x = F.layer_norm(x, normalized_shape, weight=gamma, bias=beta)

torch.compile and kernel fusion

torch.compile is PyTorch’s primary interface for automatic kernel fusion. It traces the computation graph using TorchDynamo, then applies optimization passes (including fusion) through TorchInductor before generating either Triton or C++ CUDA kernels.
import torch

def forward(x, weight, bias):
    x = x @ weight.T
    x = x + bias
    x = torch.relu(x)
    return x

# Compile with default settings
compiled_forward = torch.compile(forward)

# Run once to trigger compilation
x = torch.randn(128, 1024, device="cuda", dtype=torch.float16)
weight = torch.randn(512, 1024, device="cuda", dtype=torch.float16)
bias = torch.randn(512, device="cuda", dtype=torch.float16)

output = compiled_forward(x, weight, bias)
To inspect what torch.compile generated, set the TORCH_LOGS environment variable:
TORCH_LOGS=output_code python your_script.py
This prints the generated Triton or C++ kernel code. The lecture references output_triton_code/ and torch_compile_generated_triton.py as examples of this output.
Use torch.compile(model, mode="reduce-overhead") for inference and mode="max-autotune" for training workloads where the extra compilation time is worth the throughput gain.

Writing custom fused kernels in Triton

When torch.compile’s automatic fusion does not cover your pattern, or you need precise control over tile sizes, you can write fused kernels directly in Triton. Triton operates at the tile level, making fusion natural: compute one tile’s worth of operation A, then immediately apply operation B to that tile before moving on.
import triton
import triton.language as tl

@triton.jit
def fused_linear_relu_kernel(
    x_ptr, w_ptr, b_ptr, out_ptr,
    M, N, K,
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
    # Block indices
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)

    # Compute output tile: accumulate matmul
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    for k in range(0, K, BLOCK_K):
        x_tile = tl.load(x_ptr + ...)
        w_tile = tl.load(w_ptr + ...)
        acc += tl.dot(x_tile, w_tile)

    # Apply bias + ReLU before writing to HBM
    bias = tl.load(b_ptr + pid_n * BLOCK_N + tl.arange(0, BLOCK_N))
    acc = acc + bias[None, :]
    acc = tl.maximum(acc, 0.0)  # ReLU — fused, never hits HBM

    # Write result to HBM once
    tl.store(out_ptr + ..., acc)
The key: acc lives in registers throughout the entire kernel. The bias add and ReLU happen at register speed, not HBM speed.

LoRA fusion example

The lecture uses LoRA (Low-Rank Adaptation) as a motivating example. Standard LoRA computes:
output = x @ W + x @ A @ B * alpha
This requires two separate matmul chains that each read x from HBM. A fused approach loads x once and computes both paths in the same kernel or with a fused CUDA graph:
# Unfused LoRA forward
def lora_forward_unfused(x, W, A, B, alpha):
    base = x @ W          # kernel 1: load x, W
    lora = x @ A @ B      # kernel 2: load x, A; kernel 3: load result, B
    return base + lora * alpha  # kernel 4

# With torch.compile, the addition and scale can be fused
# A custom Triton kernel can fuse the entire computation
@triton.jit
def fused_lora_kernel(x_ptr, W_ptr, A_ptr, B_ptr, out_ptr, alpha, ...):
    # Load x tile once
    x_tile = tl.load(x_ptr + ...)
    # Compute base path: x @ W
    base = tl.dot(x_tile, tl.load(W_ptr + ...))
    # Compute LoRA path: x @ A @ B
    xa = tl.dot(x_tile, tl.load(A_ptr + ...))
    lora = tl.dot(xa, tl.load(B_ptr + ...))
    # Combine and write once
    tl.store(out_ptr + ..., base + lora * alpha)
The lecture’s lora_on_simple_mlp.py trains a small MLP on the Criteo click prediction dataset, using LoRA adapters on the dense layers. The fusion target here is recommendation model inference where the LoRA path adds minimal latency.

Measuring speedup with profiling

Use PyTorch’s built-in profiler or NVIDIA’s Nsight Compute to verify fusion is happening and quantify the speedup:
from torch.profiler import profile, ProfilerActivity

with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    with_stack=True,
) as prof:
    for _ in range(10):
        output = compiled_model(input_tensor)

# Print kernel-level timing
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))

# Export Chrome trace for detailed visualization
prof.export_chrome_trace("trace.json")
# Open at chrome://tracing
Key metrics to compare before and after fusion:
MetricUnfusedFused
Number of CUDA kernel launchesMany small kernelsFewer, larger kernels
HBM bytes read/writtenHigh (intermediate tensors)Low (only final outputs)
Memory bandwidth utilizationNear 100% (bound)Reduced
Arithmetic intensityLowHigher
The lecture’s perf_screenshots/ folder contains profiler output from a real DLRM training run, showing the before/after kernel timeline when torch.compile fusion is applied.

Further reading

Lecture 18 code

Kapil Sharma’s fused kernel examples including LoRA and DLRM

Lecture 29: Triton Internals

How Triton compiles your kernel code down to PTX and SASS

Liger Kernel (Lecture 28)

Production-quality fused kernels for LLM training (RMSNorm, cross-entropy)

GPUs go brrr

Horace He’s guide to bandwidth vs. compute bottlenecks

Build docs developers (and LLMs) love