Skip to main content
The layers described here are the building blocks of the Qwen3 model implementation inside nano-vLLM. They are designed for efficiency: tensor-parallel sharding, CUDA-graph compatibility, and Triton/FlashAttention kernels.
File: nanovllm/layers/attention.pyThe Attention module wraps FlashAttention and routes to the appropriate kernel depending on whether the current step is prefill or decode.KV cache writing — Triton kernelBefore attention is computed, the new key and value tensors are written to the paged KV cache using a custom Triton kernel:
@triton.jit
def store_kvcache_kernel(
    key_ptr,
    key_stride,
    value_ptr,
    value_stride,
    k_cache_ptr,
    v_cache_ptr,
    slot_mapping_ptr,
    D: tl.constexpr,
):
    idx = tl.program_id(0)
    slot = tl.load(slot_mapping_ptr + idx)
    if slot == -1: return
    key_offsets = idx * key_stride + tl.arange(0, D)
    value_offsets = idx * value_stride + tl.arange(0, D)
    key = tl.load(key_ptr + key_offsets)
    value = tl.load(value_ptr + value_offsets)
    cache_offsets = slot * D + tl.arange(0, D)
    tl.store(k_cache_ptr + cache_offsets, key)
    tl.store(v_cache_ptr + cache_offsets, value)
Each Triton program handles one token. The slot_mapping tensor maps token positions to flat slots in the paged KV cache. A slot value of -1 indicates a padding token and is skipped.forward()
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
    context = get_context()
    k_cache, v_cache = self.k_cache, self.v_cache
    if k_cache.numel() and v_cache.numel():
        store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
    if context.is_prefill:
        if context.block_tables is not None:    # prefix cache
            k, v = k_cache, v_cache
        o = flash_attn_varlen_func(
            q, k, v,
            max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
            max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
            softmax_scale=self.scale, causal=True, block_table=context.block_tables,
        )
    else:    # decode
        o = flash_attn_with_kvcache(
            q.unsqueeze(1), k_cache, v_cache,
            cache_seqlens=context.context_lens, block_table=context.block_tables,
            softmax_scale=self.scale, causal=True,
        )
    return o
  • Prefill uses flash_attn_varlen_func which handles variable-length sequences packed into a single batch tensor via cumulative sequence length arrays (cu_seqlens_q, cu_seqlens_k).
  • Decode uses flash_attn_with_kvcache which reads from the paged KV cache directly, one token per sequence.
  • When prefix caching is active during prefill (block_tables is not None), k and v are replaced with the full cache tensors so that cached key/value entries are attended to.
File: nanovllm/layers/linear.pyAll linear layers are tensor-parallel aware and implement a weight_loader hook used during model loading to shard weights correctly across ranks.QKVParallelLinearA column-parallel linear layer that fuses Q, K, and V projections into a single weight matrix. On each TP rank, only the shard corresponding to that rank’s heads is stored.
class QKVParallelLinear(ColumnParallelLinear):
    def __init__(
        self,
        hidden_size: int,
        head_size: int,
        total_num_heads: int,
        total_num_kv_heads: int | None = None,
        bias: bool = False,
    ):
        tp_size = dist.get_world_size()
        total_num_kv_heads = total_num_kv_heads or total_num_heads
        self.num_heads = divide(total_num_heads, tp_size)
        self.num_kv_heads = divide(total_num_kv_heads, tp_size)
        output_size = (total_num_heads + 2 * total_num_kv_heads) * head_size
        super().__init__(hidden_size, output_size, bias)
The weight_loader shards Q, K, and V independently by shard ID ("q", "k", "v").MergedColumnParallelLinearUsed for fused gate+up projections in the MLP. Accepts a list of output_sizes (one per merged sub-weight) and shards each sub-weight independently.
class MergedColumnParallelLinear(ColumnParallelLinear):
    def __init__(
        self,
        input_size: int,
        output_sizes: list[int],
        bias: bool = False,
    ):
        self.output_sizes = output_sizes
        super().__init__(input_size, sum(output_sizes), bias)
RowParallelLinearUsed for output projections. Each rank holds a shard of the input dimension. Results are summed across ranks via dist.all_reduce.
class RowParallelLinear(LinearBase):
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        y = F.linear(x, self.weight, self.bias if self.tp_rank == 0 else None)
        if self.tp_size > 1:
            dist.all_reduce(y)
        return y
Bias is only added by rank 0 to avoid double-counting after the all-reduce.
File: nanovllm/layers/layernorm.pyRMSNorm implements Root Mean Square Layer Normalization. It provides two @torch.compile-decorated paths: a standard forward pass and a fused residual+norm path that avoids a separate addition.
class RMSNorm(nn.Module):
    def __init__(self, hidden_size: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(hidden_size))

    @torch.compile
    def rms_forward(self, x: torch.Tensor) -> torch.Tensor:
        orig_dtype = x.dtype
        x = x.float()
        var = x.pow(2).mean(dim=-1, keepdim=True)
        x.mul_(torch.rsqrt(var + self.eps))
        x = x.to(orig_dtype).mul_(self.weight)
        return x

    @torch.compile
    def add_rms_forward(
        self, x: torch.Tensor, residual: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        orig_dtype = x.dtype
        x = x.float().add_(residual.float())
        residual = x.to(orig_dtype)
        var = x.pow(2).mean(dim=-1, keepdim=True)
        x.mul_(torch.rsqrt(var + self.eps))
        x = x.to(orig_dtype).mul_(self.weight)
        return x, residual

    def forward(self, x, residual=None):
        if residual is None:
            return self.rms_forward(x)
        else:
            return self.add_rms_forward(x, residual)
When residual is passed, forward returns (normed_x, updated_residual). This pattern allows the residual stream to be carried separately and added just before each norm, reducing memory bandwidth.
File: nanovllm/layers/rotary_embedding.pyRotaryEmbedding precomputes a cos_sin_cache of shape (max_position_embeddings, 1, rotary_dim) at construction time and looks up the relevant entries by position index at runtime.
class RotaryEmbedding(nn.Module):
    def __init__(self, head_size, rotary_dim, max_position_embeddings, base):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim))
        t = torch.arange(max_position_embeddings, dtype=torch.float)
        freqs = torch.einsum("i,j -> ij", t, inv_freq)
        cache = torch.cat((freqs.cos(), freqs.sin()), dim=-1).unsqueeze_(1)
        self.register_buffer("cos_sin_cache", cache, persistent=False)

    @torch.compile
    def forward(self, positions, query, key):
        cos_sin = self.cos_sin_cache[positions]
        cos, sin = cos_sin.chunk(2, dim=-1)
        query = apply_rotary_emb(query, cos, sin)
        key = apply_rotary_emb(key, cos, sin)
        return query, key
get_rope() factoryA module-level LRU-cached factory function ensures only one RotaryEmbedding instance is created per unique set of parameters (used across all attention layers with the same config):
@lru_cache(1)
def get_rope(
    head_size: int,
    rotary_dim: int,
    max_position: int,
    base: float,
    rope_scaling: dict | None = None,
):
    assert rope_scaling is None
    return RotaryEmbedding(head_size, rotary_dim, max_position, base)
rope_scaling is accepted as a parameter for API compatibility but is not implemented. Passing a non-None value raises an AssertionError.
File: nanovllm/layers/activation.pySiluAndMul implements the gated activation function used in the SwiGLU MLP variant. It splits the input in half along the last dimension and applies SiLU to the first half, then element-wise multiplies by the second half.
class SiluAndMul(nn.Module):
    @torch.compile
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x, y = x.chunk(2, -1)
        return F.silu(x) * y
This is used after MergedColumnParallelLinear projects the hidden state to 2 * intermediate_size, producing both the gate and the value in a single matmul.
File: nanovllm/layers/sampler.pySampler converts the final logits tensor into sampled token IDs using temperature scaling and the Gumbel-max trick (equivalent to multinomial sampling).
class Sampler(nn.Module):
    @torch.compile
    def forward(self, logits: torch.Tensor, temperatures: torch.Tensor):
        logits = logits.float().div_(temperatures.unsqueeze(dim=1))
        probs = torch.softmax(logits, dim=-1)
        sample_tokens = probs.div_(
            torch.empty_like(probs).exponential_(1).clamp_min_(1e-10)
        ).argmax(dim=-1)
        return sample_tokens
Dividing probabilities by independent Exponential(1) samples and taking the argmax is equivalent to sampling from the categorical distribution (Gumbel-max trick). clamp_min_(1e-10) prevents division by zero.The sampler only runs on TP rank 0; other ranks return None.
File: nanovllm/layers/embed_head.pyVocabParallelEmbeddingShards the vocabulary embedding table across TP ranks. Each rank stores num_embeddings // tp_size rows. During the forward pass, tokens outside a rank’s shard are masked to zero, and the results are summed via all_reduce.
class VocabParallelEmbedding(nn.Module):
    def forward(self, x: torch.Tensor):
        if self.tp_size > 1:
            mask = (x >= self.vocab_start_idx) & (x < self.vocab_end_idx)
            x = mask * (x - self.vocab_start_idx)
        y = F.embedding(x, self.weight)
        if self.tp_size > 1:
            y = mask.unsqueeze(1) * y
            dist.all_reduce(y)
        return y
ParallelLMHeadSubclasses VocabParallelEmbedding and uses the same weight shard for the output projection (weight tying). During prefill, it first selects only the last token of each sequence (the positions that need logits) before computing the full-vocabulary linear projection.
class ParallelLMHead(VocabParallelEmbedding):
    def forward(self, x: torch.Tensor):
        context = get_context()
        if context.is_prefill:
            last_indices = context.cu_seqlens_q[1:] - 1
            x = x[last_indices].contiguous()
        logits = F.linear(x, self.weight)
        if self.tp_size > 1:
            all_logits = [torch.empty_like(logits) for _ in range(self.tp_size)] \
                if self.tp_rank == 0 else None
            dist.gather(logits, all_logits, 0)
            logits = torch.cat(all_logits, -1) if self.tp_rank == 0 else None
        return logits
Logit shards are gathered to rank 0 via dist.gather and concatenated to produce the full vocabulary logits. Ranks 1…N return None.

Build docs developers (and LLMs) love