Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/gpu-mode/lectures/llms.txt

Use this file to discover all available pages before exploring further.

Lecture 1, by Mark Saroufim, covers the full workflow for GPU kernel development: measure first with the PyTorch profiler and NVIDIA tools, then write and integrate a custom CUDA kernel using torch.utils.cpp_extension.load_inline. This page walks through each step with the real code from the lecture.
Source files for this lecture live in lectures/lecture_001/. The shared helpers referenced throughout — cuda_begin, load_cuda — live in lectures/utils.py.

Timing GPU operations correctly

CUDA execution is asynchronous. Python’s time.time() measures when the CPU launches a kernel, not when it finishes. Use CUDA events instead:
pytorch_square.py
import torch

def time_pytorch_function(func, input):
    # CUDA IS ASYNC so can't use python time module
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    # Warmup
    for _ in range(5):
        func(input)

    start.record()
    func(input)
    end.record()
    torch.cuda.synchronize()
    return start.elapsed_time(end)

b = torch.randn(10000, 10000).cuda()

def square_2(a):
    return a * a

def square_3(a):
    return a ** 2

time_pytorch_function(torch.square, b)
time_pytorch_function(square_2, b)
time_pytorch_function(square_3, b)
torch.cuda.synchronize() blocks until all CUDA kernels have finished, ensuring elapsed_time reflects real GPU wall time.

PyTorch profiler

The PyTorch profiler gives a per-operator breakdown of CPU and CUDA time. It is the quickest way to identify which operation is the hotspot.

Basic usage

pytorch_square.py
import torch
from torch.profiler import profile

b = torch.randn(10000, 10000).cuda()

with torch.profiler.profile() as prof:
    torch.square(b)

print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
Run the same block for each variant you want to compare:
pytorch_square.py
print("=============")
print("Profiling torch.square")
print("=============")

with torch.profiler.profile() as prof:
    torch.square(b)

print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

print("=============")
print("Profiling a * a")
print("=============")

with torch.profiler.profile() as prof:
    square_2(b)

print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

print("=============")
print("Profiling a ** 2")
print("=============")

with torch.profiler.profile() as prof:
    square_3(b)

print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

Profiler with schedule and Chrome trace export

For training loops, use a schedule to skip warm-up iterations and limit how many steps are recorded:
pt_profiler.py
import torch
from torch.profiler import profile, ProfilerActivity

def trace_handler(prof):
    print(prof.key_averages().table(
        sort_by="self_cuda_time_total", row_limit=-1))
    prof.export_chrome_trace("/tmp/test_trace_" + str(prof.step_num) + ".json")

with torch.profiler.profile(
    activities=[
        torch.profiler.ProfilerActivity.CPU,
        torch.profiler.ProfilerActivity.CUDA,
    ],
    # wait=1: skip first iteration
    # warmup=1: warm up on second
    # active=2: record iterations 3 and 4
    # repeat=1: one cycle
    schedule=torch.profiler.schedule(
        wait=1,
        warmup=1,
        active=2,
        repeat=1),
    on_trace_ready=trace_handler
    ) as p:
        for iter in range(10):
            torch.square(torch.randn(10000, 10000).cuda())
            p.step()
Open the exported trace.json in Chrome at chrome://tracing or in Perfetto for a visual timeline.

NSight Systems for timeline profiling

NSight Systems (nsys) records a timeline of every kernel launch, memory transfer, and CPU-GPU synchronisation point. Use it to understand the shape of your workload before diving into per-kernel metrics. The lecture wraps the workload in a plain function so nsys captures a clean trace:
nsys_square.py
import torch
from torch.profiler import profile, record_function, ProfilerActivity

def main():
    for _ in range(100):
        a = torch.square(torch.randn(10000, 10000).cuda())

if __name__ == "__main__":
    main()
Record and view the trace:
# Record
nsys profile -w true -t cuda,nvtx,osrt --capture-range=cudaProfilerApi \
  -o nsys_square python nsys_square.py

# Open in GUI
nsys-ui nsys_square.nsys-rep
Run the workload in a loop (as above) so the timeline has enough repetitions to show a clear pattern. A single kernel launch is hard to reason about in the timeline view.

NSight Compute for kernel-level metrics

Once nsys tells you which kernel to optimise, use ncu to measure it at the hardware-counter level: memory bandwidth utilisation, occupancy, warp stalls, and more.
# Profile all kernels launched by a script
ncu python pytorch_square.py

# Save a full report for later analysis
ncu --set full -o square_kernel_report python pytorch_square.py
Running ncu against a script that JIT-compiles extensions with load_inline can fail with a CUDA initialisation error (error 36: API call not supported). Profile JIT-compiled kernels via nsys instead, or pass --target-processes all to ncu.

Writing inline CUDA kernels with load_inline

torch.utils.cpp_extension.load_inline compiles a CUDA kernel at runtime and exposes it as a Python-callable PyTorch extension — no separate build step required.

Minimal example

main.py
import torch
import torch.utils.cpp_extension

cuda_kernel = """
extern "C" __global__
void square_kernel(const float* __restrict__ input, float* __restrict__ output, int size) {
    const int index = blockIdx.x * blockDim.x + threadIdx.x;
    if (index < size) {
        output[index] = input[index] * input[index];
    }
}
"""

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
module = torch.utils.cpp_extension.load_inline(
    name='square',
    cpp_sources='',
    cuda_sources=cuda_kernel,
    functions=['square_kernel']
)

def square(input):
    output = torch.empty_like(input)
    threads_per_block = 1024
    blocks_per_grid = (input.numel() + (threads_per_block - 1)) // threads_per_block
    module.square_kernel(blocks_per_grid, threads_per_block, input, output, input.numel())
    return output

input_tensor = torch.randn(100, device=device)
output_tensor = square(input_tensor)

2D kernel with a C++ wrapper

The full lecture example uses a 2D thread grid and a proper C++ wrapper that torch can call directly:
load_inline.py
import torch
from torch.utils.cpp_extension import load_inline

cuda_source = '''
__global__ void square_matrix_kernel(const float* matrix, float* result, int width, int height) {
    int row = blockIdx.y * blockDim.y + threadIdx.y;
    int col = blockIdx.x * blockDim.x + threadIdx.x;

    if (row < height && col < width) {
        int idx = row * width + col;
        result[idx] = matrix[idx] * matrix[idx];
    }
}

torch::Tensor square_matrix(torch::Tensor matrix) {
    const auto height = matrix.size(0);
    const auto width = matrix.size(1);

    auto result = torch::empty_like(matrix);

    dim3 threads_per_block(16, 16);
    dim3 number_of_blocks((width + threads_per_block.x - 1) / threads_per_block.x,
                          (height + threads_per_block.y - 1) / threads_per_block.y);

    square_matrix_kernel<<<number_of_blocks, threads_per_block>>>(
        matrix.data_ptr<float>(), result.data_ptr<float>(), width, height);

    return result;
    }
'''

cpp_source = "torch::Tensor square_matrix(torch::Tensor matrix);"

square_matrix_extension = load_inline(
    name='square_matrix_extension',
    cpp_sources=cpp_source,
    cuda_sources=cuda_source,
    functions=['square_matrix'],
    with_cuda=True,
    extra_cuda_cflags=["-O2"],
    build_directory='./load_inline_cuda',
)

a = torch.tensor([[1., 2., 3.], [4., 5., 6.]], device='cuda')
print(square_matrix_extension.square_matrix(a))
# tensor([[ 1.,  4.,  9.],
#         [16., 25., 36.]], device='cuda:0')

The load_cuda helper from utils.py

For lectures that follow the repository conventions, utils.py provides a thin wrapper around load_inline that enables -O3 by default and auto-names the extension:
utils.py
from torch.utils.cpp_extension import load_inline

cuda_begin = r'''
#include <torch/extension.h>
#include <stdio.h>
#include <c10/cuda/CUDAException.h>

#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
#define CUDA_ERR(ans) { gpuAssert((ans), __FILE__, __LINE__); }
inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort=true)
{
   if (code != cudaSuccess)
   {
      fprintf(stderr,"GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line);
      if (abort) exit(code);
   }
}
__host__ __device__ inline unsigned int cdiv(unsigned int a, unsigned int b) { return (a+b-1)/b;}
'''

def load_cuda(cuda_src, cpp_src, funcs, opt=True, verbose=False, name=None):
    "Simple wrapper for torch.utils.cpp_extension.load_inline"
    if name is None: name = funcs[0]
    flags = "-O3 -Xptxas -O3 -Xcompiler -O3" if opt else "-O0 -Xptxas -O0 -Xcompiler -O0"
    return load_inline(cuda_sources=[cuda_src], cpp_sources=[cpp_src], functions=funcs,
                       extra_cuda_cflags=[flags], verbose=verbose, name=name)
Use it in your own kernels by prepending cuda_begin to your source string:
from utils import cuda_begin, load_cuda

cuda_src = cuda_begin + r'''
__global__ void my_kernel(...) { ... }

torch::Tensor my_func(torch::Tensor x) {
    CHECK_INPUT(x);
    ...
}
'''

cpp_src = "torch::Tensor my_func(torch::Tensor x);"

module = load_cuda(cuda_src, cpp_src, ['my_func'])

Writing a Triton kernel for comparison

The lecture also shows how the same square operation looks in Triton. This is useful for benchmarking Triton against hand-written CUDA:
triton_square.py
import triton
import triton.language as tl
import torch

@triton.jit
def square_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE: tl.constexpr):
    row_idx = tl.program_id(0)
    row_start_ptr = input_ptr + row_idx * input_row_stride
    col_offsets = tl.arange(0, BLOCK_SIZE)
    input_ptrs = row_start_ptr + col_offsets
    row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf'))

    square_output = row * row

    output_row_start_ptr = output_ptr + row_idx * output_row_stride
    output_ptrs = output_row_start_ptr + col_offsets
    tl.store(output_ptrs, square_output, mask=col_offsets < n_cols)


def square(x):
    n_rows, n_cols = x.shape
    BLOCK_SIZE = triton.next_power_of_2(n_cols)
    num_warps = 4
    if BLOCK_SIZE >= 2048:
        num_warps = 8
    if BLOCK_SIZE >= 4096:
        num_warps = 16
    y = torch.empty_like(x)
    square_kernel[(n_rows, )](
        y, x,
        x.stride(0), y.stride(0),
        n_cols,
        num_warps=num_warps,
        BLOCK_SIZE=BLOCK_SIZE,
    )
    return y


torch.manual_seed(0)
x = torch.randn(1823, 781, device='cuda')
y_triton = square(x)
y_torch = torch.square(x)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
The triton_square.py file also includes a triton.testing.perf_report benchmark that compares Triton, native PyTorch, and torch.compile across a range of matrix sizes and produces a GB/s plot.

Environment setup

Install CUDA, PyTorch, and the profiling tools before running lecture code.

GPU Mode lectures repository

Browse all lecture source files, notebooks, and slides.

Build docs developers (and LLMs) love