nn::CausalSelfAttention struct and low-level Graph ops for every attention variant. The nn:: struct is suitable for typical prefill paths in decoder models. Use the graph ops directly when you need non-causal, cross-attention, differentiable training, or decode-time KV-cache access.
nn::CausalSelfAttention
Causal self-attention with grouped-query attention (GQA) and rotary position embeddings (RoPE). All four projection matrices are bias-free.
Fields
Query projection:
hidden → hidden.Key projection:
hidden → kv_dim.Value projection:
hidden → kv_dim.Output projection:
hidden → hidden.Number of query attention heads.
Number of key/value heads. Must divide
num_heads evenly (GQA).Dimension per head.
Base frequency for rotary position embeddings.
AttentionConfig
Model hidden dimension. Used for query, and output projection sizes.
Key/value projection dimension (
num_kv_heads * head_dim).Number of query heads.
Number of key/value heads.
Dimension per head.
RoPE base frequency (e.g.
10000.0 for standard, 500000.0 for Llama 3).CausalSelfAttention::new
The computation graph to register parameters into.
Name prefix. Registers
{name}.q_proj, {name}.k_proj, {name}.v_proj, {name}.o_proj weights.Attention configuration struct.
forward
Projects Q/K/V, applies RoPE to Q and K, then runs causal masked attention.
Forward computation:
The computation graph to append ops to.
Input tensor of shape
[seq, hidden].Output tensor of shape
[seq, hidden].Graph attention ops
All attention ops operate on 2D tensors. There are no explicit batch or head dimensions — sequences are flattened and heads are interleaved in the last dimension.g.causal_attention
Fused causal masked multi-head attention with GQA support. Intended for prefill (processing a full sequence at once).
Query tensor of shape
[seq, num_heads * head_dim].Key tensor of shape
[seq, num_kv_heads * head_dim].Value tensor of shape
[seq, num_kv_heads * head_dim].Number of query heads.
Number of key/value heads.
Dimension per head.
Output tensor of shape
[seq, num_heads * head_dim].g.full_attention
Non-causal (bidirectional) multi-head attention with GQA support. Used in vision encoders and other encoder-only contexts where every token attends to every other token.
Query tensor of shape
[seq, num_heads * head_dim].Key tensor of shape
[seq, num_kv_heads * head_dim].Value tensor of shape
[seq, num_kv_heads * head_dim].Number of query heads.
Number of key/value heads.
Dimension per head.
Output tensor of shape
[seq, num_heads * head_dim].g.cross_attention
Cross-attention where the query sequence attends to a separate key/value sequence. Q and K/V may have different sequence lengths.
Query tensor of shape
[q_seq, num_heads * head_dim].Key tensor of shape
[kv_seq, num_kv_heads * head_dim].Value tensor of shape
[kv_seq, num_kv_heads * head_dim].Number of query heads.
Number of key/value heads.
Dimension per head.
Output tensor of shape
[q_seq, num_heads * head_dim].g.multi_head_attn
Differentiable multi-head attention that saves the log-sum-exp (LSE) buffer needed for an exact backward pass. Use this variant when training.
Handles both self-attention (is_cross = false) and cross-attention (is_cross = true). The autodiff engine uses MultiHeadAttnGradQ/K/V ops that reference the saved LSE.
Query tensor of shape
[q_seq, num_heads * head_dim].Key tensor of shape
[kv_seq, num_kv_heads * head_dim].Value tensor of shape
[kv_seq, num_kv_heads * head_dim].Number of query heads.
Number of key/value heads.
Dimension per head.
false for self-attention, true for cross-attention.Output tensor of shape
[q_seq, num_heads * head_dim]. An LSE buffer [q_seq * num_heads] is also allocated internally during compilation for backward use.g.cached_attention
Decode-time attention where a single new query token attends to the entire KV cache. Requires q to have seq_len = 1.
Query tensor of shape
[1, num_heads * head_dim].Key cache buffer of shape
[max_seq, num_kv_heads * head_dim].Value cache buffer of shape
[max_seq, num_kv_heads * head_dim].U32 scalar input: the number of valid positions currently stored in the cache.
Number of query heads.
Number of key/value heads.
Dimension per head.
Output tensor of shape
[1, num_heads * head_dim].Rotary position embedding ops
g.rope
Applies RoPE starting at position 0. Use for prefill.
2D input tensor of shape
[seq, dim]. dim must be even and divisible by head_dim.RoPE base frequency.
Dimension of each attention head. Rotations are applied independently within each head.
Rotated tensor, same shape
[seq, dim].g.rope_with_offset
Applies RoPE with a static position offset. Use when the current sequence starts at a known non-zero position.
2D input tensor of shape
[seq, dim].RoPE base frequency.
Static offset added to each row’s position index.
Dimension per head.
Rotated tensor, same shape
[seq, dim].g.rope_dynamic_offset
Applies RoPE with a position offset read at runtime from a U32 input buffer. Use for decode-time single-token inference where kv_pos is known only at runtime.
2D input tensor of shape
[seq, dim].RoPE base frequency.
U32 scalar input. Its value is added to each row’s position (
position = row_index + offset_input[0]).Dimension per head.
Rotated tensor, same shape
[seq, dim].KV cache write op
g.cache_write
Writes a single new key or value row into a pre-allocated cache buffer at position kv_pos. Returns a node representing the updated cache.
New key or value tensor of shape
[1, dim].Cache buffer of shape
[max_seq, dim].U32 scalar input: the position index at which to write.
Updated cache buffer, same shape
[max_seq, dim] as the input cache.cache_write performs an in-place write at the GPU level. The returned NodeId aliases the cache buffer — use it (not the original cache input) as the source for subsequent cached_attention calls within the same graph.