Skip to main content
SmolLM2 is a compact decoder-only transformer language model. Meganeura includes a pre-built graph definition that matches the HuggingFace checkpoint format, allowing you to load weights directly and run inference or fine-tuning.

Configuration

vocab_size
usize
Vocabulary size — number of token embeddings. SmolLM2-135M uses 49152.
hidden_size
usize
Dimensionality of the transformer hidden state. SmolLM2-135M uses 576.
num_hidden_layers
usize
Number of transformer decoder blocks. SmolLM2-135M uses 30.
num_attention_heads
u32
Number of query heads in grouped-query attention (GQA). SmolLM2-135M uses 9.
num_key_value_heads
u32
Number of key/value heads (fewer than query heads for GQA). SmolLM2-135M uses 3.
intermediate_size
usize
Inner dimension of the SwiGLU feed-forward network. SmolLM2-135M uses 1536.
rms_norm_eps
f32
Epsilon for RMSNorm numerical stability. SmolLM2-135M uses 1e-5.
rope_theta
f32
Base frequency for Rotary Position Embeddings (RoPE). SmolLM2-135M uses 10000.0.
Use the built-in preset for SmolLM2-135M:
use meganeura::models::smollm2::{SmolLM2Config, build_graph};

let config = SmolLM2Config::smollm2_135m();

Architecture

Each transformer block follows the standard LLaMA-style decoder design:
  1. Pre-attention RMSNorm — normalizes the input before the attention block
  2. Grouped-query attention — projects Q/K/V, applies RoPE, then runs causal attention with fewer KV heads than Q heads
  3. Residual connection — adds attention output back to the input
  4. Post-attention RMSNorm — normalizes before the FFN
  5. SwiGLU FFN — gate and up projections followed by element-wise SwiGLU, then a down projection
  6. Residual connection — adds FFN output back
After all layers: a final RMSNorm and a linear LM head project to vocabulary logits.

Building the inference graph

The build_graph function constructs the full forward pass for a given sequence length:
use meganeura::{Graph, build_inference_session};
use meganeura::models::smollm2::{SmolLM2Config, build_graph};

let config = SmolLM2Config::smollm2_135m();
let seq_len = 64;

let mut g = Graph::new();
let logits = build_graph(&mut g, &config, seq_len);
g.set_outputs(vec![logits]);

let mut session = build_inference_session(&g);
The graph expects one input:
  • "token_ids" — U32 tensor of shape [seq_len]

Loading HuggingFace weights

Weight names follow the HuggingFace safetensors convention. Linear layer weights are stored transposed in HuggingFace format ([out, in]) and must be transposed when loading:
use meganeura::data::safetensors::SafeTensorsModel;
use meganeura::models::smollm2::{SmolLM2Config, weight_names, transposed_weight_names};

let hf = SafeTensorsModel::download("HuggingFaceTB/SmolLM2-135M").unwrap();
let config = SmolLM2Config::smollm2_135m();

for name in weight_names(&config) {
    let data = if transposed_weight_names(&config).contains(&name) {
        hf.tensor_f32_transposed(&name).unwrap()
    } else {
        hf.tensor_f32(&name).unwrap()
    };
    session.set_parameter(&name, &data);
}
lm_head.weight is often weight-tied to model.embed_tokens.weight in SmolLM2. If lm_head.weight is absent from the checkpoint, load model.embed_tokens.weight transposed and use it for both.

Prefill and decode graphs

For autoregressive generation with KV cache, use the prefill and decode graph builders:
use meganeura::models::smollm2::{build_prefill_graph, build_decode_graph};

// Prefill: process the full prompt, capture K/V for each layer
let mut prefill_g = Graph::new();
let (logits, k_outs, v_outs) = build_prefill_graph(&mut prefill_g, &config, prompt_len);
let mut outputs = vec![logits];
outputs.extend_from_slice(&k_outs);
outputs.extend_from_slice(&v_outs);
prefill_g.set_outputs(outputs);

// Decode: single token with KV cache (pre-allocated max_seq_len slots)
let mut decode_g = Graph::new();
let (logits, k_caches, v_caches) = build_decode_graph(&mut decode_g, &config, max_seq_len);
decode_g.set_outputs(vec![logits]);
The decode graph expects two inputs per step:
  • "token_ids" — U32 tensor of shape [1] (the current token)
  • "kv_pos" — U32 tensor of shape [1] (number of already-cached positions)
See KV cache for the full decode loop pattern.

Build docs developers (and LLMs) love