Documentation Index
Fetch the complete documentation index at: https://mintlify.com/gpu-mode/lectures/llms.txt
Use this file to discover all available pages before exploring further.
Efficient low-bit computation requires more than just choosing a number format — it requires kernels that fuse operations, exploit hardware-specific instructions, and minimize memory traffic. This page brings together techniques from three GPU Mode lectures: Liger Kernel (Lecture 28, Byron Hsu), BitBLAS (Lecture 33, Wang Lei), Low-Bit Triton Kernels (Lecture 34, Hicham Badri), and ARM low-bit kernels (Lecture 38, Scott Roy).
Liger Kernel
Liger Kernel is a collection of Triton kernels for training efficiency, developed at LinkedIn. It integrates directly with Hugging Face transformers models through a patching API, requiring no model code changes.
What it is
Liger provides fused Triton implementations of common LLM training operations:
- RMSNorm: fused forward + backward, eliminates intermediate activation storage
- Fused linear + cross entropy: combines the output projection and loss computation to avoid materializing the full vocabulary logit matrix
- SwiGLU / GeGLU: fused gating activations
- Rotary positional embeddings (RoPE): fused in-place application
Integrating with Hugging Face models
from liger_kernel.transformers import apply_liger_kernel_to_llama
# Patch a LLaMA model — no model code changes required
apply_liger_kernel_to_llama()
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
# Model now uses Liger kernels for RMSNorm, SwiGLU, RoPE, etc.
You can also enable individual components:
from liger_kernel.transformers import apply_liger_kernel_to_llama
apply_liger_kernel_to_llama(
rope=True,
swiglu=True,
rms_norm=True,
fused_linear_cross_entropy=True,
)
Fused linear + cross entropy for memory savings
The largest single memory cost during LLM training is often the logit matrix: for a vocabulary size of 128,000 and a sequence length of 4,096 at BF16, the logit tensor alone is ~1 GB per batch. Liger’s fused linear cross-entropy kernel computes the loss in a chunked fashion without ever fully materializing the logit matrix.
from liger_kernel.transformers.functional import liger_cross_entropy
# Instead of:
# logits = model.lm_head(hidden_states) # (B, T, vocab_size) — huge
# loss = F.cross_entropy(logits.view(-1, vocab), labels.view(-1))
# Use fused version (hidden_states passed directly to loss):
loss = liger_cross_entropy(
hidden_states, # (B, T, hidden_dim)
model.lm_head.weight, # (vocab_size, hidden_dim)
labels, # (B, T)
chunk_size=1024, # process vocabulary in chunks
)
The fused linear cross entropy is the single highest-impact Liger optimization for training memory. Expect 20–40% reduction in peak activation memory on large vocabulary models.
RMSNorm
Liger’s RMSNorm kernel fuses the normalization, scale multiplication, and (optionally) the backward pass into a single Triton kernel, avoiding multiple reads and writes of the activation tensor.
from liger_kernel.ops.rms_norm import LigerRMSNormFunction
# Replaces torch.nn.RMSNorm in a training loop
output = LigerRMSNormFunction.apply(
hidden_states, weight, eps=1e-6
)
The correctness verification notebook from Lecture 28 benchmarks Liger RMSNorm against PyTorch’s native implementation.
BitBLAS: bit-level BLAS operations
BitBLAS (Lecture 33, Wang Lei) is a library from Microsoft Research for matrix operations at sub-byte precisions (INT1, INT2, INT4). It uses TVM-based code generation to produce optimized CUDA kernels for each target architecture.
What BitBLAS provides
- Matrix multiplication with weight precisions of 1, 2, 3, 4, 5, 6, 7, or 8 bits
- Mixed-precision GEMM: low-bit weights × FP16 activations
- Automatic kernel generation and tuning for the target GPU
- Integration with
torch as a drop-in layer
Using BitBLAS for INT4 weight-only GEMM
import bitblas
# Define a mixed-precision matmul: FP16 activation × INT4 weight
matmul_config = bitblas.MatmulConfig(
M=1, # batch size (use symbolic for dynamic shapes)
N=4096, # output features
K=4096, # input features
A_dtype="float16",
W_dtype="int4",
accum_dtype="float16",
out_dtype="float16",
layout="nt", # A: row-major, B: column-major (transposed)
with_bias=False,
group_size=128, # per-group quantization scale
with_scaling=True,
with_zeros=True,
zeros_mode="original",
)
matmul = bitblas.Matmul(config=matmul_config)
# Quantize and run
A = torch.randn(1, 4096, dtype=torch.float16, device="cuda")
W_quantized = matmul.transform_weight(W_fp16) # pack INT4 weights
output = matmul(A, W_quantized)
BitBLAS auto-tunes kernels the first time they are called for a given configuration. This takes seconds to minutes but is cached. Use bitblas.auto_detect_nvidia_target() to select the right target for your GPU.
BitBLAS tutorials
The official tutorials from Lecture 33 include:
int4_weight_only_gemm.py — INT4 weight-only GEMM walkthrough
int2_weight_only_gemm.py — INT2 extreme quantization
fp16_int4_gemv.py — Optimized GEMV (single-token decode) with INT4 weights
Writing low-bit Triton kernels for INT2/INT4
Lecture 34 by Hicham Badri covers writing Triton kernels for sub-8-bit integer formats directly. The key challenge is that Triton operates on standard integer types; packing and unpacking must be done manually.
INT4 packing in Triton
import triton
import triton.language as tl
@triton.jit
def unpack_int4(packed: tl.tensor) -> tuple:
"""Unpack two INT4 values from a uint8 tensor."""
low = packed & 0xF # lower nibble: bits [3:0]
high = (packed >> 4) & 0xF # upper nibble: bits [7:4]
# Convert from unsigned [0,15] to signed [-8,7]
low = low - 8
high = high - 8
return low, high
@triton.jit
def int4_gemv_kernel(
A_ptr, # FP16 activations (1 × K)
B_ptr, # INT4 packed weights (N × K/2)
S_ptr, # FP16 scales (N × K//group_size)
C_ptr, # FP16 output (1 × N)
K: tl.constexpr,
GROUP_SIZE: tl.constexpr,
BLOCK_N: tl.constexpr,
):
pid = tl.program_id(0)
n_start = pid * BLOCK_N
acc = tl.zeros((BLOCK_N,), dtype=tl.float32)
for k in range(0, K, GROUP_SIZE):
# Load packed INT4 weights for this group
b_packed = tl.load(B_ptr + (n_start + tl.arange(0, BLOCK_N)) * K // 2 + k // 2)
b_low, b_high = unpack_int4(b_packed)
# Load and apply scale
scale = tl.load(S_ptr + (n_start + tl.arange(0, BLOCK_N)) * (K // GROUP_SIZE) + k // GROUP_SIZE)
# Load activations
a = tl.load(A_ptr + k + tl.arange(0, GROUP_SIZE // 2) * 2)
# Accumulate (simplified — real implementation interleaves low/high)
acc += tl.sum(a * b_low.to(tl.float32) * scale[:, None], axis=1)
tl.store(C_ptr + n_start + tl.arange(0, BLOCK_N), acc.to(tl.float16))
Triton does not have a native INT4 type. You must pack/unpack INT4 values using bitwise operations on INT8/UINT8 tensors. The packing layout (low nibble first vs. high nibble first) must match between the quantization step and the kernel.
ARM CPU low-bit kernels (Lecture 38)
Lecture 38 by Scott Roy covers low-bit kernel design for ARM CPUs, targeting deployment on mobile and edge devices where NVIDIA GPUs are unavailable.
ARM NEON and SVE for low-bit inference
ARM processors include SIMD extensions (NEON on mobile, SVE/SVE2 on server) that can be used for INT4 and INT8 matrix operations:
#include <arm_neon.h>
// INT8 dot product using ARM NEON
void int8_dot_product_neon(
const int8_t* a, const int8_t* b, int32_t* c, int n) {
int32x4_t acc = vdupq_n_s32(0);
for (int i = 0; i < n; i += 16) {
int8x16_t va = vld1q_s8(a + i);
int8x16_t vb = vld1q_s8(b + i);
// Widen and multiply-accumulate
int16x8_t prod_lo = vmull_s8(vget_low_s8(va), vget_low_s8(vb));
int16x8_t prod_hi = vmull_s8(vget_high_s8(va), vget_high_s8(vb));
acc = vpadalq_s16(acc, prod_lo);
acc = vpadalq_s16(acc, prod_hi);
}
*c = vaddvq_s32(acc);
}
INT4 on ARM requires manual unpacking from packed uint8 storage, similar to the Triton approach above. See Lecture 38 slides for detailed ARM-specific optimization patterns.
Approach comparison
| Approach | Target hardware | Precision | Effort | Best for |
|---|
| Liger Kernel | NVIDIA GPU | FP16/BF16 | Low (patching API) | Training memory reduction |
| BitBLAS | NVIDIA GPU | INT1–INT8 | Medium (config + tuning) | Extreme quantization inference |
| Custom Triton (INT4) | NVIDIA GPU | INT4/INT2 | High | Research, non-standard formats |
| ARM NEON/SVE | ARM CPU | INT4/INT8 | High | Mobile / edge deployment |
For production LLM fine-tuning on NVIDIA GPUs, Liger Kernel is the fastest path to memory savings with minimal code changes. For INT4 inference deployment, BitBLAS provides the broadest hardware support and highest throughput.
Further reading