Use this file to discover all available pages before exploring further.
When you decorate a function with @triton.jit and call it with a grid, a sophisticated compilation pipeline quietly converts your Python DSL into a CUDA binary that runs at near-hardware speed. This guide is based on Lecture 29 by Kapil Sharma (Software Engineer @ Meta), who traces that pipeline end-to-end and compares it with the classical NVCC toolchain.Kapil has written a three-part deep-dive blog series that this lecture draws from: Part 1, Part 2, Part 3.
When Python encounters @triton.jit, the decorator inspects the function’s AST and translates it into Triton IR — a high-level, device-agnostic MLIR dialect (tt.* ops).Consider the canonical add_kernel from lecture_029/vector_add.py:
import torchimport tritonimport triton.language as tl@triton.jitdef add_kernel( x_ptr, # Pointer to first input vector y_ptr, # Pointer to second input vector output_ptr, # Pointer to output vector n_elements, # Size of the vector BLOCK_SIZE: tl.constexpr,): pid = tl.program_id(axis=0) block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) mask = offsets < n_elements x = tl.load(x_ptr + offsets, mask=mask) y = tl.load(y_ptr + offsets, mask=mask) output = x + y tl.store(output_ptr + offsets, output, mask=mask)
Notice that every operation works on rank-1 tensors of 1024 elements — the block has been materialized as a vector type, and the masked load/store pattern translates directly to tt.load / tt.store with the mask tensor %6.
The TritonGPU IR pass manager converts the device-agnostic tt.* ops into GPU-aware ttg.* ops. This is where decisions about warp and CTA layouts, shared memory usage, and instruction scheduling are made.Key passes applied in make_ttgir:
LLVM lowers the TritonGPU IR to LLVM IR, which the LLVM PTX backend converts to PTX. ptxas then assembles PTX into a CUBIN — an ELF-formatted binary that the CUDA driver loads into the GPU.
Triton’s AOT compilation tool (triton/tools/compile.py) also emits a C launcher that wraps the CUBIN as a byte array and exposes a pure-C entry point. This is how Triton kernels can be shipped without a Python runtime.The generated header (add_kernel.9969bdda_0123.h):
When a kernel is compiled JIT, all intermediate representations are accessible via compiled_kernel.asm:
size = 1024x = torch.rand(size, device='cuda')y = torch.rand(size, device='cuda')output = torch.empty_like(x)grid = lambda meta: (triton.cdiv(size, meta['BLOCK_SIZE']),)compiled_kernel = add_kernel[grid](x, y, output, size, BLOCK_SIZE=1024)# All stages are available:print(compiled_kernel.asm.keys())# dict_keys(['ttir', 'ttgir', 'llir', 'ptx', 'cubin'])# Inspect the Triton IRprint(compiled_kernel.asm['ttir'])# Inspect the PTXprint(compiled_kernel.asm['ptx'])
The disk cache (~/.triton/cache/) mirrors these artifacts:
$ ls ~/.triton/cache/z8LaDnJEt9PBQ6kqB-fckHL71-x-aUysRw-hpmbeuJc/add_kernel.cubinadd_kernel.jsonadd_kernel.lliradd_kernel.ptxadd_kernel.ttgiradd_kernel.ttir__grp__add_kernel.json
Use cuobjdump add_kernel.cubin -sass -ptx or nvdisasm -gi add_kernel.cubin to inspect the SASS (machine-level) instructions from the final CUBIN.
Triton’s backend was completely rewritten in 2022 to use MLIR (PR #1004). MLIR (Multi-Level Intermediate Representation) is a flexible compiler infrastructure in the LLVM ecosystem. It provides:
Dialects: modular, extensible IR specifications. Triton defines tt (Triton), ttg (TritonGPU), and ttnvgpu (NVIDIA-specific) dialects.
Passes: composable transformations. Triton uses both standard MLIR passes (CSE, DCE, LICM) and custom passes for GPU optimization.
TableGen: DSL for generating MLIR boilerplate (op definitions, pass interfaces).
Enable verbose IR dumping with:
MLIR_ENABLE_DUMP=1 python vector_add.py
This prints the IR after every pass, which is invaluable for understanding what each transformation does.
This pass maps tt.dot operations onto hardware Tensor Core MMA instructions. It is defined in TableGen:
def TritonGPUAccelerateMatmul : Pass<"tritongpu-accelerate-matmul", "mlir::ModuleOp"> { let summary = "accelerate matmul"; let description = [{ Optimize the input/output layout of `dot` instructions to make them compatible with hardware accelerators (e.g., NVIDIA Tensor Cores). }]; let dependentDialects = [ "mlir::triton::gpu::TritonGPUDialect", "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect", "mlir::triton::TritonDialect" ];}
The TableGen definition generates a base class; the pass author implements runOnOperation() in C++:
class TritonGPUAccelerateMatmulPass : public impl::TritonGPUAccelerateMatmulBase<TritonGPUAccelerateMatmulPass> {public: void runOnOperation() override { MLIRContext *context = &getContext(); ModuleOp m = getOperation(); auto computeCapability = getNVIDIAComputeCapability(m); mlir::RewritePatternSet patterns(context); patterns.add<BlockedToMMA>(context, computeCapability); if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) signalPassFailure(); // Decompose dot ops that are not natively supported by the MMA unit decomposeMixedModeDotOp(m, computeCapability); }};
Triton’s automatic optimizations are excellent for the common case but can fall short in a few situations:
Shared memory allocation
Triton manages shared memory automatically. For kernels that need precise control over shared-memory bank conflicts or padding, you may need CUDA.
Warp-level primitives
__shfl_sync, __ballot_sync, and warp-level reductions have no direct Triton equivalents. Triton abstracts these away, sometimes leaving performance on the table.
Non-power-of-two shapes
Triton block sizes must be powers of two. For shapes where the optimal tile is not a power of two (e.g., 48), you may pay a masking penalty.
Custom memory hierarchies
Direct L1/L2 cache management, prefetch hints, or async copy tuning (e.g., cp.async.ca.shared.global variants) require hand-written PTX or CUDA.
For most ML operators — elementwise, reductions, attention, matmul — Triton’s compiler produces near-cuBLAS quality code with a fraction of the engineering effort.