Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/gpu-mode/lectures/llms.txt

Use this file to discover all available pages before exploring further.

When you decorate a function with @triton.jit and call it with a grid, a sophisticated compilation pipeline quietly converts your Python DSL into a CUDA binary that runs at near-hardware speed. This guide is based on Lecture 29 by Kapil Sharma (Software Engineer @ Meta), who traces that pipeline end-to-end and compares it with the classical NVCC toolchain. Kapil has written a three-part deep-dive blog series that this lecture draws from: Part 1, Part 2, Part 3.

Compilation Pipeline Overview

Triton’s pipeline has several distinct stages, each lowering the representation one level closer to hardware:
Python (DSL)

    ▼  @triton.jit / AST visitor
Triton IR (ttir)          — high-level, device-agnostic

    ▼  TTIR → TTGIR pass manager
TritonGPU IR (ttgir)      — GPU-aware MLIR dialect

    ▼  TTGIR → LLVM IR pass manager
LLVM IR (llir)            — target-independent representation

    ▼  LLVM PTX backend
PTX                       — NVIDIA virtual ISA

    ▼  ptxas
CUBIN / fatbinary         — executable ELF binary loaded by the driver
All artifacts are cached on disk at ~/.triton/cache/ (configurable with TRITON_CACHE_DIR). You can inspect every stage after running a kernel.

Background: The NVCC Toolchain

To appreciate what Triton does, it helps to understand the classical CUDA compilation path driven by nvcc.
1

Source separation

nvcc splits .cu files into host code (C/C++) and device code (CUDA kernels). Host code is compiled by a system C++ compiler (e.g., g++ or clang).
2

Device preprocessing

The CUDA C++ preprocessor and the cicc program process device code into an intermediate form, apply optimizations, and generate PTX.
3

PTX → SASS

ptxas compiles PTX into SASS — the actual hardware instruction set for a specific GPU architecture (compute capability).
4

Linking

The host and device objects are linked together into the final executable. The CUBIN (ELF-formatted CUDA binary) is embedded inline.
Triton replaces steps 2–4 with its own Python-driven pipeline built on MLIR and LLVM.

Stage 1: Python AST → Triton IR

When Python encounters @triton.jit, the decorator inspects the function’s AST and translates it into Triton IR — a high-level, device-agnostic MLIR dialect (tt.* ops). Consider the canonical add_kernel from lecture_029/vector_add.py:
import torch
import triton
import triton.language as tl


@triton.jit
def add_kernel(
    x_ptr,        # Pointer to first input vector
    y_ptr,        # Pointer to second input vector
    output_ptr,   # Pointer to output vector
    n_elements,   # Size of the vector
    BLOCK_SIZE: tl.constexpr,
):
    pid         = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets     = block_start + tl.arange(0, BLOCK_SIZE)
    mask        = offsets < n_elements
    x           = tl.load(x_ptr + offsets, mask=mask)
    y           = tl.load(y_ptr + offsets, mask=mask)
    output      = x + y
    tl.store(output_ptr + offsets, output, mask=mask)
After running:
python3 triton/python/triton/tools/compile.py \
  --kernel-name add_kernel \
  --signature "*fp32,*fp32,*fp32,i32,64" \
  --grid=1024,1024,1024 \
  vector_add.py
the corresponding Triton IR (.ttir) looks like this:
#blocked = #triton_gpu.blocked<{
  sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]
}>
module attributes {
  "triton_gpu.num-ctas"         = 1 : i32,
  "triton_gpu.num-warps"        = 4 : i32,
  triton_gpu.target             = "cuda:89",
  "triton_gpu.threads-per-warp" = 32 : i32
} {
  tt.func public @add_kernel(
      %arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
      %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32},
      %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32},
      %arg3: i32        {tt.divisibility = 16 : i32})
      attributes {noinline = false} {

    %c1024_i32 = arith.constant 1024 : i32
    %0  = tt.get_program_id x : i32          ; pid
    %1  = arith.muli %0, %c1024_i32 : i32    ; block_start = pid * BLOCK_SIZE
    %2  = tt.make_range {end = 1024 : i32, start = 0 : i32}
                        : tensor<1024xi32, #blocked>
    %3  = tt.splat %1 : i32 -> tensor<1024xi32, #blocked>
    %4  = arith.addi %3, %2 : tensor<1024xi32, #blocked>  ; offsets
    %5  = tt.splat %arg3 : i32 -> tensor<1024xi32, #blocked>
    %6  = arith.cmpi slt, %4, %5 : tensor<1024xi32, #blocked>  ; mask
    %7  = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>, #blocked>
    %8  = tt.addptr %7, %4 : tensor<1024x!tt.ptr<f32>, #blocked>,
                             tensor<1024xi32, #blocked>
    %9  = tt.load %8, %6 : tensor<1024x!tt.ptr<f32>, #blocked>  ; load x
    %10 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>, #blocked>
    %11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr<f32>, #blocked>,
                              tensor<1024xi32, #blocked>
    %12 = tt.load %11, %6 : tensor<1024x!tt.ptr<f32>, #blocked>  ; load y
    %13 = arith.addf %9, %12 : tensor<1024xf32, #blocked>         ; x + y
    %14 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>, #blocked>
    %15 = tt.addptr %14, %4 : tensor<1024x!tt.ptr<f32>, #blocked>,
                              tensor<1024xi32, #blocked>
    tt.store %15, %13, %6 : tensor<1024x!tt.ptr<f32>, #blocked>  ; store
    tt.return
  }
}
Notice that every operation works on rank-1 tensors of 1024 elements — the block has been materialized as a vector type, and the masked load/store pattern translates directly to tt.load / tt.store with the mask tensor %6.

Stage 2: Triton IR → TritonGPU IR

The TritonGPU IR pass manager converts the device-agnostic tt.* ops into GPU-aware ttg.* ops. This is where decisions about warp and CTA layouts, shared memory usage, and instruction scheduling are made. Key passes applied in make_ttgir:
def make_ttgir(mod, metadata, opt, capability):
    pm = ir.pass_manager(mod.context)
    pm.enable_debug()

    nvidia.passes.ttnvgpuir.add_plan_cta(pm, cluster_info)
    passes.ttgpuir.add_remove_layout_conversions(pm)
    passes.ttgpuir.add_optimize_thread_locality(pm)
    passes.ttgpuir.add_accelerate_matmul(pm)
    passes.ttgpuir.add_remove_layout_conversions(pm)
    passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
    passes.common.add_cse(pm)

    if capability // 10 >= 8:
        passes.ttgpuir.add_combine_tensor_select_and_if(pm)
        passes.ttgpuir.add_pipeline(pm, opt.num_stages)  # software pipelining

    pm.run(mod)
    return mod
Notable passes include:
PassWhat it does
add_plan_ctaDetermines CTA (Cooperative Thread Array) cluster configuration
add_optimize_thread_localityReorders operations to improve locality within a warp
add_accelerate_matmulMaps tt.dot onto Tensor Core MMA instructions
add_optimize_dot_operandsSelects optimal layout for MMA operands
add_pipelineInserts software pipelining (async loads + barrier scheduling)
add_cseCommon subexpression elimination

Stage 3: TritonGPU IR → PTX / CUBIN

LLVM lowers the TritonGPU IR to LLVM IR, which the LLVM PTX backend converts to PTX. ptxas then assembles PTX into a CUBIN — an ELF-formatted binary that the CUDA driver loads into the GPU.

Triton Passes (Triton IR level)

Before the GPU-specific passes, Triton applies several language-level optimizations to the Triton IR:
  • TritonCombineOps: Folds pointer additions and accumulations.
    • dot(a, b, 0) + c → dot(a, b, c)
    • addptr(addptr(ptr, i), j) → addptr(ptr, i + j)
  • TritonReorderBroadcast: Pushes broadcasts past elementwise ops to reduce work.
    • elementwise(broadcast(a)) → broadcast(elementwise(a))
  • TritonRewriteTensorPointer: Eliminates tt.make_tensor_ptr / tt.advance ops.
  • TritonLoopUnroll: Unrolls loops annotated with tt.loop_unroll_factor.

Common MLIR passes (used by Triton directly)

void init_triton_passes_common(py::module &&m) {
    using namespace mlir;
    ADD_PASS_WRAPPER_0("add_sccp",        createSCCPPass);
    ADD_PASS_WRAPPER_0("add_symbol_dce",  createSymbolDCEPass);
    ADD_PASS_WRAPPER_0("add_inliner",     createInlinerPass);
    ADD_PASS_WRAPPER_0("add_canonicalizer", createCanonicalizerPass);
    ADD_PASS_WRAPPER_0("add_cse",         createCSEPass);
    ADD_PASS_WRAPPER_0("add_licm",        createLoopInvariantCodeMotionPass);
}

The Generated C Launcher

Triton’s AOT compilation tool (triton/tools/compile.py) also emits a C launcher that wraps the CUBIN as a byte array and exposes a pure-C entry point. This is how Triton kernels can be shipped without a Python runtime. The generated header (add_kernel.9969bdda_0123.h):
#ifndef TT_KERNEL_INCLUDES
#define TT_KERNEL_INCLUDES

#include <cuda.h>
#include <inttypes.h>
#include <stdint.h>
#include <stdio.h>

#endif

void unload_add_kernel_9969bdda_0123(void);
void load_add_kernel_9969bdda_0123(void);
// tt-linker: add_kernel_9969bdda_0123:CUdeviceptr x_ptr, CUdeviceptr y_ptr,
//            CUdeviceptr output_ptr, int32_t n_elements:64_warps1xstages3
CUresult add_kernel_9969bdda_0123(
    CUstream stream,
    CUdeviceptr x_ptr,
    CUdeviceptr y_ptr,
    CUdeviceptr output_ptr,
    int32_t n_elements
);
The corresponding .c file (abbreviated) shows the kernel parameters baked into the launcher:
/*
 * ['BLOCK_SIZE=64', 'num_warps=1', 'num_stages=3']
 */
CUresult add_kernel_9969bdda_0123(
    CUstream stream,
    CUdeviceptr x_ptr,
    CUdeviceptr y_ptr,
    CUdeviceptr output_ptr,
    int32_t n_elements
) {
    if (add_kernel_9969bdda_0123_func == NULL)
        load_add_kernel_9969bdda_0123();

    unsigned int gX = 1024, gY = 1024, gZ = 1024;
    void *args[4] = { &x_ptr, &y_ptr, &output_ptr, &n_elements };

    if (gX * gY * gZ > 0)
        return cuLaunchKernel(
            add_kernel_9969bdda_0123_func,
            gX, gY, gZ,   /* grid */
            1 * 32, 1, 1, /* block: 1 warp × 32 threads */
            0, stream, args, NULL
        );
}
The load_add_kernel_9969bdda_0123 function loads the CUBIN (embedded as a static byte array) into the CUDA driver via cuModuleLoadData.

JIT Compilation: Accessing Artifacts at Runtime

When a kernel is compiled JIT, all intermediate representations are accessible via compiled_kernel.asm:
size   = 1024
x      = torch.rand(size, device='cuda')
y      = torch.rand(size, device='cuda')
output = torch.empty_like(x)

grid = lambda meta: (triton.cdiv(size, meta['BLOCK_SIZE']),)
compiled_kernel = add_kernel[grid](x, y, output, size, BLOCK_SIZE=1024)

# All stages are available:
print(compiled_kernel.asm.keys())
# dict_keys(['ttir', 'ttgir', 'llir', 'ptx', 'cubin'])

# Inspect the Triton IR
print(compiled_kernel.asm['ttir'])

# Inspect the PTX
print(compiled_kernel.asm['ptx'])
The disk cache (~/.triton/cache/) mirrors these artifacts:
$ ls ~/.triton/cache/z8LaDnJEt9PBQ6kqB-fckHL71-x-aUysRw-hpmbeuJc/
add_kernel.cubin
add_kernel.json
add_kernel.llir
add_kernel.ptx
add_kernel.ttgir
add_kernel.ttir
__grp__add_kernel.json
Use cuobjdump add_kernel.cubin -sass -ptx or nvdisasm -gi add_kernel.cubin to inspect the SASS (machine-level) instructions from the final CUBIN.

Triton and MLIR

Triton’s backend was completely rewritten in 2022 to use MLIR (PR #1004). MLIR (Multi-Level Intermediate Representation) is a flexible compiler infrastructure in the LLVM ecosystem. It provides:
  • Dialects: modular, extensible IR specifications. Triton defines tt (Triton), ttg (TritonGPU), and ttnvgpu (NVIDIA-specific) dialects.
  • Passes: composable transformations. Triton uses both standard MLIR passes (CSE, DCE, LICM) and custom passes for GPU optimization.
  • TableGen: DSL for generating MLIR boilerplate (op definitions, pass interfaces).
Enable verbose IR dumping with:
MLIR_ENABLE_DUMP=1 python vector_add.py
This prints the IR after every pass, which is invaluable for understanding what each transformation does.

Example: TritonGPUAccelerateMatmul

This pass maps tt.dot operations onto hardware Tensor Core MMA instructions. It is defined in TableGen:
def TritonGPUAccelerateMatmul
    : Pass<"tritongpu-accelerate-matmul", "mlir::ModuleOp"> {
  let summary = "accelerate matmul";
  let description = [{
    Optimize the input/output layout of `dot` instructions to make them
    compatible with hardware accelerators (e.g., NVIDIA Tensor Cores).
  }];
  let dependentDialects = [
    "mlir::triton::gpu::TritonGPUDialect",
    "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
    "mlir::triton::TritonDialect"
  ];
}
The TableGen definition generates a base class; the pass author implements runOnOperation() in C++:
class TritonGPUAccelerateMatmulPass
    : public impl::TritonGPUAccelerateMatmulBase<TritonGPUAccelerateMatmulPass> {
public:
  void runOnOperation() override {
    MLIRContext *context = &getContext();
    ModuleOp    m        = getOperation();

    auto computeCapability = getNVIDIAComputeCapability(m);

    mlir::RewritePatternSet patterns(context);
    patterns.add<BlockedToMMA>(context, computeCapability);

    if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed())
      signalPassFailure();

    // Decompose dot ops that are not natively supported by the MMA unit
    decomposeMixedModeDotOp(m, computeCapability);
  }
};

When Triton’s Compiler Choices Can Be Suboptimal

Triton’s automatic optimizations are excellent for the common case but can fall short in a few situations:

Shared memory allocation

Triton manages shared memory automatically. For kernels that need precise control over shared-memory bank conflicts or padding, you may need CUDA.

Warp-level primitives

__shfl_sync, __ballot_sync, and warp-level reductions have no direct Triton equivalents. Triton abstracts these away, sometimes leaving performance on the table.

Non-power-of-two shapes

Triton block sizes must be powers of two. For shapes where the optimal tile is not a power of two (e.g., 48), you may pay a masking penalty.

Custom memory hierarchies

Direct L1/L2 cache management, prefetch hints, or async copy tuning (e.g., cp.async.ca.shared.global variants) require hand-written PTX or CUDA.
For most ML operators — elementwise, reductions, attention, matmul — Triton’s compiler produces near-cuBLAS quality code with a fraction of the engineering effort.

Further Reading

Kapil's Deep-Dive Blog (Part 1)

Triton compilation pipeline walkthrough with annotated IR examples.

Triton Dialect Reference

Official reference for tt.* and ttg.* ops.

Original Triton Paper

Tillet, Kung & Cox (2019) — the academic foundation of the Triton language.

Practitioner's Guide (Lecture 14)

How to write, tune, and debug Triton kernels from scratch.

Build docs developers (and LLMs) love