Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/gpu-mode/lectures/llms.txt

Use this file to discover all available pages before exploring further.

ARM CPUs power the majority of the world’s mobile devices, edge accelerators, and a growing number of cloud servers. Optimizing matrix operations for ARM means understanding SIMD instruction sets — NEON for fixed-width 128-bit vectors, SVE for scalable-width vectors — and exploiting low-bit quantization formats (INT8, INT4) that increase arithmetic throughput while shrinking memory bandwidth. This page accompanies Lecture 38 by Scott Roy.
Lecture 38 slides are available in the lecture repository at lecture_038/lowbit_kernels.pdf.

Why ARM CPU optimization matters

GPU availability is not universal. A large fraction of ML inference workloads run on ARM CPUs — in smartphones, IoT devices, on-premises edge servers, and increasingly in ARM-based cloud instances (AWS Graviton, Ampere Altra). Even when a GPU is present, the CPU often handles data preprocessing, tokenization, and small-batch inference where GPU launch overhead dominates.

Mobile and edge

Smartphones and microcontrollers run ARM cores exclusively. On-device LLM inference (e.g., llama.cpp) depends entirely on NEON/SVE performance.

Cloud servers

AWS Graviton3, Ampere Altra, and Apple M-series chips are ARM-based. Cost per inference can be lower than x86 for many workloads.

Memory bandwidth

Low-bit quantization reduces memory traffic — the dominant bottleneck for transformer inference — making INT8 and INT4 essential for throughput on ARM.

ARM NEON SIMD: 128-bit vector registers

NEON is ARM’s Advanced SIMD extension, present on every modern ARM application processor (Cortex-A, Apple Silicon, Graviton). Each NEON register is 128 bits wide and can hold:
  • 16 × INT8 (or UINT8)
  • 8 × INT16
  • 4 × INT32
  • 4 × FP32
  • 8 × FP16
NEON intrinsics follow a consistent naming pattern:
#include <arm_neon.h>

// Load 16 int8 values from memory into a NEON register
int8x16_t a = vld1q_s8(ptr);

// Multiply-accumulate: vmlaq_s32(acc, a, b)
// Each int8 lane multiplied and accumulated into int32
int32x4_t acc = vdupq_n_s32(0);
// NEON does not have int8→int32 widening multiply-accumulate directly;
// use vmull_s8 + vaddw_s16 or the SDOT instruction (see below).
The v prefix denotes a NEON operation. The type suffix encodes the lane count and element type: s8 = signed 8-bit, u8 = unsigned 8-bit, s32 = signed 32-bit, q = 128-bit (quad) register.
Prefer intrinsics over inline assembly for portability across NEON targets. Compilers (GCC, Clang) translate intrinsics to optimal instruction sequences and handle register allocation.

Key NEON intrinsics for matrix multiply

#include <arm_neon.h>

// Vectorized INT8 dot product building block (pre-SDOT)
// Multiply two int8x8 vectors → int16x8, then widen-accumulate
void int8_dot_neon(const int8_t *a, const int8_t *b, int32_t *acc, int n) {
  int32x4_t sum = vdupq_n_s32(0);
  for (int i = 0; i < n; i += 8) {
    int8x8_t  va = vld1_s8(a + i);
    int8x8_t  vb = vld1_s8(b + i);
    int16x8_t prod = vmull_s8(va, vb);        // 8 × int8 → 8 × int16
    sum = vpadalq_s16(sum, prod);              // pairwise-add int16 into int32
  }
  *acc = vaddvq_s32(sum);                      // horizontal sum
}

ARM SVE: scalable vector extension

SVE (Scalable Vector Extension) generalizes NEON by making the vector length a runtime parameter. Instead of hardcoding 128-bit registers, SVE uses scalable vectors whose width ranges from 128 to 2048 bits in 128-bit increments depending on the hardware implementation (e.g., 256-bit on Fujitsu A64FX, 512-bit on AWS Graviton3). This means SVE code written once runs correctly on all SVE implementations — the compiler does not need to know the vector length at compile time.
#include <arm_sve.h>

// SVE vectorized dot product — adapts to any hardware vector width
int32_t sve_dot_s8(const int8_t *a, const int8_t *b, int n) {
  svint32_t sum = svdup_n_s32(0);
  int64_t i = 0;
  svbool_t pg;
  while (i < n) {
    pg = svwhilelt_b8(i, (int64_t)n);          // predicate: active lanes
    svint8_t va = svld1_s8(pg, a + i);
    svint8_t vb = svld1_s8(pg, b + i);
    sum = svdot_s32(sum, va, vb);               // SVE SDOT: 4 int8 → 1 int32
    i  += svcntb();                             // advance by hardware vector width
  }
  return svaddv_s32(svptrue_b32(), sum);        // horizontal sum
}
svcntb() returns the number of active bytes per vector at runtime, so the loop is entirely hardware-agnostic.
SVE2 (introduced with ARMv9) extends SVE with additional instructions including SMMLA and UMMLA for 8×8 INT8 matrix multiply accumulate, directly targeting the same use cases as NVIDIA’s tensor core instructions.

Low-bit quantization on ARM: INT8 and INT4

Quantization reduces weight and activation precision to shrink memory footprint and improve throughput. On ARM:
  • INT8: fully supported in NEON/SVE via SDOT/UDOT instructions (see below). Typical accuracy loss < 0.5% on most LLMs with per-channel weight quantization.
  • INT4: requires dequantization before arithmetic on current NEON hardware. Weights stored as INT4 (2 values per byte) are unpacked to INT8 before being fed to SDOT, halving the memory read volume.
// Unpack packed INT4 nibbles to INT8
void unpack_int4_to_int8(const uint8_t *src, int8_t *dst, int n_bytes) {
  for (int i = 0; i < n_bytes; i++) {
    // Low nibble: sign-extend 4-bit to 8-bit
    dst[2 * i]     = (int8_t)((src[i] & 0x0F) << 4) >> 4;
    // High nibble
    dst[2 * i + 1] = (int8_t)((src[i] & 0xF0)) >> 4;
  }
}
NEON can vectorize INT4 unpacking using vandq_u8 and vshrq_n_u8. Processing 16 packed bytes (32 INT4 weights) per NEON instruction makes unpacking fast relative to the SDOT compute.

SDOT and UDOT: INT8 matrix multiply instructions

The SDOT (signed dot product) and UDOT (unsigned dot product) instructions, introduced in ARMv8.2-A, are the key to efficient INT8 GEMM on ARM. Each instruction takes a group of 4 INT8 values from each operand and accumulates their dot product into a 32-bit accumulator — one cycle for four multiply-adds. In C intrinsics:
#include <arm_neon.h>

// vsdotq_s32: signed 8-bit dot product accumulate (SDOT)
// Computes 4 groups of 4×int8 dot products per call
// dst[i] += a[4i..4i+3] · b[4i..4i+3]  for i in 0..3

void gemv_int8_sdot(const int8_t *A,   // M × K matrix, row-major
                    const int8_t *x,   // K-element vector
                    int32_t      *y,   // M-element output
                    int M, int K) {
  for (int m = 0; m < M; m++) {
    int32x4_t acc = vdupq_n_s32(0);
    for (int k = 0; k < K; k += 16) {
      int8x16_t va = vld1q_s8(A + m * K + k);
      int8x16_t vx = vld1q_s8(x + k);
      acc = vsdotq_s32(acc, va, vx);    // 4 SDOT lanes, 4 int8 each
    }
    y[m] = vaddvq_s32(acc);
  }
}
A single vsdotq_s32 call processes 16 INT8 multiplications in one instruction, compared to 4 with vmull_s8 + widening accumulate.
SDOT/UDOT require ARMv8.2-A or later (-march=armv8.2-a+dotprod with GCC/Clang). Check target support at compile time with #ifdef __ARM_FEATURE_DOTPROD. Most Cortex-A55, A75, and all A76+ cores support it; Apple Silicon supports it on all M-series chips.

I8MM: INT8 matrix multiply extension

The I8MM (Int8 Matrix Multiply) extension, part of ARMv8.6-A, adds SMMLA and UMMLA — instructions that perform an 8×8 INT8 matrix tile multiplication in a single instruction, accumulating into a 2×2 block of INT32 accumulators.
#include <arm_neon.h>

// vmmaq_s32: 8×8 INT8 matrix multiply accumulate (I8MM)
// Processes a 2×8 block of A and an 8×2 block of B
// producing a 2×2 INT32 output block

void i8mm_tile(const int8_t *A, const int8_t *B, int32_t *C) {
  int32x4_t acc = vdupq_n_s32(0);           // 2×2 accumulator as int32x4
  int8x16_t va  = vld1q_s8(A);              // 2 rows × 8 elements
  int8x16_t vb  = vld1q_s8(B);              // 8 rows × 2 elements (transposed)
  acc = vmmlaq_s32(acc, va, vb);            // SMMLA: 2×2 tile result
}
I8MM effectively doubles the arithmetic throughput per cycle compared to SDOT for matrix workloads, approaching what GPU tensor cores offer at the same clock frequency.
I8MM is available on Cortex-X1C, A78, and newer; AWS Graviton3; and Apple Silicon from M1 onward. Detect at runtime with HWCAP2_I8MM from sys/auxv.h on Linux, or at compile time with #ifdef __ARM_FEATURE_MATMUL_INT8.

Memory layout for ARM cache hierarchy

Cache efficiency is as important as instruction throughput on ARM. ARM Cortex and Neoverse cores typically use:
  • L1 cache: 32–64 KB, 4-cycle latency
  • L2 cache: 256 KB – 1 MB, 12–15-cycle latency
  • L3 / system cache: 4–32 MB, 30–50-cycle latency
For matrix multiply, the classic tiling strategy applies: choose tile sizes so that the working set of A, B, and C tiles fits in L1 or L2.
// Tiled GEMM skeleton for ARM cache efficiency
void gemm_tiled(const int8_t *A, const int8_t *B, int32_t *C,
                int M, int N, int K) {
  constexpr int MC = 64;   // rows per A tile  (tune for L1 size)
  constexpr int NC = 64;   // cols per B tile
  constexpr int KC = 256;  // K reduction block

  for (int m = 0; m < M; m += MC)
  for (int n = 0; n < N; n += NC)
  for (int k = 0; k < K; k += KC) {
    int mc = std::min(MC, M - m);
    int nc = std::min(NC, N - n);
    int kc = std::min(KC, K - k);
    // inner kernel: mc×nc output block, kc reduction
    inner_kernel(A + m*K + k, B + k*N + n, C + m*N + n,
                 mc, nc, kc, K, N);
  }
}
Pack B into column-major layout before the outer loop to achieve sequential memory access in the inner kernel. Libraries like libxsmm and XNNPACK do this automatically.

Prefetching

ARM CPUs have hardware prefetchers that detect sequential strides, but software prefetch hints help for indirect or strided access patterns:
// Prefetch 8 cache lines ahead (typical 64-byte cache line)
__builtin_prefetch(ptr + 512, 0, 1);  // read, low temporal locality

Benchmarking on ARM hardware

Measure wall-clock time with clock_gettime(CLOCK_MONOTONIC) on Linux, or mach_absolute_time() on macOS/Apple Silicon. Avoid std::chrono in tight loops due to syscall overhead.
#include <time.h>

double elapsed_ms(struct timespec start, struct timespec end) {
  return (end.tv_sec - start.tv_sec) * 1e3 +
         (end.tv_nsec - start.tv_nsec) * 1e-6;
}

// Usage
struct timespec t0, t1;
clock_gettime(CLOCK_MONOTONIC, &t0);
gemm_tiled(A, B, C, M, N, K);
clock_gettime(CLOCK_MONOTONIC, &t1);
printf("%.2f ms  %.1f GOPS\n",
       elapsed_ms(t0, t1),
       2.0 * M * N * K / elapsed_ms(t0, t1) * 1e-6);

Useful profiling tools

# Count INT8 instructions and cache misses
perf stat -e instructions,cache-misses,cycles \
          ./your_benchmark

# Per-function breakdown
perf record -g ./your_benchmark && perf report

Reference implementations

Production-quality ARM INT8 kernels are in these open-source libraries:

XNNPACK

Google’s optimized neural network operators for ARM. Used by TensorFlow Lite and PyTorch Mobile. Hand-written NEON and SVE assembly for GEMM, convolution, and depthwise ops.

llama.cpp

LLM inference on CPU. Contains highly optimized ARM NEON kernels for INT4 and INT8 quantized GEMM, targeting Apple Silicon and mobile Cortex-A devices.

torchao

PyTorch quantization and sparsity library. Includes ARM-specific lowbit kernel backends contributed by the PyTorch team.

Lecture 38 slides

Scott Roy’s full slide deck covering low-bit kernels for ARM CPUs.

Build docs developers (and LLMs) love