Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/gpu-mode/lectures/llms.txt

Use this file to discover all available pages before exploring further.

AMD’s GPU stack has matured considerably with ROCm, HIP, and the Composable Kernel (CK) library. For practitioners coming from CUDA, the transition is mostly mechanical — HIP is deliberately API-compatible — but squeezing peak performance on AMD Instinct accelerators requires understanding the differences in memory hierarchy, matrix units, and the tile-based programming model that CK exposes. This page accompanies Lecture 25 by Haocong Wang.
Lecture 25 slides are available in the lecture repository at lecture_025/AMD_ROCm_Speaking_Composable_Kernel_July_20_2024.pdf.

AMD GPU architecture: RDNA vs. CDNA

AMD produces two distinct GPU microarchitecture families with very different design goals:

RDNA (consumer/gaming)

Optimized for rasterization throughput and display output. Found in Radeon RX series. Has limited FP64 and no MFMA matrix units — not the primary target for ML training workloads.

CDNA (compute/datacenter)

Optimized for HPC and ML. Found in AMD Instinct MI100, MI200, MI300 series. Features large register files, high-bandwidth HBM memory, and the MFMA (Matrix Fused Multiply-Add) instruction set.

MI200 and MI300 series highlights

FeatureMI200 (CDNA2)MI300X (CDNA3)
ArchitectureCDNA2CDNA3
HBM capacity128 GB (2× GPU)192 GB
HBM bandwidth3.2 TB/s5.3 TB/s
Peak FP16 TFLOPS3831307
Matrix unitsMFMAMFMA
NVLink equivalentInfinity FabricInfinity Fabric
The MI300X is AMD’s answer to the H100 SXM — it integrates CPU and GPU dies on the same package and is targeted at large-model inference where memory capacity is the binding constraint.
The MI300X’s 192 GB of unified HBM makes it possible to run 70B+ parameter models in full FP16 on a single card, without model parallelism. This is its primary competitive advantage over NVIDIA A100/H100 for inference.

ROCm: AMD’s GPU computing platform

ROCm (Radeon Open Compute) is the open-source software stack that sits between hardware and user code, analogous to CUDA Toolkit. It includes:
  • HIP runtime: CUDA-compatible API layer
  • rocBLAS / hipBLAS: BLAS for AMD GPUs (equivalent to cuBLAS)
  • MIOpen: DNN primitives (equivalent to cuDNN)
  • rocPROF: GPU profiler (equivalent to Nsight)
  • Composable Kernel: high-performance kernel library (unique to AMD)
# Install ROCm (Ubuntu, AMD GPU)
wget https://repo.radeon.com/amdgpu-install/6.1.2/ubuntu/jammy/amdgpu-install_6.1.60102-1_all.deb
sudo apt install ./amdgpu-install_6.1.60102-1_all.deb
sudo amdgpu-install --usecase=rocm

# Verify installation
rocm-smi
hipcc --version

HIP: CUDA-compatible programming model

HIP (Heterogeneous-computing Interface for Portability) mirrors the CUDA API almost exactly. Most CUDA code can be mechanically converted with the hipify-perl tool:
# Convert a CUDA source file to HIP
hipify-perl my_kernel.cu > my_kernel.hip
Side by side, the syntax differences are minimal:
__global__ void vector_add(const float *a, const float *b,
                           float *c, int n) {
  int idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (idx < n) c[idx] = a[idx] + b[idx];
}

// Launch
vector_add<<<grid, block>>>(d_a, d_b, d_c, n);
cudaDeviceSynchronize();
HIP code compiles to either AMD ROCm targets or NVIDIA CUDA by setting HIP_PLATFORM:
# Compile for AMD GPU
hipcc --offload-arch=gfx90a my_kernel.hip -o my_kernel

# Compile for NVIDIA GPU (via CUDA backend)
HIP_PLATFORM=nvidia hipcc my_kernel.hip -o my_kernel
The wavefront (AMD’s equivalent of a CUDA warp) is 64 threads wide on CDNA architecture (compared to NVIDIA’s 32). This is a meaningful architectural difference — occupancy calculations and register pressure analysis differ from CUDA accordingly. RDNA uses 32-thread wavefronts.

Composable Kernel (CK) overview

Composable Kernel is AMD’s high-performance kernel library for ML workloads on ROCm. Unlike cuBLAS or rocBLAS, CK is open-source and designed to be composed — library users can mix and match operation types, data types, and layout configurations without forking kernel code. CK’s key design principles:
  • Tile-based programming: all work is expressed as operations on tiles that map to the GPU memory hierarchy (global → LDS → registers).
  • Template metaprogramming: kernel configurations (tile sizes, pipeline stages, instruction types) are compile-time template parameters, not runtime switches.
  • MFMA-first: the inner loop targets AMD’s MFMA matrix instructions directly, rather than relying on the compiler to discover them.
# Clone and build CK
git clone https://github.com/ROCm/composable_kernel
cd composable_kernel
mkdir build && cd build
cmake -DCMAKE_BUILD_TYPE=Release \
      -DGPU_TARGETS="gfx90a;gfx940" \  # MI250X; MI300
      ..
make -j$(nproc) ck_gemm_example

CK’s tile-based programming model

CK decomposes a GEMM into a three-level tile hierarchy that maps directly onto AMD GPU memory levels:
Grid of threadblocks
  └── Threadblock tile  (C tile in LDS, A/B tiles streamed from global)
        └── Warp tile   (C fragment in registers, MFMA instruction output)
              └── MFMA instruction  (native 16×16×16 or 32×32×8 matrix multiply)
The core abstraction is the GridwiseGemm template. You specify:
  • BlockSize: threads per threadblock
  • MPerBlock, NPerBlock, KPerBlock: tile dimensions at the threadblock level
  • MPerWave, NPerWave: tile dimensions at the wavefront level
  • MRepeat, NRepeat: how many MFMA results each thread accumulates
This explicit hierarchy lets CK generate kernels that saturate MFMA throughput while hiding global memory latency via double-buffering in LDS.

Writing a GEMM with Composable Kernel

CK provides both a high-level “client” API and direct access to building-block templates. The client API is the recommended starting point:
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"

using F16        = ck::half_t;
using F32        = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;

// Select a pre-configured GEMM instance (FP16 in, FP32 accumulate, FP16 out)
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl<
    F16,          // ADataType
    F16,          // BDataType
    F16,          // CDataType
    F32,          // AccDataType
    PassThrough,  // AElementOp
    PassThrough,  // BElementOp
    PassThrough,  // CElementOp
    256,          // BlockSize
    128,          // MPerBlock
    128,          // NPerBlock
    32,           // KPerBlock
    8,            // AK1
    8,            // BK1
    32,           // MPerXDL  (MFMA tile M)
    32,           // NPerXDL  (MFMA tile N)
    2,            // MXdlPerWave
    2>;           // NXdlPerWave

int run_gemm(int M, int N, int K,
             const ck::half_t *d_A,
             const ck::half_t *d_B,
             ck::half_t       *d_C) {
  auto gemm    = DeviceGemmInstance{};
  auto invoker = gemm.MakeInvoker();
  auto argument = gemm.MakeArgument(
      d_A, d_B, d_C,
      M, N, K,
      K,    // lda (row-major)
      N,    // ldb (row-major)
      N,    // ldc
      PassThrough{}, PassThrough{}, PassThrough{});

  if (!gemm.IsSupportedArgument(argument)) {
    throw std::runtime_error("Unsupported GEMM configuration");
  }
  invoker.Run(argument, StreamConfig{nullptr, false});
  return 0;
}
CK ships with a ckProfiler tool that benchmarks all available instances for a given problem shape and selects the best one automatically — similar to cublasLt’s heuristic search.

The MFMA instruction

At the hardware level, CK targets the v_mfma_* instructions. For FP16:
// Inline assembly: 32×32×8 FP16 matrix multiply-accumulate
// Reads 32×8 from A and 8×32 from B, accumulates into 32×32 FP32 fragment
__builtin_amdgcn_mfma_f32_32x32x8f16(a_frag, b_frag, c_frag, 0, 0, 0);
You rarely call MFMA intrinsics directly — CK’s template machinery generates the correct instruction variant and register layout based on your template parameters.

Key differences from CUDA and cuBLAS

Understanding these differences avoids subtle performance bugs when porting CUDA kernels to ROCm:
AspectCUDA / NVIDIAHIP / AMD CDNA
Warp / wavefront width32 threads64 threads (CDNA), 32 (RDNA)
Shared memory48–228 KB (configurable)LDS: 64 KB per CU
Matrix instructionWMMA / MMA (Tensor Core)MFMA (v_mfma_*)
WMMA fragment size16×1616×16 or 32×32
L1 cachePer SM, configurableL1 per CU, less configurable
Occupancy tuningcudaFuncSetCacheConfigLDS and register pressure govern occupancy directly
Profilingncu, Nsight Systemsrocprof, Omniperf
CUDA’s __shfl_* warp shuffle intrinsics map to HIP’s __shfl_* equivalents, but the shuffle operates across 64-thread wavefronts on CDNA. Algorithms that assume 32-thread warps (e.g., warp-level reductions with 5 shuffle rounds) need adjustment — you need 6 rounds for 64 lanes.

Profiling with rocprof and Omniperf

# Basic kernel timing
rocprof --stats ./your_hip_binary

# Hardware counter collection
rocprof --hsa-trace \
        -i counters.txt \
        ./your_hip_binary

# counters.txt example:
# pmc: SQ_INSTS_VALU_MFMA_MOPS_F16 SQ_INSTS_VALU FETCH_SIZE
Omniperf provides a richer, web-based analysis:
pip install omniperf
omniperf profile -n my_gemm -- ./my_gemm_binary
omniperf analyze -p workloads/my_gemm/

SYCL mode: Intel GPU portability (Lecture 26)

Lecture 26 by Patric Zhao covers SYCL MODE — running the same compute kernels on Intel GPUs using SYCL (via Intel’s oneAPI DPC++ compiler). SYCL is a Khronos standard C++ abstraction layer that runs on top of OpenCL, Level Zero (Intel), CUDA (via Codeplay), and HIP (via AMD’s implementation).
#include <sycl/sycl.hpp>

void vector_add_sycl(sycl::queue &q,
                     const float *a, const float *b,
                     float *c, int n) {
  q.parallel_for(sycl::range<1>(n), [=](sycl::id<1> idx) {
    c[idx] = a[idx] + b[idx];
  }).wait();
}

Lecture 26: SYCL MODE

Patric Zhao’s slides on Intel GPU programming with SYCL and oneAPI

Intel oneAPI docs

Official Intel oneAPI developer documentation and DPC++ reference

Cross-platform portability considerations

Writing portable GPU kernels that run on NVIDIA, AMD, and Intel hardware is achievable with the right abstraction layer:
1

Use HIP as the primary layer

HIP compiles to both CUDA and ROCm. For NVIDIA targets, set HIP_PLATFORM=nvidia; for AMD, use the default ROCm backend. This covers ~90% of use cases with minimal code changes.
2

Abstract matrix instructions behind compile-time dispatch

Use #ifdef __HIP_PLATFORM_AMD__ to select MFMA vs. #ifdef __CUDA_ARCH__ for WMMA. CK and CUTLASS handle this internally — prefer library APIs over raw intrinsics when possible.
3

Profile on each target separately

Optimal tile sizes and pipeline depths differ between MI300X and H100. A kernel tuned for one will likely underperform on the other by 20–40%. Use autotuning frameworks (CK profiler, CUTLASS profiler) rather than hardcoding tile sizes.
4

Consider Triton for new kernels

Triton generates both CUDA PTX and AMD GCN ISA from the same Python source. For new kernels where peak performance is not yet critical, Triton offers a faster path to cross-platform coverage than maintaining dual HIP/CUDA implementations.

Composable Kernel on GitHub

Source code, examples, and instance profiler for AMD’s CK library

Lecture 25 slides

Haocong Wang’s full slide deck: Speaking Composable Kernel (July 2024)

ROCm documentation

AMD’s official ROCm programming guide and API reference

HIP porting guide

Official guide for porting CUDA code to HIP/ROCm

Build docs developers (and LLMs) love