Skip to main content
Meganeura provides both a high-level nn::CausalSelfAttention struct and low-level Graph ops for every attention variant. The nn:: struct is suitable for typical prefill paths in decoder models. Use the graph ops directly when you need non-causal, cross-attention, differentiable training, or decode-time KV-cache access.

nn::CausalSelfAttention

Causal self-attention with grouped-query attention (GQA) and rotary position embeddings (RoPE). All four projection matrices are bias-free. Fields
q_proj
Linear
required
Query projection: hidden → hidden.
k_proj
Linear
required
Key projection: hidden → kv_dim.
v_proj
Linear
required
Value projection: hidden → kv_dim.
o_proj
Linear
required
Output projection: hidden → hidden.
num_heads
u32
required
Number of query attention heads.
num_kv_heads
u32
required
Number of key/value heads. Must divide num_heads evenly (GQA).
head_dim
u32
required
Dimension per head.
rope_theta
f32
required
Base frequency for rotary position embeddings.

AttentionConfig

pub struct AttentionConfig {
    pub hidden: usize,
    pub kv_dim: usize,
    pub num_heads: u32,
    pub num_kv_heads: u32,
    pub head_dim: u32,
    pub rope_theta: f32,
}
hidden
usize
required
Model hidden dimension. Used for query, and output projection sizes.
kv_dim
usize
required
Key/value projection dimension (num_kv_heads * head_dim).
num_heads
u32
required
Number of query heads.
num_kv_heads
u32
required
Number of key/value heads.
head_dim
u32
required
Dimension per head.
rope_theta
f32
required
RoPE base frequency (e.g. 10000.0 for standard, 500000.0 for Llama 3).

CausalSelfAttention::new

g
&mut Graph
required
The computation graph to register parameters into.
name
&str
required
Name prefix. Registers {name}.q_proj, {name}.k_proj, {name}.v_proj, {name}.o_proj weights.
cfg
&AttentionConfig
required
Attention configuration struct.
let cfg = nn::AttentionConfig {
    hidden: 512,
    kv_dim: 128,
    num_heads: 8,
    num_kv_heads: 2,
    head_dim: 64,
    rope_theta: 10000.0,
};
let attn = nn::CausalSelfAttention::new(&mut g, "model.layers.0.self_attn", &cfg);

forward

Projects Q/K/V, applies RoPE to Q and K, then runs causal masked attention. Forward computation:
q = q_proj(x)         // [seq, hidden]
k = k_proj(x)         // [seq, kv_dim]
v = v_proj(x)         // [seq, kv_dim]
q = rope(q, theta, head_dim)
k = rope(k, theta, head_dim)
attn = causal_attention(q, k, v, ...)
out = o_proj(attn)    // [seq, hidden]
g
&mut Graph
required
The computation graph to append ops to.
x
NodeId
required
Input tensor of shape [seq, hidden].
NodeId
NodeId
Output tensor of shape [seq, hidden].
let y = attn.forward(&mut g, x);
// y shape: [seq, 512]

Graph attention ops

All attention ops operate on 2D tensors. There are no explicit batch or head dimensions — sequences are flattened and heads are interleaved in the last dimension.

g.causal_attention

Fused causal masked multi-head attention with GQA support. Intended for prefill (processing a full sequence at once).
q
NodeId
required
Query tensor of shape [seq, num_heads * head_dim].
k
NodeId
required
Key tensor of shape [seq, num_kv_heads * head_dim].
v
NodeId
required
Value tensor of shape [seq, num_kv_heads * head_dim].
num_heads
u32
required
Number of query heads.
num_kv_heads
u32
required
Number of key/value heads.
head_dim
u32
required
Dimension per head.
NodeId
NodeId
Output tensor of shape [seq, num_heads * head_dim].
let attn = g.causal_attention(q, k, v, 8, 2, 64);

g.full_attention

Non-causal (bidirectional) multi-head attention with GQA support. Used in vision encoders and other encoder-only contexts where every token attends to every other token.
q
NodeId
required
Query tensor of shape [seq, num_heads * head_dim].
k
NodeId
required
Key tensor of shape [seq, num_kv_heads * head_dim].
v
NodeId
required
Value tensor of shape [seq, num_kv_heads * head_dim].
num_heads
u32
required
Number of query heads.
num_kv_heads
u32
required
Number of key/value heads.
head_dim
u32
required
Dimension per head.
NodeId
NodeId
Output tensor of shape [seq, num_heads * head_dim].
// Vision encoder: every patch attends to every other patch
let attn = g.full_attention(q, k, v, 16, 16, 64);

g.cross_attention

Cross-attention where the query sequence attends to a separate key/value sequence. Q and K/V may have different sequence lengths.
q
NodeId
required
Query tensor of shape [q_seq, num_heads * head_dim].
k
NodeId
required
Key tensor of shape [kv_seq, num_kv_heads * head_dim].
v
NodeId
required
Value tensor of shape [kv_seq, num_kv_heads * head_dim].
num_heads
u32
required
Number of query heads.
num_kv_heads
u32
required
Number of key/value heads.
head_dim
u32
required
Dimension per head.
NodeId
NodeId
Output tensor of shape [q_seq, num_heads * head_dim].
// Decoder attending to encoder output
let out = g.cross_attention(q_dec, k_enc, v_enc, 8, 8, 64);

g.multi_head_attn

Differentiable multi-head attention that saves the log-sum-exp (LSE) buffer needed for an exact backward pass. Use this variant when training. Handles both self-attention (is_cross = false) and cross-attention (is_cross = true). The autodiff engine uses MultiHeadAttnGradQ/K/V ops that reference the saved LSE.
q
NodeId
required
Query tensor of shape [q_seq, num_heads * head_dim].
k
NodeId
required
Key tensor of shape [kv_seq, num_kv_heads * head_dim].
v
NodeId
required
Value tensor of shape [kv_seq, num_kv_heads * head_dim].
num_heads
u32
required
Number of query heads.
num_kv_heads
u32
required
Number of key/value heads.
head_dim
u32
required
Dimension per head.
is_cross
bool
required
false for self-attention, true for cross-attention.
NodeId
NodeId
Output tensor of shape [q_seq, num_heads * head_dim]. An LSE buffer [q_seq * num_heads] is also allocated internally during compilation for backward use.
// Training: differentiable self-attention
let out = g.multi_head_attn(q, k, v, 8, 8, 64, false);

g.cached_attention

Decode-time attention where a single new query token attends to the entire KV cache. Requires q to have seq_len = 1.
q
NodeId
required
Query tensor of shape [1, num_heads * head_dim].
k_cache
NodeId
required
Key cache buffer of shape [max_seq, num_kv_heads * head_dim].
v_cache
NodeId
required
Value cache buffer of shape [max_seq, num_kv_heads * head_dim].
kv_pos
NodeId
required
U32 scalar input: the number of valid positions currently stored in the cache.
num_heads
u32
required
Number of query heads.
num_kv_heads
u32
required
Number of key/value heads.
head_dim
u32
required
Dimension per head.
NodeId
NodeId
Output tensor of shape [1, num_heads * head_dim].
let kv_pos = g.input_u32("kv_pos", &[1]);
let out = g.cached_attention(q, k_cache, v_cache, kv_pos, 8, 2, 64);
// out shape: [1, 512]

Rotary position embedding ops

g.rope

Applies RoPE starting at position 0. Use for prefill.
x
NodeId
required
2D input tensor of shape [seq, dim]. dim must be even and divisible by head_dim.
theta
f32
required
RoPE base frequency.
head_dim
u32
required
Dimension of each attention head. Rotations are applied independently within each head.
NodeId
NodeId
Rotated tensor, same shape [seq, dim].
let q = g.rope(q, 10000.0, 64);

g.rope_with_offset

Applies RoPE with a static position offset. Use when the current sequence starts at a known non-zero position.
x
NodeId
required
2D input tensor of shape [seq, dim].
theta
f32
required
RoPE base frequency.
pos_offset
u32
required
Static offset added to each row’s position index.
head_dim
u32
required
Dimension per head.
NodeId
NodeId
Rotated tensor, same shape [seq, dim].
// Apply RoPE starting at position 128
let q = g.rope_with_offset(q, 10000.0, 128, 64);

g.rope_dynamic_offset

Applies RoPE with a position offset read at runtime from a U32 input buffer. Use for decode-time single-token inference where kv_pos is known only at runtime.
x
NodeId
required
2D input tensor of shape [seq, dim].
theta
f32
required
RoPE base frequency.
offset_input
NodeId
required
U32 scalar input. Its value is added to each row’s position (position = row_index + offset_input[0]).
head_dim
u32
required
Dimension per head.
NodeId
NodeId
Rotated tensor, same shape [seq, dim].
let kv_pos = g.input_u32("kv_pos", &[1]);
let q = g.rope_dynamic_offset(q, 10000.0, kv_pos, 64);

KV cache write op

g.cache_write

Writes a single new key or value row into a pre-allocated cache buffer at position kv_pos. Returns a node representing the updated cache.
new_kv
NodeId
required
New key or value tensor of shape [1, dim].
cache
NodeId
required
Cache buffer of shape [max_seq, dim].
kv_pos
NodeId
required
U32 scalar input: the position index at which to write.
NodeId
NodeId
Updated cache buffer, same shape [max_seq, dim] as the input cache.
let kv_pos = g.input_u32("kv_pos", &[1]);
let k_cache = g.input("k_cache", &[max_seq, kv_dim]);
let k_new = k_proj.forward(&mut g, x);  // shape [1, kv_dim]
let k_cache = g.cache_write(k_new, k_cache, kv_pos);
cache_write performs an in-place write at the GPU level. The returned NodeId aliases the cache buffer — use it (not the original cache input) as the source for subsequent cached_attention calls within the same graph.

Build docs developers (and LLMs) love