Use this file to discover all available pages before exploring further.
Lecture 1, by Mark Saroufim, covers the full workflow for GPU kernel development: measure first with the PyTorch profiler and NVIDIA tools, then write and integrate a custom CUDA kernel using torch.utils.cpp_extension.load_inline. This page walks through each step with the real code from the lecture.
Source files for this lecture live in lectures/lecture_001/. The shared helpers referenced throughout — cuda_begin, load_cuda — live in lectures/utils.py.
CUDA execution is asynchronous. Python’s time.time() measures when the CPU launches a kernel, not when it finishes. Use CUDA events instead:
pytorch_square.py
import torchdef time_pytorch_function(func, input): # CUDA IS ASYNC so can't use python time module start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) # Warmup for _ in range(5): func(input) start.record() func(input) end.record() torch.cuda.synchronize() return start.elapsed_time(end)b = torch.randn(10000, 10000).cuda()def square_2(a): return a * adef square_3(a): return a ** 2time_pytorch_function(torch.square, b)time_pytorch_function(square_2, b)time_pytorch_function(square_3, b)
torch.cuda.synchronize() blocks until all CUDA kernels have finished, ensuring elapsed_time reflects real GPU wall time.
Run the same block for each variant you want to compare:
pytorch_square.py
print("=============")print("Profiling torch.square")print("=============")with torch.profiler.profile() as prof: torch.square(b)print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))print("=============")print("Profiling a * a")print("=============")with torch.profiler.profile() as prof: square_2(b)print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))print("=============")print("Profiling a ** 2")print("=============")with torch.profiler.profile() as prof: square_3(b)print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
NSight Systems (nsys) records a timeline of every kernel launch, memory transfer, and CPU-GPU synchronisation point. Use it to understand the shape of your workload before diving into per-kernel metrics.The lecture wraps the workload in a plain function so nsys captures a clean trace:
nsys_square.py
import torchfrom torch.profiler import profile, record_function, ProfilerActivitydef main(): for _ in range(100): a = torch.square(torch.randn(10000, 10000).cuda())if __name__ == "__main__": main()
Record and view the trace:
# Recordnsys profile -w true -t cuda,nvtx,osrt --capture-range=cudaProfilerApi \ -o nsys_square python nsys_square.py# Open in GUInsys-ui nsys_square.nsys-rep
Run the workload in a loop (as above) so the timeline has enough repetitions to show a clear pattern. A single kernel launch is hard to reason about in the timeline view.
Once nsys tells you which kernel to optimise, use ncu to measure it at the hardware-counter level: memory bandwidth utilisation, occupancy, warp stalls, and more.
# Profile all kernels launched by a scriptncu python pytorch_square.py# Save a full report for later analysisncu --set full -o square_kernel_report python pytorch_square.py
Running ncu against a script that JIT-compiles extensions with load_inline can fail with a CUDA initialisation error (error 36: API call not supported). Profile JIT-compiled kernels via nsys instead, or pass --target-processes all to ncu.
torch.utils.cpp_extension.load_inline compiles a CUDA kernel at runtime and exposes it as a Python-callable PyTorch extension — no separate build step required.
For lectures that follow the repository conventions, utils.py provides a thin wrapper around load_inline that enables -O3 by default and auto-names the extension:
utils.py
from torch.utils.cpp_extension import load_inlinecuda_begin = r'''#include <torch/extension.h>#include <stdio.h>#include <c10/cuda/CUDAException.h>#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)#define CUDA_ERR(ans) { gpuAssert((ans), __FILE__, __LINE__); }inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort=true){ if (code != cudaSuccess) { fprintf(stderr,"GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line); if (abort) exit(code); }}__host__ __device__ inline unsigned int cdiv(unsigned int a, unsigned int b) { return (a+b-1)/b;}'''def load_cuda(cuda_src, cpp_src, funcs, opt=True, verbose=False, name=None): "Simple wrapper for torch.utils.cpp_extension.load_inline" if name is None: name = funcs[0] flags = "-O3 -Xptxas -O3 -Xcompiler -O3" if opt else "-O0 -Xptxas -O0 -Xcompiler -O0" return load_inline(cuda_sources=[cuda_src], cpp_sources=[cpp_src], functions=funcs, extra_cuda_cflags=[flags], verbose=verbose, name=name)
Use it in your own kernels by prepending cuda_begin to your source string:
The triton_square.py file also includes a triton.testing.perf_report benchmark that compares Triton, native PyTorch, and torch.compile across a range of matrix sizes and produces a GB/s plot.
Environment setup
Install CUDA, PyTorch, and the profiling tools before running lecture code.
GPU Mode lectures repository
Browse all lecture source files, notebooks, and slides.