Documentation Index
Fetch the complete documentation index at: https://mintlify.com/OminiX-ai/OminiX-MLX/llms.txt
Use this file to discover all available pages before exploring further.
mlx-rs
Unofficial Rust bindings for the MLX framework, providing efficient machine learning primitives for Apple Silicon.
Overview
mlx-rs is a Rust wrapper around Apple’s MLX framework, designed to leverage unified memory architecture and lazy evaluation for high-performance machine learning on Apple Silicon.
Key features
- Lazy evaluation - Operations build compute graphs that execute only when needed
- Unified memory - CPU and GPU share memory, no explicit device transfers
- Type-safe arrays - Strongly-typed n-dimensional arrays with compile-time safety
- Automatic differentiation - Function transforms for gradient computation
- Hardware acceleration - Optimized for Apple Silicon Metal GPUs
Array
The core Array type represents n-dimensional tensors.
Construction
Create array from slice with explicit shapeuse mlx_rs::Array;
let data = vec![1i32, 2, 3, 4, 5, 6];
let arr = Array::from_slice(&data, &[2, 3]);
Create scalar array from boolean valuelet arr = Array::from_bool(true);
Create scalar array from i32 valuelet arr = Array::from_int(42);
Create scalar array from f32 valuelet arr = Array::from_f32(3.14);
Array macro
The array! macro provides convenient array construction:
use mlx_rs::{array, Dtype};
let a = array!([1, 2, 3, 4]);
assert_eq!(a.shape(), &[4]);
assert_eq!(a.dtype(), Dtype::Int32);
let b = array!([1.0, 2.0, 3.0, 4.0]);
assert_eq!(b.dtype(), Dtype::Float32);
Properties
Returns the shape of the arraylet arr = array!([[1, 2], [3, 4]]);
assert_eq!(arr.shape(), &[2, 2]);
Returns the data type of the arraylet arr = array!([1.0, 2.0]);
assert_eq!(arr.dtype(), Dtype::Float32);
Returns total number of elements in the arraylet arr = array!([[1, 2], [3, 4]]);
assert_eq!(arr.size(), 4);
Returns number of dimensionslet arr = array!([[1, 2], [3, 4]]);
assert_eq!(arr.ndim(), 2);
Check if array is contiguous in memory (row-major/C-style)use mlx_rs::ops::indexing::IndexOp;
let arr = Array::from_slice(&[1, 2, 3, 4, 5, 6], &[2, 3]);
assert!(arr.is_contiguous());
let sliced = arr.index((.., ..2));
// May not be contiguous after indexing
Evaluation
Evaluate the array, forcing computation of lazy operationsuse mlx_rs::{array, transforms::eval};
let a = array!([1, 2, 3, 4]);
let b = array!([1.0, 2.0, 3.0, 4.0]);
let c = &a + &b; // Not evaluated yet
c.eval().unwrap(); // Evaluate now
Extract scalar value from array (evaluates automatically)let arr = array!(42);
let value: i32 = arr.item();
assert_eq!(value, 42);
Access underlying data as slice (evaluates automatically)let arr = array!([1, 2, 3, 4]);
let slice: &[i32] = arr.as_slice();
assert_eq!(slice, &[1, 2, 3, 4]);
Operations (ops)
The ops module provides array operations.
Factory functions
ops::zeros
fn<T>(shape: &[i32]) -> Result<Array>
Create array filled with zerosuse mlx_rs::ops::zeros;
let arr = zeros::<f32>(&[2, 3]).unwrap();
ops::ones
fn<T>(shape: &[i32]) -> Result<Array>
Create array filled with onesuse mlx_rs::ops::ones;
let arr = ones::<f32>(&[2, 3]).unwrap();
ops::arange
fn<U, T>(start: U, stop: U, step: U) -> Result<Array>
Create array with evenly spaced valuesuse mlx_rs::ops::arange;
let arr = arange(0, 10, 2).unwrap(); // [0, 2, 4, 6, 8]
ops::eye
fn<T>(n: i32, m: Option<i32>, k: i32) -> Result<Array>
Create 2D identity matrixuse mlx_rs::ops::eye;
let arr = eye::<f32>(3, None, 0).unwrap();
// [[1, 0, 0],
// [0, 1, 0],
// [0, 0, 1]]
Arithmetic operations
Arrays support standard arithmetic operators:
let a = array!([1.0, 2.0, 3.0]);
let b = array!([4.0, 5.0, 6.0]);
let sum = &a + &b; // Element-wise addition
let diff = &a - &b; // Element-wise subtraction
let prod = &a * &b; // Element-wise multiplication
let quot = &a / &b; // Element-wise division
Reduction operations
ops::sum
fn(array: &Array, axes: Option<&[i32]>) -> Result<Array>
Sum array elements over given axesuse mlx_rs::ops::sum;
let arr = array!([[1, 2], [3, 4]]);
let total = sum(&arr, None).unwrap(); // Sum all: 10
let row_sums = sum(&arr, Some(&[1])).unwrap(); // [3, 7]
ops::mean
fn(array: &Array, axes: Option<&[i32]>) -> Result<Array>
Compute mean over given axesuse mlx_rs::ops::mean;
let arr = array!([1.0, 2.0, 3.0, 4.0]);
let avg = mean(&arr, None).unwrap(); // 2.5
ops::argmax_axis
fn(array: &Array, axis: i32) -> Result<Array>
Indices of maximum values along axisuse mlx_rs::argmax_axis;
let arr = array!([[1, 3, 2], [4, 2, 5]]);
let indices = argmax_axis!(&arr, -1).unwrap(); // [1, 2]
Shape operations
ops::reshape
fn(array: &Array, shape: &[i32]) -> Result<Array>
Reshape array to new dimensionsuse mlx_rs::ops::reshape;
let arr = array!([1, 2, 3, 4, 5, 6]);
let reshaped = reshape(&arr, &[2, 3]).unwrap();
ops::transpose
fn(array: &Array, axes: &[i32]) -> Result<Array>
Permute array dimensionsuse mlx_rs::ops::transpose;
let arr = array!([[1, 2, 3], [4, 5, 6]]); // Shape: [2, 3]
let transposed = transpose(&arr, &[1, 0]).unwrap(); // Shape: [3, 2]
ops::concatenate
fn(arrays: &[Array], axis: i32) -> Result<Array>
Join arrays along existing axisuse mlx_rs::ops::concatenate_axis;
let a = array!([1, 2]);
let b = array!([3, 4]);
let concat = concatenate_axis(&[a, b], 0).unwrap(); // [1, 2, 3, 4]
Indexing
The ops::indexing module provides array indexing operations:
use mlx_rs::ops::indexing::{IndexOp, NewAxis};
let arr = array!([[1, 2, 3], [4, 5, 6]]);
// Slice first row
let row = arr.index(0);
// Slice with ranges
let sub = arr.index((.., 1..3)); // All rows, columns 1-2
// Add dimension
let expanded = arr.index(NewAxis); // Shape: [1, 2, 3]
Function transformations for automatic differentiation and compilation.
Gradient computation
Compute gradient of function with respect to first argumentuse mlx_rs::{Array, error::Exception, transforms::grad};
fn f(args: &[Array]) -> Result<Array, Exception> {
let x = &args[0];
x.square() // f(x) = x²
}
let mut grad_fn = grad(f);
let x = Array::from_f32(3.0);
let df_dx = grad_fn(&[x]).unwrap();
// df/dx = 2x = 6.0
assert_eq!(df_dx.item::<f32>(), 6.0);
transforms::value_and_grad
Compute both function value and gradientuse mlx_rs::{Array, error::Exception, transforms::value_and_grad};
fn loss(args: &[Array]) -> Result<Array, Exception> {
let x = &args[0];
x.square()
}
let mut vg_fn = value_and_grad(loss);
let x = Array::from_f32(3.0);
let (value, grad) = vg_fn(&[x]).unwrap();
Evaluation
transforms::eval
fn(outputs: impl IntoIterator<Item = &Array>) -> Result<()>
Evaluate multiple arrays in one graph evaluationuse mlx_rs::transforms::eval;
let a = array!([1, 2, 3]);
let b = array!([4, 5, 6]);
let c = &a + &b;
let d = &a * &b;
eval(&[&c, &d]).unwrap(); // Evaluate both together
transforms::eval_params
fn(params: ModuleParamRef) -> Result<()>
Evaluate all parameters of a moduleuse mlx_rs::transforms::eval_params;
// After updating model parameters
eval_params(model.parameters()).unwrap();
Random
Random number generation operations.
Generate array from normal distributionuse mlx_rs::normal;
let arr = normal!(shape = &[100]).unwrap();
// Mean 0, std 1
Generate array from uniform distributionuse mlx_rs::uniform;
let arr = uniform!(low = 0.0, high = 1.0, shape = &[100]).unwrap();
Sample from categorical distributionuse mlx_rs::categorical;
let logits = array!([0.1, 0.5, 0.4]);
let sample = categorical!(&logits).unwrap();
Neural networks (nn)
The nn module provides neural network layers and utilities.
Layers
nn::Linear - Fully connected layer
nn::Conv1d, nn::Conv2d - Convolution layers
nn::LayerNorm - Layer normalization
nn::Dropout - Dropout regularization
nn::Embedding - Embedding lookup table
Activations
nn::relu - ReLU activation
nn::gelu - GELU activation
nn::silu - SiLU/Swish activation
nn::softmax - Softmax function
Rotary position embedding
Rotary position embedding for transformersuse mlx_rs::nn::{RopeBuilder, Rope};
let rope = RopeBuilder::new(64)
.traditional(false)
.base(10000.0)
.build()
.unwrap();
Module system
The module module provides traits for neural network modules.
Trait for neural network modules with learnable parametersuse mlx_rs::module::Module;
// All nn layers implement Module
let output = model.forward(&input).unwrap();
let params = model.parameters();
Saving and loading
Arrays and models can be saved/loaded:
| Type | Load | Save |
|---|
Array | Array::load_numpy | Array::save_numpy |
HashMap<String, Array> | Array::load_safetensors | Array::save_safetensors |
Module | ModuleParametersExt::load_safetensors | ModuleParametersExt::save_safetensors |
// Save single array
arr.save_numpy("array.npy").unwrap();
// Load single array
let arr = Array::load_numpy("array.npy").unwrap();
// Save model parameters
model.save_safetensors("model.safetensors").unwrap();
// Load model parameters
model.load_safetensors("model.safetensors").unwrap();
Lazy evaluation
MLX uses lazy evaluation - operations build compute graphs without executing:
use mlx_rs::{array, transforms::eval};
let a = array!([1, 2, 3, 4]);
let b = array!([1.0, 2.0, 3.0, 4.0]);
// No computation happens yet
let c = &a + &b;
let d = &c * 2.0;
// Evaluate when needed
d.eval().unwrap();
// Or evaluation happens automatically when:
println!("{:?}", d); // Printing
let val: f32 = d.item(); // Getting scalar value
let slice = d.as_slice::<f32>(); // Accessing data
When to evaluate
A natural place to use eval() is at each iteration of an outer loop:
for batch in dataset {
// Build compute graph
let (loss, grad) = value_and_grad_fn(&mut model, batch)?;
// Update parameters (still lazy)
optimizer.update(&mut model, grad)?;
// Evaluate loss and parameters together
eval_params(model.parameters())?;
}
Unified memory
On Apple Silicon, CPU and GPU share unified memory:
use mlx_rs::normal;
let a = normal!(shape = &[100]).unwrap();
let b = normal!(shape = &[100]).unwrap();
// Both live in unified memory
// Operations specify device at runtime:
let c = mlx_rs::add!(&a, &b, stream = StreamOrDevice::cpu()).unwrap();
let d = mlx_rs::add!(&a, &b, stream = StreamOrDevice::gpu()).unwrap();
No explicit memory transfers needed - arrays are accessible to all devices.
Example: Linear regression
use mlx_rs::{array, ops, transforms, Array};
use mlx_rs::error::Exception;
fn main() -> Result<(), Exception> {
// Generate synthetic data
let w_star = mlx_rs::normal!(shape = &[100])?;
let x = mlx_rs::normal!(shape = &[1000, 100])?;
let eps = mlx_rs::normal!(shape = &[1000])? * 1e-2;
let y = x.matmul(&w_star)? + eps;
// Initialize weights
let w = mlx_rs::normal!(shape = &[100])? * 1e-2;
// Define loss function
let loss_fn = |inputs: &[Array]| -> Result<Array, Exception> {
let w = &inputs[0];
let x = &inputs[1];
let y = &inputs[2];
let y_pred = x.matmul(w)?;
let loss = Array::from_f32(0.5) * ops::mean(&ops::square(y_pred - y)?, None)?;
Ok(loss)
};
// Train
let mut grad_fn = transforms::grad(loss_fn);
let mut inputs = [w, x, y];
for _ in 0..10000 {
let grad = grad_fn(&inputs)?;
inputs[0] = &inputs[0] - Array::from_f32(0.01) * grad;
inputs[0].eval()?;
}
let loss = loss_fn(&inputs)?;
println!("Final loss: {:.5}", loss.item::<f32>());
Ok(())
}
Module structure
array - Core Array type and operations
ops - Array operations (arithmetic, reduction, shape)
transforms - Function transforms (grad, eval, compile)
random - Random number generation
nn - Neural network layers and activations
module - Module system for composable networks
optimizers - Optimization algorithms (SGD, Adam, etc.)
losses - Loss functions
linalg - Linear algebra operations
fft - Fast Fourier transform
fast - Hardware-optimized operations
dtype - Data type definitions
device - Device management
stream - Computation stream management