Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/karpathy/nanochat/llms.txt

Use this file to discover all available pages before exploring further.

The Engine class provides efficient inference for NanoChat models using KV caching and supports advanced features like tool use and multi-sample generation.

Overview

The engine is designed for maximum efficiency:
  • KV Cache: Stores key-value pairs from previous tokens to avoid recomputation
  • Streaming Generation: Yields tokens one at a time for real-time output
  • Batch Generation: Generate multiple samples in parallel
  • Tool Use: Built-in calculator tool with automatic result injection

Basic Usage

from nanochat.engine import Engine
from nanochat.checkpoint_manager import load_model

# Load model and tokenizer
model, tokenizer, meta = load_model("sft", device, phase="eval")

# Create engine
engine = Engine(model, tokenizer)

# Generate tokens
prompt_tokens = tokenizer.encode("What is 2+2?", prepend=bos_token_id)
for token_column, token_masks in engine.generate(prompt_tokens, num_samples=1, max_tokens=100):
    token = token_column[0]
    print(tokenizer.decode([token]), end="", flush=True)

Generation Methods

Streaming Generation

generate(tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42) Streaming generator that yields tokens one at a time. Parameters:
  • tokens (list[int]): Input token sequence
  • num_samples (int): Number of parallel samples to generate (default: 1)
  • max_tokens (int): Maximum tokens to generate (default: None = unlimited)
  • temperature (float): Sampling temperature, 0.0 = greedy (default: 1.0)
  • top_k (int): Top-k sampling parameter (default: None)
  • seed (int): Random seed (default: 42)
Yields:
  • token_column (list[int]): Next token for each sample (length = num_samples)
  • token_masks (list[int]): 1 if sampled, 0 if forced by tool (length = num_samples)
Example:
for token_column, token_masks in engine.generate(
    prompt_tokens,
    num_samples=4,  # Generate 4 samples in parallel
    max_tokens=256,
    temperature=0.8,
    top_k=50,
    seed=12345
):
    for i, (token, mask) in enumerate(zip(token_column, token_masks)):
        if mask == 1:
            print(f"Sample {i}: {tokenizer.decode([token])}")
        else:
            print(f"Sample {i}: [FORCED] {tokenizer.decode([token])}")

Batch Generation

generate_batch(tokens, num_samples=1, **kwargs) Non-streaming batch generation that returns complete token sequences. Returns:
  • results (list[list[int]]): Token sequences for each sample
  • masks (list[list[int]]): Mask sequences (1=sampled, 0=forced)
Example:
results, masks = engine.generate_batch(
    prompt_tokens,
    num_samples=4,
    max_tokens=128,
    temperature=0.7
)

for i, (tokens, mask) in enumerate(zip(results, masks)):
    text = tokenizer.decode(tokens)
    print(f"Sample {i}: {text}")

KV Cache

The KV cache stores key-value pairs from attention layers to avoid recomputing them for previous tokens.

Architecture

From nanochat/engine.py:83-133:
class KVCache:
    """
    KV Cache designed for Flash Attention 3's flash_attn_with_kvcache API.
    
    Key differences from FA2-style cache:
    - Tensors are (B, T, H, D) not (B, H, T, D)
    - FA3 updates the cache in-place during flash_attn_with_kvcache
    - Position tracked per batch element via cache_seqlens tensor
    """
    
    def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers, device, dtype):
        self.batch_size = batch_size
        self.max_seq_len = seq_len
        self.n_layers = num_layers
        self.n_heads = num_heads
        self.head_dim = head_dim
        # Pre-allocate cache tensors: (n_layers, B, T, H, D)
        self.k_cache = torch.zeros(num_layers, batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
        self.v_cache = torch.zeros(num_layers, batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
        # Current sequence length per batch element (FA3 needs int32)
        self.cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)

Key Methods

  • reset(): Reset cache to empty state
  • get_pos(): Get current position (assumes all batch elements at same position)
  • get_layer_cache(layer_idx): Return (k_cache, v_cache) views for a specific layer
  • advance(num_tokens): Advance the cache position by num_tokens
  • prefill(other): Copy cached KV from another cache (used for multi-sample generation)

Prefill-then-Decode Pattern

The engine uses an efficient two-phase approach:
  1. Prefill: Process the entire prompt in batch=1
  2. Decode: Clone the KV cache for each sample and generate in parallel
From nanochat/engine.py:194-218:
# 1) Run a batch 1 prefill of the prompt tokens
m = self.model.config
kv_model_kwargs = {"num_heads": m.n_kv_head, "head_dim": m.n_embd // m.n_head, "num_layers": m.n_layer}
kv_cache_prefill = KVCache(
    batch_size=1,
    seq_len=len(tokens),
    device=device,
    dtype=dtype,
    **kv_model_kwargs,
)
ids = torch.tensor([tokens], dtype=torch.long, device=device)
logits = self.model.forward(ids, kv_cache=kv_cache_prefill)
logits = logits[:, -1, :].expand(num_samples, -1)  # (num_samples, vocab_size)

# 2) Replicate the KV cache for each sample/row
kv_length_hint = (len(tokens) + max_tokens) if max_tokens is not None else self.model.config.sequence_len
kv_cache_decode = KVCache(
    batch_size=num_samples,
    seq_len=kv_length_hint,
    device=device,
    dtype=dtype,
    **kv_model_kwargs,
)
kv_cache_decode.prefill(kv_cache_prefill)
del kv_cache_prefill  # no need to keep this memory around
This approach processes the prompt once and then generates multiple diverse samples efficiently.

Token Sampling

The engine uses a custom sampling function that supports temperature and top-k sampling. From nanochat/engine.py:135-152:
@torch.inference_mode()
def sample_next_token(logits, rng, temperature=1.0, top_k=None):
    """Sample a single next token from given logits of shape (B, vocab_size). Returns (B, 1)."""
    assert temperature >= 0.0, "temperature must be non-negative"
    if temperature == 0.0:
        return torch.argmax(logits, dim=-1, keepdim=True)
    if top_k is not None and top_k > 0:
        k = min(top_k, logits.size(-1))
        vals, idx = torch.topk(logits, k, dim=-1)
        vals = vals / temperature
        probs = F.softmax(vals, dim=-1)
        choice = torch.multinomial(probs, num_samples=1, generator=rng)
        return idx.gather(1, choice)
    else:
        logits = logits / temperature
        probs = F.softmax(logits, dim=-1)
        return torch.multinomial(probs, num_samples=1, generator=rng)
Sampling Modes:
  • temperature=0.0: Greedy decoding (always pick most likely token)
  • temperature=1.0: Standard sampling from full distribution
  • temperature>1.0: More random/creative (flattens distribution)
  • temperature<1.0: More focused/deterministic (sharpens distribution)
  • top_k: Only sample from top-k most likely tokens

Tool Use: Calculator

The engine includes built-in support for a calculator tool. When the model generates special tokens, the engine automatically evaluates expressions and injects results.

How It Works

  1. Model generates <|python_start|> token
  2. Engine enters “python block” mode and accumulates tokens
  3. Model generates <|python_end|> token
  4. Engine evaluates the expression using use_calculator()
  5. If successful, engine forces <|output_start|> + result + <|output_end|> tokens
  6. Model continues generation with the result in context
From nanochat/engine.py:251-267:
# Handle tool logic
if next_token == python_start:
    state.in_python_block = True
    state.python_expr_tokens = []
elif next_token == python_end and state.in_python_block:
    state.in_python_block = False
    if state.python_expr_tokens:
        expr = self.tokenizer.decode(state.python_expr_tokens)
        result = use_calculator(expr)
        if result is not None:
            result_tokens = self.tokenizer.encode(str(result))
            state.forced_tokens.append(output_start)
            state.forced_tokens.extend(result_tokens)
            state.forced_tokens.append(output_end)
    state.python_expr_tokens = []
elif state.in_python_block:
    state.python_expr_tokens.append(next_token)

Supported Expressions

The calculator supports:
  • Math expressions: 2+2, 3.14*10, 100/5
  • String operations: "hello".count("l"), "world".count("o")
Safety features:
  • Timeout after 3 seconds
  • No access to builtins or dangerous operations
  • Disallows power operator **
  • Sanitizes input to prevent code injection
From nanochat/engine.py:47-80:
def use_calculator(expr):
    """
    Evaluate a Python expression safely.
    Supports both math expressions and string operations like .count()
    """
    # Remove commas from numbers
    expr = expr.replace(",", "")
    
    # Check if it's a pure math expression
    if all([x in "0123456789*+-/.() " for x in expr]):
        if "**" in expr:  # disallow power operator
            return None
        return eval_with_timeout(expr)
    
    # Check if it's a string operation we support
    allowed_chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'\"()._ "
    if not all([x in allowed_chars for x in expr]):
        return None
    
    # Disallow dangerous patterns
    dangerous_patterns = ['__', 'import', 'exec', 'eval', 'compile', 'open', 'file',
                         'input', 'raw_input', 'globals', 'locals', 'vars', 'dir',
                         'getattr', 'setattr', 'delattr', 'hasattr']
    expr_lower = expr.lower()
    if any(pattern in expr_lower for pattern in dangerous_patterns):
        return None
    
    # Only allow .count() method for now
    if '.count(' not in expr:
        return None
    
    return eval_with_timeout(expr)

Row State Tracking

When generating multiple samples in parallel, the engine maintains per-row state to track tool use independently. From nanochat/engine.py:155-162:
class RowState:
    # Per-row state tracking during generation
    def __init__(self, current_tokens=None):
        self.current_tokens = current_tokens or []  # Current token sequence for this row
        self.forced_tokens = deque()  # Queue of tokens to force inject
        self.in_python_block = False  # Whether we are inside a python block
        self.python_expr_tokens = []  # Tokens of the current python expression
        self.completed = False  # Whether this row has completed generation
Each sample maintains:
  • current_tokens: Full token history
  • forced_tokens: Queue of tokens to inject (from tool results)
  • in_python_block: Whether currently inside <|python_start|><|python_end|>
  • python_expr_tokens: Accumulated expression tokens
  • completed: Whether generation has ended for this sample

Performance Testing

The engine includes a built-in test to verify correctness and benchmark performance.
python -m nanochat.engine
This compares the engine’s output against the model’s naive generation function and reports timing. From nanochat/engine.py:302-357:
if __name__ == "__main__":
    """
    Quick inline test to make sure that the naive/slow model.generate function
    is equivalent to the faster Engine.generate function here.
    """
    # Load model
    model, tokenizer, meta = load_model("base", device, phase="eval")
    bos_token_id = tokenizer.get_bos_token_id()
    kwargs = dict(max_tokens=64, temperature=0.0)
    prompt_tokens = tokenizer.encode("The chemical formula of water is", prepend=bos_token_id)
    
    # Generate with reference implementation
    generated_tokens = []
    torch.cuda.synchronize()
    t0 = time.time()
    stream = model.generate(prompt_tokens, **kwargs)
    with autocast_ctx:
        for token in stream:
            generated_tokens.append(token)
    torch.cuda.synchronize()
    t1 = time.time()
    print(f"Reference time: {t1 - t0:.2f}s")
    reference_ids = generated_tokens
    
    # Generate with Engine
    generated_tokens = []
    engine = Engine(model, tokenizer)
    stream = engine.generate(prompt_tokens, num_samples=1, **kwargs)
    torch.cuda.synchronize()
    t0 = time.time()
    with autocast_ctx:
        for token_column, token_masks in stream:
            token = token_column[0]
            generated_tokens.append(token)
    torch.cuda.synchronize()
    t1 = time.time()
    print(f"Engine time: {t1 - t0:.2f}s")
    
    # Compare
    print(f"Match: {reference_ids == generated_tokens}")

Complete Example: Multi-Sample Generation

import torch
from nanochat.engine import Engine
from nanochat.checkpoint_manager import load_model
from nanochat.common import compute_init, autodetect_device_type

# Initialize
device_type = autodetect_device_type()
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
model, tokenizer, meta = load_model("sft", device, phase="eval")
engine = Engine(model, tokenizer)

# Prepare prompt
bos = tokenizer.get_bos_token_id()
user_start = tokenizer.encode_special("<|user_start|>")
user_end = tokenizer.encode_special("<|user_end|>")
assistant_start = tokenizer.encode_special("<|assistant_start|>")

tokens = [bos, user_start]
tokens.extend(tokenizer.encode("Tell me a joke"))
tokens.extend([user_end, assistant_start])

# Generate 4 different jokes in parallel
results, masks = engine.generate_batch(
    tokens,
    num_samples=4,
    max_tokens=200,
    temperature=1.0,
    top_k=50,
    seed=42
)

for i, (result_tokens, mask) in enumerate(zip(results, masks)):
    # Only decode the assistant's response (after assistant_start)
    response_start = len(tokens)
    response_tokens = result_tokens[response_start:]
    text = tokenizer.decode(response_tokens)
    print(f"\n=== Sample {i+1} ===")
    print(text)
This efficiently generates 4 diverse responses by processing the prompt once and then sampling 4 times in parallel.

Build docs developers (and LLMs) love