Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/Wenyueh/MinivLLM/llms.txt

Use this file to discover all available pages before exploring further.

The myvllm.layers.attention module implements two core attention strategies:
  • Prefill — flash attention over variable-length sequences using a Triton kernel.
  • Decode — paged attention over a block-structured KV cache using a separate Triton kernel.
The Attention class acts as a unified interface and selects the correct path at runtime via the shared request context.
All kernels are written in Triton and run on CUDA GPUs. Ensure your tensors are on a CUDA device before calling any function in this module.

Attention

myvllm.layers.attention.Attention A PyTorch nn.Module that wraps both attention kernels and owns the paged KV cache buffers. On each forward pass it reads the global Context object to decide whether the engine is in prefill or decode mode, and dispatches accordingly. When the KV cache is allocated and slot_mapping is present in the context, the module calls store_kvcache before computing attention so that new key/value pairs are written to the cache automatically.

Constructor

num_heads
int
required
Total number of query heads per GPU.
head_dim
int
required
Dimension of each attention head.
scale
float
default:"1.0"
Multiplied with 1 / sqrt(head_dim) to produce the final attention scale factor.
num_kv_heads
int
Number of key/value heads. Defaults to num_heads (multi-head attention). Set to a smaller value for grouped-query attention (GQA).
block_size
int
default:"16"
Number of token slots per block in the paged KV cache. Must match the block size used when allocating k_cache and v_cache.

forward

forward(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor
Runs either flash attention (prefill) or paged attention (decode) based on context.is_prefill. Returns a tensor of shape (total_tokens, num_heads * head_dim) in prefill mode, or (batch_size, num_heads * head_dim) in decode mode.
The k_cache and v_cache attributes start as empty tensors. The engine’s memory manager must allocate and assign them before the first decode step.
Example usage
import torch
from myvllm.layers import Attention

attn = Attention(num_heads=8, head_dim=64, num_kv_heads=2, block_size=16).cuda()

# Prefill — context.is_prefill must be True and context.cu_seqlens_q must be set
q = torch.randn(32, 8, 64).cuda()   # (total_tokens, num_heads, head_dim)
k = torch.randn(32, 2, 64).cuda()   # (total_tokens, num_kv_heads, head_dim)
v = torch.randn(32, 2, 64).cuda()
out = attn(q, k, v)  # (32, 512)

flash_attention_prefill

flash_attention_prefill(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    cu_seqlens: torch.Tensor,
    scale: float,
    num_heads: int,
    num_kv_heads: int,
    head_dim: int,
) -> torch.Tensor
Triton-based flash attention for the prefill phase. Supports batched variable-length sequences packed into a single tensor (varlen format). Implements online softmax with causal masking and grouped-query attention.

Parameters

q
torch.Tensor
required
Query tensor of shape (total_tokens, num_heads, head_dim).
k
torch.Tensor
required
Key tensor of shape (total_tokens, num_kv_heads, head_dim).
v
torch.Tensor
required
Value tensor of shape (total_tokens, num_kv_heads, head_dim).
cu_seqlens
torch.Tensor
required
Cumulative sequence lengths tensor of shape (num_seqs + 1,). cu_seqlens[i] is the start token index of sequence i; cu_seqlens[-1] equals total_tokens. For example, two sequences of lengths 5 and 7 would give [0, 5, 12].
scale
float
required
Attention scale factor, typically 1.0 / sqrt(head_dim).
num_heads
int
required
Number of query heads.
num_kv_heads
int
required
Number of key/value heads. When num_kv_heads < num_heads, each KV head is shared across num_heads / num_kv_heads query heads (GQA).
head_dim
int
required
Dimension of each head.

Returns

torch.Tensor — shape (total_tokens, num_heads, head_dim).

Block sizing

Tile sizes BLOCK_M and BLOCK_N are chosen automatically to stay within the ~48 KB shared memory budget of most GPUs:
head_dimBLOCK_MBLOCK_N
≤ 646464
≤ 1283232
> 1281616

Kernel grid

The Triton kernel is launched with grid (ceil(max_seq_len / BLOCK_M), num_heads, num_seqs). Each program processes one tile of queries for one head of one sequence.

paged_attention_decode

paged_attention_decode(
    query: torch.Tensor,
    k_cache: torch.Tensor,
    v_cache: torch.Tensor,
    block_tables: torch.Tensor,
    context_lens: torch.Tensor,
    scale: float,
    num_heads: int,
    num_kv_heads: int,
    head_dim: int,
    block_size: int,
) -> torch.Tensor
Triton-based paged attention for the decode phase. Each sequence in the batch generates exactly one new token, attending over all previously cached key/value pairs stored in non-contiguous physical memory blocks. Uses online softmax for numerically stable attention.

Parameters

query
torch.Tensor
required
Query tensor of shape (batch_size, num_heads, head_dim). One query vector per sequence.
k_cache
torch.Tensor
required
Paged key cache of shape (num_blocks, block_size, num_kv_heads, head_dim).
v_cache
torch.Tensor
required
Paged value cache of shape (num_blocks, block_size, num_kv_heads, head_dim).
block_tables
torch.Tensor
required
Physical block index lookup table of shape (batch_size, max_num_blocks). block_tables[i, j] is the physical block index for logical block j of sequence i. Use -1 for unallocated slots.
context_lens
torch.Tensor
required
1-D tensor of shape (batch_size,) containing the number of valid KV tokens for each sequence in the cache.
scale
float
required
Attention scale factor, typically 1.0 / sqrt(head_dim).
num_heads
int
required
Number of query heads.
num_kv_heads
int
required
Number of key/value heads.
head_dim
int
required
Dimension of each head.
block_size
int
required
Number of token slots per physical block. Must match the block size used when the cache was allocated.

Returns

torch.Tensor — shape (batch_size, num_heads, head_dim).

Kernel grid

The kernel is launched with grid (batch_size, num_heads). Each program is responsible for one query head of one sequence and iterates over the full context in chunks of BLOCK_N (64 when head_dim ≤ 128, otherwise 32).

store_kvcache

store_kvcache(
    key: torch.Tensor,
    value: torch.Tensor,
    k_cache: torch.Tensor,
    v_cache: torch.Tensor,
    slot_mapping: torch.Tensor,
    block_size: int,
)
Writes key and value tensors into a paged KV cache using a Triton kernel. Each token is mapped to a specific cache slot via slot_mapping. Tokens whose slot is -1 are silently skipped (padding tokens). The kernel is launched with grid (num_tokens, num_kv_heads) so that each thread processes one token for one KV head.

Parameters

key
torch.Tensor
required
Keys to store, shape (num_tokens, num_kv_heads, head_dim).
value
torch.Tensor
required
Values to store, shape (num_tokens, num_kv_heads, head_dim).
k_cache
torch.Tensor
required
Destination key cache, shape (num_blocks, block_size, num_kv_heads, head_dim).
v_cache
torch.Tensor
required
Destination value cache, shape (num_blocks, block_size, num_kv_heads, head_dim).
slot_mapping
torch.Tensor
required
1-D integer tensor of shape (num_tokens,). Each element is a flat cache slot index in the range [0, num_blocks * block_size), or -1 to skip the token.The slot maps to a physical location as:
block_idx    = slot // block_size
block_offset = slot %  block_size
block_size
int
required
Number of token slots per physical block.
Example
from myvllm.layers.attention import store_kvcache
import torch

num_tokens, num_kv_heads, head_dim = 4, 2, 64
num_blocks, block_size = 8, 16

key   = torch.randn(num_tokens, num_kv_heads, head_dim).cuda()
value = torch.randn(num_tokens, num_kv_heads, head_dim).cuda()
k_cache = torch.zeros(num_blocks, block_size, num_kv_heads, head_dim).cuda()
v_cache = torch.zeros(num_blocks, block_size, num_kv_heads, head_dim).cuda()
slot_mapping = torch.tensor([0, 1, 16, -1], dtype=torch.int32).cuda()

store_kvcache(key, value, k_cache, v_cache, slot_mapping, block_size)
# token 3 is skipped; tokens 0, 1, 2 are stored at slots 0, 1, 16

Build docs developers (and LLMs) love