Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/gpu-mode/lectures/llms.txt

Use this file to discover all available pages before exploring further.

Efficient low-bit computation requires more than just choosing a number format — it requires kernels that fuse operations, exploit hardware-specific instructions, and minimize memory traffic. This page brings together techniques from three GPU Mode lectures: Liger Kernel (Lecture 28, Byron Hsu), BitBLAS (Lecture 33, Wang Lei), Low-Bit Triton Kernels (Lecture 34, Hicham Badri), and ARM low-bit kernels (Lecture 38, Scott Roy).

Liger Kernel

Liger Kernel is a collection of Triton kernels for training efficiency, developed at LinkedIn. It integrates directly with Hugging Face transformers models through a patching API, requiring no model code changes.

What it is

Liger provides fused Triton implementations of common LLM training operations:
  • RMSNorm: fused forward + backward, eliminates intermediate activation storage
  • Fused linear + cross entropy: combines the output projection and loss computation to avoid materializing the full vocabulary logit matrix
  • SwiGLU / GeGLU: fused gating activations
  • Rotary positional embeddings (RoPE): fused in-place application

Integrating with Hugging Face models

from liger_kernel.transformers import apply_liger_kernel_to_llama

# Patch a LLaMA model — no model code changes required
apply_liger_kernel_to_llama()

from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
# Model now uses Liger kernels for RMSNorm, SwiGLU, RoPE, etc.
You can also enable individual components:
from liger_kernel.transformers import apply_liger_kernel_to_llama

apply_liger_kernel_to_llama(
    rope=True,
    swiglu=True,
    rms_norm=True,
    fused_linear_cross_entropy=True,
)

Fused linear + cross entropy for memory savings

The largest single memory cost during LLM training is often the logit matrix: for a vocabulary size of 128,000 and a sequence length of 4,096 at BF16, the logit tensor alone is ~1 GB per batch. Liger’s fused linear cross-entropy kernel computes the loss in a chunked fashion without ever fully materializing the logit matrix.
from liger_kernel.transformers.functional import liger_cross_entropy

# Instead of:
# logits = model.lm_head(hidden_states)          # (B, T, vocab_size) — huge
# loss = F.cross_entropy(logits.view(-1, vocab), labels.view(-1))

# Use fused version (hidden_states passed directly to loss):
loss = liger_cross_entropy(
    hidden_states,          # (B, T, hidden_dim)
    model.lm_head.weight,   # (vocab_size, hidden_dim)
    labels,                 # (B, T)
    chunk_size=1024,        # process vocabulary in chunks
)
The fused linear cross entropy is the single highest-impact Liger optimization for training memory. Expect 20–40% reduction in peak activation memory on large vocabulary models.

RMSNorm

Liger’s RMSNorm kernel fuses the normalization, scale multiplication, and (optionally) the backward pass into a single Triton kernel, avoiding multiple reads and writes of the activation tensor.
from liger_kernel.ops.rms_norm import LigerRMSNormFunction

# Replaces torch.nn.RMSNorm in a training loop
output = LigerRMSNormFunction.apply(
    hidden_states, weight, eps=1e-6
)
The correctness verification notebook from Lecture 28 benchmarks Liger RMSNorm against PyTorch’s native implementation.

BitBLAS: bit-level BLAS operations

BitBLAS (Lecture 33, Wang Lei) is a library from Microsoft Research for matrix operations at sub-byte precisions (INT1, INT2, INT4). It uses TVM-based code generation to produce optimized CUDA kernels for each target architecture.

What BitBLAS provides

  • Matrix multiplication with weight precisions of 1, 2, 3, 4, 5, 6, 7, or 8 bits
  • Mixed-precision GEMM: low-bit weights × FP16 activations
  • Automatic kernel generation and tuning for the target GPU
  • Integration with torch as a drop-in layer

Using BitBLAS for INT4 weight-only GEMM

import bitblas

# Define a mixed-precision matmul: FP16 activation × INT4 weight
matmul_config = bitblas.MatmulConfig(
    M=1,          # batch size (use symbolic for dynamic shapes)
    N=4096,       # output features
    K=4096,       # input features
    A_dtype="float16",
    W_dtype="int4",
    accum_dtype="float16",
    out_dtype="float16",
    layout="nt",  # A: row-major, B: column-major (transposed)
    with_bias=False,
    group_size=128,     # per-group quantization scale
    with_scaling=True,
    with_zeros=True,
    zeros_mode="original",
)

matmul = bitblas.Matmul(config=matmul_config)

# Quantize and run
A = torch.randn(1, 4096, dtype=torch.float16, device="cuda")
W_quantized = matmul.transform_weight(W_fp16)  # pack INT4 weights
output = matmul(A, W_quantized)
BitBLAS auto-tunes kernels the first time they are called for a given configuration. This takes seconds to minutes but is cached. Use bitblas.auto_detect_nvidia_target() to select the right target for your GPU.

BitBLAS tutorials

The official tutorials from Lecture 33 include:
  • int4_weight_only_gemm.py — INT4 weight-only GEMM walkthrough
  • int2_weight_only_gemm.py — INT2 extreme quantization
  • fp16_int4_gemv.py — Optimized GEMV (single-token decode) with INT4 weights

Writing low-bit Triton kernels for INT2/INT4

Lecture 34 by Hicham Badri covers writing Triton kernels for sub-8-bit integer formats directly. The key challenge is that Triton operates on standard integer types; packing and unpacking must be done manually.

INT4 packing in Triton

import triton
import triton.language as tl

@triton.jit
def unpack_int4(packed: tl.tensor) -> tuple:
    """Unpack two INT4 values from a uint8 tensor."""
    low = packed & 0xF          # lower nibble: bits [3:0]
    high = (packed >> 4) & 0xF  # upper nibble: bits [7:4]
    # Convert from unsigned [0,15] to signed [-8,7]
    low = low - 8
    high = high - 8
    return low, high

@triton.jit
def int4_gemv_kernel(
    A_ptr,          # FP16 activations (1 × K)
    B_ptr,          # INT4 packed weights (N × K/2)
    S_ptr,          # FP16 scales (N × K//group_size)
    C_ptr,          # FP16 output (1 × N)
    K: tl.constexpr,
    GROUP_SIZE: tl.constexpr,
    BLOCK_N: tl.constexpr,
):
    pid = tl.program_id(0)
    n_start = pid * BLOCK_N

    acc = tl.zeros((BLOCK_N,), dtype=tl.float32)
    for k in range(0, K, GROUP_SIZE):
        # Load packed INT4 weights for this group
        b_packed = tl.load(B_ptr + (n_start + tl.arange(0, BLOCK_N)) * K // 2 + k // 2)
        b_low, b_high = unpack_int4(b_packed)

        # Load and apply scale
        scale = tl.load(S_ptr + (n_start + tl.arange(0, BLOCK_N)) * (K // GROUP_SIZE) + k // GROUP_SIZE)

        # Load activations
        a = tl.load(A_ptr + k + tl.arange(0, GROUP_SIZE // 2) * 2)

        # Accumulate (simplified — real implementation interleaves low/high)
        acc += tl.sum(a * b_low.to(tl.float32) * scale[:, None], axis=1)

    tl.store(C_ptr + n_start + tl.arange(0, BLOCK_N), acc.to(tl.float16))
Triton does not have a native INT4 type. You must pack/unpack INT4 values using bitwise operations on INT8/UINT8 tensors. The packing layout (low nibble first vs. high nibble first) must match between the quantization step and the kernel.

ARM CPU low-bit kernels (Lecture 38)

Lecture 38 by Scott Roy covers low-bit kernel design for ARM CPUs, targeting deployment on mobile and edge devices where NVIDIA GPUs are unavailable.

ARM NEON and SVE for low-bit inference

ARM processors include SIMD extensions (NEON on mobile, SVE/SVE2 on server) that can be used for INT4 and INT8 matrix operations:
#include <arm_neon.h>

// INT8 dot product using ARM NEON
void int8_dot_product_neon(
    const int8_t* a, const int8_t* b, int32_t* c, int n) {
  int32x4_t acc = vdupq_n_s32(0);
  for (int i = 0; i < n; i += 16) {
    int8x16_t va = vld1q_s8(a + i);
    int8x16_t vb = vld1q_s8(b + i);
    // Widen and multiply-accumulate
    int16x8_t prod_lo = vmull_s8(vget_low_s8(va), vget_low_s8(vb));
    int16x8_t prod_hi = vmull_s8(vget_high_s8(va), vget_high_s8(vb));
    acc = vpadalq_s16(acc, prod_lo);
    acc = vpadalq_s16(acc, prod_hi);
  }
  *c = vaddvq_s32(acc);
}
INT4 on ARM requires manual unpacking from packed uint8 storage, similar to the Triton approach above. See Lecture 38 slides for detailed ARM-specific optimization patterns.

Approach comparison

ApproachTarget hardwarePrecisionEffortBest for
Liger KernelNVIDIA GPUFP16/BF16Low (patching API)Training memory reduction
BitBLASNVIDIA GPUINT1–INT8Medium (config + tuning)Extreme quantization inference
Custom Triton (INT4)NVIDIA GPUINT4/INT2HighResearch, non-standard formats
ARM NEON/SVEARM CPUINT4/INT8HighMobile / edge deployment
For production LLM fine-tuning on NVIDIA GPUs, Liger Kernel is the fastest path to memory savings with minimal code changes. For INT4 inference deployment, BitBLAS provides the broadest hardware support and highest throughput.

Further reading

Build docs developers (and LLMs) love