Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/gpu-mode/lectures/llms.txt

Use this file to discover all available pages before exploring further.

Triton lets you write GPU kernels in Python-like syntax that compile directly to PTX — the same intermediate representation produced by CUDA C++. This guide is based on Lecture 14 by Umer Adil, which walks through the full practitioner workflow: understanding the programming model, writing real kernels, masking out-of-bounds memory, tiling 2-D computations, and automating tile-size selection with @triton.autotune.

Why Triton?

CUDA gives you full control over every SM register, warp shuffle, and shared-memory bank — but that power comes at a steep cost in complexity. Triton occupies the sweet spot between PyTorch convenience and hand-tuned CUDA performance.
CUDATritontorch.compile
ControlCompleteBlock-levelAlmost none
Performance ceilingAbsolute maxNear-optimalGood baseline
Lines of codeHighMediumNone
Shared-memory managementManualAutomaticAutomatic
Debuggable in PythonNoYes (simulate)No
When to reach for Triton:
1

Profile first

If your model is too slow, run it through torch.compile. It often generates Triton kernels automatically and can be a free win.
2

Reshape for the compiler

Restructure your PyTorch code to be more torch.compile-friendly (contiguous tensors, fewer dynamic shapes).
3

Write a Triton kernel

Identify the bottleneck operation and replace it with a hand-written Triton kernel. This page shows you how.
4

Drop to CUDA only if needed

If you need absolute maximum performance (e.g. a custom warp-level primitive), write CUDA C++. For most ML workloads Triton is sufficient.
torch.compile actually generates Triton kernels internally, so they make an excellent starting point for your own customizations. See Lecture 1 for how to extract them.

Core Concepts: Programs, Blocks, and Pointer Arithmetic

The CUDA vs. Triton programming model

In CUDA you decompose work at two levels: blocks (running on an SM) and threads (scalars within a block). Each thread operates on a single value. Triton removes the thread level entirely. Each kernel invocation — called a program — operates on a block of values simultaneously. Triton handles the thread-level decomposition internally, freeing you from manual shared-memory management.
# CUDA pseudocode: scalar operations per thread
def add_cuda_k(x, y, z, n, bs):
    block_id  = ...          # e.g., one of [0, 1]
    thread_id = ...          # e.g., one of [0, 1, 2, 3]
    offs = block_id * bs + thread_id   # scalar index
    if offs < n:
        z[offs] = x[offs] + y[offs]   # scalar arithmetic
# Triton pseudocode: vector operations per program
def add_triton_k(x, y, z, n, bs):
    pid  = tl.program_id(0)             # program id (≈ block_id)
    offs = pid * bs + tl.arange(0, bs)  # vector of indices
    mask = offs < n                     # vector of bools
    x_vals = tl.load(x + offs, mask)   # vector load
    y_vals = tl.load(y + offs, mask)   # vector load
    tl.store(z + offs, x_vals + y_vals, mask)  # vector store
In Triton terminology, each kernel invocation that processes one block is a program. The program ID is obtained with tl.program_id(axis) and plays the same role as blockIdx in CUDA.

Pointer arithmetic

Triton receives raw pointers to the first element of each tensor. To address element i, you write ptr + i. For a block, tl.arange(0, BLOCK_SIZE) produces the range [0, 1, ..., BLOCK_SIZE-1], so:
pid         = tl.program_id(0)
block_start = pid * BLOCK_SIZE
offsets     = block_start + tl.arange(0, BLOCK_SIZE)  # addresses for this program
For 2-D tensors, strides are passed explicitly:
# 2-D offset: row * stride_row + col * stride_col
offs_2d = (
    tl.expand_dims(row_offs, 1) * stride_row +
    tl.expand_dims(col_offs, 0) * stride_col
)

tl.load and tl.store with Masking

Because block sizes must be powers of two and tensors may not be perfectly divisible, you almost always need a mask to guard against out-of-bounds accesses.
mask = offsets < n_elements          # boolean vector
x    = tl.load(x_ptr + offsets, mask=mask)          # safe load
tl.store(output_ptr + offsets, result, mask=mask)   # safe store
Out-of-bounds loads with a mask return zeros (or a configurable other value) rather than reading garbage memory. Masked stores are silently ignored.
Forgetting the mask is one of the most common bugs in Triton kernels. Always check whether your tensor size is guaranteed to be a multiple of your block size; if not, add the mask.

Writing a Vector Addition Kernel

The following is the real add_kernel from Lecture 29 (lecture_029/vector_add.py), which Umer Adil also uses as the canonical starter example:
import torch
import triton
import triton.language as tl


@triton.jit
def add_kernel(
    x_ptr,        # Pointer to first input vector
    y_ptr,        # Pointer to second input vector
    output_ptr,   # Pointer to output vector
    n_elements,   # Size of the vector
    BLOCK_SIZE: tl.constexpr,  # Elements each program processes
):
    # Each program covers a different slice of the vector.
    # For a 256-element vector with BLOCK_SIZE=64, programs cover
    # [0:64], [64:128], [128:192], [192:256].
    pid         = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets     = block_start + tl.arange(0, BLOCK_SIZE)

    # Guard against out-of-bounds if n_elements % BLOCK_SIZE != 0
    mask = offsets < n_elements

    x      = tl.load(x_ptr + offsets, mask=mask)
    y      = tl.load(y_ptr + offsets, mask=mask)
    output = x + y

    tl.store(output_ptr + offsets, output, mask=mask)


if __name__ == "__main__":
    size   = 1024
    x      = torch.rand(size, device="cuda")
    y      = torch.rand(size, device="cuda")
    output = torch.empty_like(x)

    # Grid: one program per BLOCK_SIZE chunk
    grid = lambda meta: (triton.cdiv(size, meta["BLOCK_SIZE"]),)

    compiled_kernel = add_kernel[grid](x, y, output, size, BLOCK_SIZE=1024)

    # Inspect the generated IR/PTX
    print(compiled_kernel.asm.keys())
    # dict_keys(['ttir', 'ttgir', 'llir', 'ptx', 'cubin'])
    print(compiled_kernel.asm["ttir"])
Key observations:
  • BLOCK_SIZE: tl.constexpr tells the compiler this parameter is a compile-time constant, enabling shape inference and loop unrolling.
  • The grid lambda returns the number of programs to launch. triton.cdiv is ceiling-division: (a + b - 1) // b.
  • compiled_kernel.asm exposes every stage of the compilation pipeline.

Blocked / Tiled Computation for 2-D Problems

For matrix operations (matmul, softmax, etc.) you need 2-D tiling. The idea is to assign each program a 2-D tile (bm × bn) of the output and accumulate partial results over the shared k-dimension.
@triton.jit
def naive_matmul_k(
    a_ptr, b_ptr, c_ptr,
    m, n, k,
    stride_am, stride_ak,
    stride_bk, stride_bn,
    stride_cm, stride_cn,
    bm: tl.constexpr, bn: tl.constexpr, bk: tl.constexpr,
):
    pid_m = tl.program_id(0)   # row tile index
    pid_n = tl.program_id(1)   # col tile index

    # 1-D offsets for rows (m) and columns (n)
    rm = pid_m * bm + tl.arange(0, bm)   # [pid_m*bm, ..., pid_m*bm + bm-1]
    rn = pid_n * bn + tl.arange(0, bn)
    rk = tl.arange(0, bk)

    # Pointers into A and B for the first k-tile
    offs_a = a_ptr + tl.expand_dims(rm, 1) * stride_am + tl.expand_dims(rk, 0) * stride_ak
    offs_b = b_ptr + tl.expand_dims(rk, 1) * stride_bk + tl.expand_dims(rn, 0) * stride_bn

    acc = tl.zeros((bm, bn), dtype=tl.float32)

    for _ in range(0, k, bk):
        a    = tl.load(offs_a)   # (bm, bk)
        b    = tl.load(offs_b)   # (bk, bn)
        acc += tl.dot(a, b, allow_tf32=False)
        offs_a += bk * stride_ak   # advance k pointer in A
        offs_b += bk * stride_bk   # advance k pointer in B

    # Write output tile
    offs_c = c_ptr + tl.expand_dims(rm, 1) * stride_cm + tl.expand_dims(rn, 0) * stride_cn
    mask   = (tl.expand_dims(rm, 1) < m) & (tl.expand_dims(rn, 0) < n)
    tl.store(offs_c, acc, mask=mask)
The grouped (swizzled) ordering variant reorders programs so that tiles accessing overlapping rows of A run close together in time, improving L2 cache reuse:
pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, group_sz)

Fused Elementwise Kernels

Fusion eliminates intermediate global-memory round-trips. A fused softmax computes exp(x - max(x)) / sum(exp(x - max(x))) in a single pass over each row:
@triton.jit
def fused_softmax_k(
    x_ptr, out_ptr,
    n_cols,
    stride_row,
    BLOCK_SIZE: tl.constexpr,
):
    row = tl.program_id(0)
    offs = tl.arange(0, BLOCK_SIZE)
    mask = offs < n_cols

    x = tl.load(x_ptr + row * stride_row + offs, mask=mask, other=-float("inf"))

    # Numerically stable softmax: subtract row max before exp
    x_max = tl.max(x, axis=0)
    x     = x - x_max
    num   = tl.exp(x)
    denom = tl.sum(num, axis=0)
    out   = num / denom

    tl.store(out_ptr + row * stride_row + offs, out, mask=mask)
Each program handles one row. This works efficiently when n_cols fits in a single tile; for wide matrices you would accumulate the max/sum in multiple passes.

@triton.autotune for Automatic Tile-Size Tuning

Picking the right block size manually is tedious and hardware-dependent. @triton.autotune runs a grid search over candidate configurations and caches the winner keyed to the problem dimensions.
@triton.autotune(
    configs=[
        triton.Config({'bm': 128, 'bn': 256, 'bk': 64, 'group_sz': 8},
                      num_stages=3, num_warps=8),
        triton.Config({'bm':  64, 'bn': 256, 'bk': 32, 'group_sz': 8},
                      num_stages=4, num_warps=4),
        triton.Config({'bm': 128, 'bn': 128, 'bk': 32, 'group_sz': 8},
                      num_stages=4, num_warps=4),
        triton.Config({'bm': 128, 'bn':  64, 'bk': 32, 'group_sz': 8},
                      num_stages=4, num_warps=4),
        triton.Config({'bm':  64, 'bn': 128, 'bk': 32, 'group_sz': 8},
                      num_stages=4, num_warps=4),
        triton.Config({'bm': 128, 'bn':  32, 'bk': 32, 'group_sz': 8},
                      num_stages=4, num_warps=4),
        triton.Config({'bm':  64, 'bn':  32, 'bk': 32, 'group_sz': 8},
                      num_stages=5, num_warps=2),
        triton.Config({'bm':  32, 'bn':  64, 'bk': 32, 'group_sz': 8},
                      num_stages=5, num_warps=2),
    ],
    # Re-tune when any of these problem dimensions change
    key=['m', 'n', 'k'],
)
@triton.jit
def grouped_autotuned_matmul_k(
    a_ptr, b_ptr, c_ptr,
    m, n, k,
    stride_am, stride_ak,
    stride_bk, stride_bn,
    stride_cm, stride_cn,
    bm: tl.constexpr, bn: tl.constexpr, bk: tl.constexpr,
    group_sz: tl.constexpr,
):
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)
    num_pid_m = tl.num_programs(0)
    num_pid_n = tl.num_programs(1)
    pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, group_sz)
    # ... rest of the matmul body as shown above
The key list tells Triton which arguments determine the “problem size”. When the values of m, n, or k change, Triton re-benchmarks all configs and picks the fastest one for the new shape.
Both num_stages (software-pipelining depth) and num_warps (warps per CTA) can dramatically affect performance. Include a range of values in your configs list.

Debugging Triton Kernels

Interpreter mode

Set TRITON_INTERPRET=1 before importing Triton to run kernels entirely on the CPU. This lets you use standard Python debuggers.
TRITON_INTERPRET=1 python my_kernel.py
Some Triton operations (like tl.swizzle2d and certain dtype promotions) have subtle differences in interpreter mode. Always verify on real GPU before shipping.

Utility functions from triton_util.py (Lecture 14)

Umer Adil ships a small utility module in the lecture repo that makes debugging far easier. Here is the full source:
import os
import triton
import triton.language as tl


def test_pid_conds(conds, pid_0=[0], pid_1=[0], pid_2=[0]):
    """Test whether conditions on program IDs are met.

    Examples:
        '=0'    checks pid_0 == 0
        ',>1'   checks pid_1 > 1
        '>1,=0' checks pid_0 > 1 and pid_1 == 0
    """
    pids  = pid_0[0], pid_1[0], pid_2[0]
    conds = conds.replace(' ', '').split(',')
    for cond, pid in zip(conds, pids):
        if cond == '':
            continue
        op, threshold = cond[0], int(cond[1:])
        if op not in ['<', '>', '>=', '<=', '=', '!=']:
            raise ValueError(f"Invalid op in rule: '{cond}'")
        op = '==' if op == '=' else op
        if not eval(f'{pid} {op} {threshold}'):
            return False
    return True


def breakpoint_if(conds, pid_0=[0], pid_1=[0], pid_2=[0]):
    """Drop into the debugger when PID conditions are met."""
    from IPython.core.debugger import set_trace
    if test_pid_conds(conds, pid_0, pid_1, pid_2):
        set_trace()


def print_if(txt, conds, pid_0=[0], pid_1=[0], pid_2=[0]):
    """Print txt when PID conditions are met."""
    if test_pid_conds(conds, pid_0, pid_1, pid_2):
        print(txt)


def check_tensors_gpu_ready(*tensors):
    """Assert tensors are contiguous and on CUDA (unless interpreting)."""
    for t in tensors:
        assert t.is_contiguous(), "A tensor is not contiguous"
        if not os.environ.get('TRITON_INTERPRET') == '1':
            assert t.is_cuda, "A tensor is not on CUDA"


def cdiv(a, b):
    """Ceiling division."""
    return (a + b - 1) // b


@triton.jit
def get_1d_offset(size, n_prev_chunks):
    return n_prev_chunks * size + tl.arange(0, size)


@triton.jit
def get_2d_offset(offs_0, offs_1, stride_0, stride_1=1):
    return (
        tl.expand_dims(offs_0, 1) * stride_0 +
        tl.expand_dims(offs_1, 0) * stride_1
    )


@triton.jit
def get_1d_mask(offs, max):
    return offs < max


@triton.jit
def get_2d_mask(offs_0, offs_1, max_0, max_1):
    return (
        (tl.expand_dims(offs_0, 1) < max_0) &
        (tl.expand_dims(offs_1, 0) < max_1)
    )
Usage pattern inside a kernel (with TRITON_INTERPRET=1):
@triton.jit
def my_kernel(x_ptr, out_ptr, n, BLOCK_SIZE: tl.constexpr):
    pid  = tl.program_id(0)
    offs = get_1d_offset(BLOCK_SIZE, pid)
    mask = get_1d_mask(offs, n)

    # Only print for program 0
    print_if(f"pid={pid}, offs={offs}", "=0")

    # Drop into debugger for program 1
    breakpoint_if(",=1")

    x = tl.load(x_ptr + offs, mask=mask)
    tl.store(out_ptr + offs, x * 2, mask=mask)

Profiling with Nsight Compute

ncu --target-processes all python my_kernel.py

Inspecting compiled artifacts

# After launching the kernel, access generated IR at every stage:
compiled = my_kernel[grid](x, out, n, BLOCK_SIZE=64)
print(compiled.asm.keys())
# dict_keys(['ttir', 'ttgir', 'llir', 'ptx', 'cubin'])

# Human-readable Triton IR
print(compiled.asm['ttir'])

# PTX assembly
print(compiled.asm['ptx'])

Built-in Benchmarking

Triton ships triton.testing for reproducible micro-benchmarks:
@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['square_matrix_size'],
        x_vals=[2**i for i in range(5, 12)],
        x_log=True,
        line_arg='provider',
        line_vals=['naive', 'grouped', 'torch'],
        line_names=['Naive', 'Grouped', 'Torch'],
        styles=[('blue', '-'), ('green', '-'), ('orange', '-')],
        ylabel='GB/s',
        plot_name='matmul-performance',
        args={},
    )
)
def benchmark(square_matrix_size, provider):
    sz = square_matrix_size
    a  = torch.rand((sz, sz), device='cuda', dtype=torch.float32)
    b  = torch.rand((sz, sz), device='cuda', dtype=torch.float32)
    quantiles = [0.5, 0.2, 0.8]
    if provider == 'naive':
        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: naive_matmul(a, b), quantiles=quantiles)
    if provider == 'grouped':
        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: grouped_matmul(a, b, group_sz=8), quantiles=quantiles)
    if provider == 'torch':
        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: torch.matmul(a, b), quantiles=quantiles)
    gbps = lambda ms: 12 * sz / ms * 1e-6
    return gbps(ms), gbps(max_ms), gbps(min_ms)

benchmark.run(print_data=True, show_plots=True)

Lecture 14 Notebook

Full worked examples including copy, greyscale, and matrix multiply kernels.

Triton Official Tutorials

The vector-add, fused softmax, and matmul tutorials from the Triton team.

Triton Internals (Lecture 29)

Understand what happens after @triton.jit: AST → MLIR → PTX.

Iris: Multi-GPU Triton (Lecture 78)

Extend Triton kernels across multiple GPUs with the Iris programming model.

Build docs developers (and LLMs) love