Skip to main content
A KV cache stores the key and value tensors computed in previous decoding steps so that the model does not recompute them on every new token. Without a cache, generating a sequence of length N requires N full forward passes, each processing a growing context. With a cache, each decode step processes only one new token and reads back the accumulated K/V history in constant time.

How it works in Meganeura

Meganeura models follow a two-graph design:

Prefill graph

Processes the full prompt in one shot using standard CausalAttention over the entire sequence. Outputs logits for the last prompt token plus the computed K/V tensors for every layer.

Decode graph

Processes exactly one new token per step. Uses CacheWrite to insert the new token’s K/V into a pre-allocated cache buffer, then uses CachedAttention to attend over all cached positions.
The K/V cache buffers are represented as graph Parameter nodes (named kv_cache.layer.{i}.k and kv_cache.layer.{i}.v). They persist in GPU memory between decode steps because parameters are not reset between calls to step().

The CacheWrite op

Op::CacheWrite writes a single-token K or V tensor into row kv_pos of a pre-allocated cache buffer.
  • new_kv: shape [1, dim] — the K or V vector for the current token
  • cache: shape [max_seq, dim] — the pre-allocated cache buffer (a Parameter node)
  • kv_pos: a u32 input scalar — the row index to write into
The operation writes in-place and returns a node representing the updated cache. The write happens at exactly row kv_pos[0], leaving all other rows unchanged.
// From graph.rs:
/// Write `new_kv` [1, dim] into row `kv_pos` of `cache` [max_seq, dim].
/// Returns a node representing the updated cache buffer.
pub fn cache_write(&mut self, new_kv: NodeId, cache: NodeId, kv_pos: NodeId) -> NodeId

The CachedAttention op

Op::CachedAttention computes scaled dot-product attention where the query comes from the current token and the keys and values come from the cache.
  • q: shape [1, num_heads * head_dim] — query for the current token
  • k_cache: shape [max_seq, num_kv_heads * head_dim] — accumulated key cache
  • v_cache: shape [max_seq, num_kv_heads * head_dim] — accumulated value cache
  • kv_pos: a u32 scalar — number of valid positions in the cache (attention is masked to rows 0..kv_pos)
The output has shape [1, num_heads * head_dim].
// From graph.rs:
/// Cached attention: Q attends to K/V cache.
/// q: [1, num_heads*head_dim], k_cache/v_cache: [max_seq, kv_dim],
/// kv_pos: u32 scalar (number of valid positions in cache).
pub fn cached_attention(
    &mut self,
    q: NodeId,
    k_cache: NodeId,
    v_cache: NodeId,
    kv_pos: NodeId,
    num_heads: u32,
    num_kv_heads: u32,
    head_dim: u32,
) -> NodeId

RoPE with a dynamic position offset

During prefill, RoPE positions are computed statically from row indices (position 0 for the first token, 1 for the second, and so on). During decode, the single new token has a position equal to kv_pos — the number of tokens already in the cache. Use graph.rope_dynamic_offset to apply RoPE with a position read from a runtime input buffer rather than a compile-time constant:
// From graph.rs:
/// RoPE with a dynamic position offset read from an input buffer.
/// The position for each row is `row_index + offset_buf[0]`.
pub fn rope_dynamic_offset(
    &mut self,
    x: NodeId,
    theta: f32,
    offset_input: NodeId,
    head_dim: u32,
) -> NodeId
The offset_input node must be a u32 input of shape [1]. At decode time, set it to kv_pos (the current cache fill level). Because x has shape [1, dim] in the decode graph, row index 0 plus the offset gives the correct absolute position.

Difference from the prefill path

PrefillDecode
Attention opcausal_attention(q, k, v, …)cached_attention(q, k_cache, v_cache, kv_pos, …)
Input shape[seq_len, hidden][1, hidden]
ContextFull prompt sequenceAccumulated KV cache
RoPEStatic offset (row index = absolute position)Dynamic offset via rope_dynamic_offset
KV outputWritten to graph outputs for cache initializationWritten in-place via cache_write

Building the decode graph

The following is the decode layer loop from src/models/smollm2.rs. Each transformer layer allocates a K cache and V cache parameter, applies cache_write to update them, and uses cached_attention to compute attention over the full cached history.
// Single token input
let token_ids = g.input_u32("token_ids", &[1]);
// Dynamic position for cache write (u32 scalar)
let kv_pos = g.input_u32("kv_pos", &[1]);

let embed_weight = g.parameter("model.embed_tokens.weight", &[config.vocab_size, hidden]);
let mut x = g.embedding(token_ids, embed_weight);

for i in 0..config.num_hidden_layers {
    let prefix = format!("model.layers.{}", i);

    let ln1_w = g.parameter(&format!("{}.input_layernorm.weight", prefix), &[hidden]);
    let h = g.rms_norm(x, ln1_w, eps);

    let wq = g.parameter(&format!("{}.self_attn.q_proj.weight", prefix), &[hidden, hidden]);
    let wk = g.parameter(&format!("{}.self_attn.k_proj.weight", prefix), &[hidden, kv_dim]);
    let wv = g.parameter(&format!("{}.self_attn.v_proj.weight", prefix), &[hidden, kv_dim]);

    let q = g.matmul(h, wq); // [1, hidden]
    let k = g.matmul(h, wk); // [1, kv_dim]
    let v = g.matmul(h, wv); // [1, kv_dim]

    // RoPE with dynamic position offset from kv_pos input
    let q = g.rope_dynamic_offset(q, theta, kv_pos, config.head_dim());
    let k = g.rope_dynamic_offset(k, theta, kv_pos, config.head_dim());

    // Pre-allocated KV cache buffers (treated as mutable parameters)
    let k_cache = g.parameter(&format!("kv_cache.layer.{}.k", i), &[max_seq_len, kv_dim]);
    let v_cache = g.parameter(&format!("kv_cache.layer.{}.v", i), &[max_seq_len, kv_dim]);

    // Write new K/V into cache at kv_pos
    let _k_updated = g.cache_write(k, k_cache, kv_pos);
    let _v_updated = g.cache_write(v, v_cache, kv_pos);

    // Cached attention: Q attends to full cache up to kv_pos positions
    let attn = g.cached_attention(
        q,
        k_cache,
        v_cache,
        kv_pos,
        config.num_attention_heads,
        config.num_key_value_heads,
        config.head_dim(),
    );

    let wo = g.parameter(&format!("{}.self_attn.o_proj.weight", prefix), &[hidden, hidden]);
    let attn_out = g.matmul(attn, wo);
    x = g.add(x, attn_out);

    // FFN omitted for brevity …
}
cache_write returns an updated cache node, but the decode graph reads from the original k_cache and v_cache parameter nodes in cached_attention. The write happens in-place on the GPU buffer, so both the original parameter node and the updated node refer to the same underlying buffer. The _k_updated and _v_updated nodes exist to ensure the write is included in the execution plan.

Initializing the cache from prefill

After running the prefill graph, copy each layer’s K and V outputs into the corresponding decode-graph parameter buffers. Use session.read_output_by_index to read prefill K/V outputs and session.upload_param to write them into the decode cache.
// Prefill: build_prefill_graph returns (logits, k_outputs, v_outputs) node IDs.
// The compiled session exposes them via read_output_by_index(index, buf).

// For each layer i, the prefill graph outputs k at index (1 + i) and v at (1 + num_layers + i).
// After running the prefill session:
let kv_size = prompt_len * kv_dim;
for i in 0..config.num_hidden_layers {
    let mut k_buf = vec![0.0f32; kv_size];
    let mut v_buf = vec![0.0f32; kv_size];
    prefill_session.read_output_by_index(1 + i, &mut k_buf);
    prefill_session.read_output_by_index(1 + config.num_hidden_layers + i, &mut v_buf);

    // Upload into the decode session's cache parameters
    decode_session.upload_param(&format!("kv_cache.layer.{}.k", i), &k_buf);
    decode_session.upload_param(&format!("kv_cache.layer.{}.v", i), &v_buf);
}

Running the decode loop

Each decode step processes one token, updates the cache, and reads back logits for the next token prediction.
let mut kv_pos: u32 = prompt_len as u32;

for _step in 0..max_new_tokens {
    decode_session.set_input_u32("token_ids", &[current_token]);
    decode_session.set_input_u32("kv_pos",    &[kv_pos]);
    decode_session.step();
    decode_session.wait();

    // Output is [1, vocab_size] — read the single-token logits
    let logits = decode_session.read_output(config.vocab_size);
    let next_token = logits.iter()
        .enumerate()
        .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
        .unwrap()
        .0 as u32;

    current_token = next_token;
    kv_pos += 1;
}
Do not let kv_pos exceed max_seq_len. The CacheWrite shader writes at row kv_pos without bounds checking. If kv_pos >= max_seq_len, the write will land outside the allocated buffer.
The decode session’s cache parameters persist between step() calls because parameters are stored on the GPU and are not reset between invocations. You only need to initialize them once from prefill output.

Build docs developers (and LLMs) love