How it works in Meganeura
Meganeura models follow a two-graph design:Prefill graph
Processes the full prompt in one shot using standard
CausalAttention over the entire sequence. Outputs logits for the last prompt token plus the computed K/V tensors for every layer.Decode graph
Processes exactly one new token per step. Uses
CacheWrite to insert the new token’s K/V into a pre-allocated cache buffer, then uses CachedAttention to attend over all cached positions.Parameter nodes (named kv_cache.layer.{i}.k and kv_cache.layer.{i}.v). They persist in GPU memory between decode steps because parameters are not reset between calls to step().
The CacheWrite op
Op::CacheWrite writes a single-token K or V tensor into row kv_pos of a pre-allocated cache buffer.
new_kv: shape[1, dim]— the K or V vector for the current tokencache: shape[max_seq, dim]— the pre-allocated cache buffer (aParameternode)kv_pos: au32input scalar — the row index to write into
kv_pos[0], leaving all other rows unchanged.
The CachedAttention op
Op::CachedAttention computes scaled dot-product attention where the query comes from the current token and the keys and values come from the cache.
q: shape[1, num_heads * head_dim]— query for the current tokenk_cache: shape[max_seq, num_kv_heads * head_dim]— accumulated key cachev_cache: shape[max_seq, num_kv_heads * head_dim]— accumulated value cachekv_pos: au32scalar — number of valid positions in the cache (attention is masked to rows0..kv_pos)
[1, num_heads * head_dim].
RoPE with a dynamic position offset
During prefill, RoPE positions are computed statically from row indices (position 0 for the first token, 1 for the second, and so on). During decode, the single new token has a position equal tokv_pos — the number of tokens already in the cache.
Use graph.rope_dynamic_offset to apply RoPE with a position read from a runtime input buffer rather than a compile-time constant:
offset_input node must be a u32 input of shape [1]. At decode time, set it to kv_pos (the current cache fill level). Because x has shape [1, dim] in the decode graph, row index 0 plus the offset gives the correct absolute position.
Difference from the prefill path
| Prefill | Decode | |
|---|---|---|
| Attention op | causal_attention(q, k, v, …) | cached_attention(q, k_cache, v_cache, kv_pos, …) |
| Input shape | [seq_len, hidden] | [1, hidden] |
| Context | Full prompt sequence | Accumulated KV cache |
| RoPE | Static offset (row index = absolute position) | Dynamic offset via rope_dynamic_offset |
| KV output | Written to graph outputs for cache initialization | Written in-place via cache_write |
Building the decode graph
The following is the decode layer loop fromsrc/models/smollm2.rs. Each transformer layer allocates a K cache and V cache parameter, applies cache_write to update them, and uses cached_attention to compute attention over the full cached history.
cache_write returns an updated cache node, but the decode graph reads from the original k_cache and v_cache parameter nodes in cached_attention. The write happens in-place on the GPU buffer, so both the original parameter node and the updated node refer to the same underlying buffer. The _k_updated and _v_updated nodes exist to ensure the write is included in the execution plan.Initializing the cache from prefill
After running the prefill graph, copy each layer’s K and V outputs into the corresponding decode-graph parameter buffers. Usesession.read_output_by_index to read prefill K/V outputs and session.upload_param to write them into the decode cache.