Documentation Index Fetch the complete documentation index at: https://mintlify.com/karpathy/nanochat/llms.txt
Use this file to discover all available pages before exploring further.
Overview
Nanochat uses a unified Flash Attention interface that automatically selects the best implementation based on available hardware:
Flash Attention 3 (FA3) : On Hopper GPUs (H100, H200) - fastest
PyTorch SDPA : Fallback for Ada, Blackwell, Ampere, MPS, and CPU
This provides optimal performance across different hardware without code changes.
Key Features
Drop-in replacement for FA3 with identical API
Zero-overhead hardware detection (once at import time)
Supports causal attention and sliding window attention
KV cache support for inference
Automatic layout conversion for SDPA fallback
Reference: flash_attention.py:1-15
Usage
from nanochat.flash_attention import flash_attn
# Training (no KV cache)
y = flash_attn.flash_attn_func(q, k, v, causal = True , window_size = window_size)
# Inference (with KV cache)
y = flash_attn.flash_attn_with_kvcache(
q, k_cache, v_cache,
k = k, v = v,
cache_seqlens = cache_seqlens,
causal = True ,
window_size = window_size,
)
The API is identical whether FA3 or SDPA is used under the hood.
Hardware Detection
Loading FA3
FA3 is loaded at import time if conditions are met:
def _load_flash_attention_3 ():
if not torch.cuda.is_available():
return None
major, _ = torch.cuda.get_device_capability()
# FA3 only on Hopper (sm90)
if major != 9 :
return None
from kernels import get_kernel
return get_kernel( 'varunneal/flash-attention-3' ).flash_attn_interface
Supported Hardware :
✅ Hopper (H100, H200) - compute capability 9.0
❌ Blackwell - compute capability 10.0 (needs recompilation)
❌ Ada (RTX 4090) - compute capability 8.9
❌ Ampere (A100) - compute capability 8.0
Reference: flash_attention.py:23-38
Why Not Blackwell?
FA3 kernels are compiled specifically for sm90 (Hopper). Blackwell (sm100) requires kernel recompilation. The SDPA fallback provides good performance until FA3 adds Blackwell support.
Detection Result
from nanochat.flash_attention import HAS_FA3
if HAS_FA3 :
print ( "Using Flash Attention 3" )
else :
print ( "Using PyTorch SDPA fallback" )
Reference: flash_attention.py:42
API Reference
flash_attn_func
flash_attn.flash_attn_func(
q, # (B, T, H, D) - queries
k, # (B, T, H_kv, D) - keys
v, # (B, T, H_kv, D) - values
causal = False , # Use causal masking
window_size = ( - 1 , - 1 ), # (left, right) window
)
Tensor Layout : (batch, sequence, heads, dim) - FA3’s native layout
Window Size :
(-1, 0): Full causal attention
(N, 0): Sliding window, attend to last N+1 tokens
(-1, -1): Full bidirectional (not causal)
Returns : Output tensor of shape (B, T, H, D)
Reference: flash_attention.py:99-120
flash_attn_with_kvcache
flash_attn.flash_attn_with_kvcache(
q, # (B, T_new, H, D) - queries
k_cache, # (B, T_max, H_kv, D) - key cache
v_cache, # (B, T_max, H_kv, D) - value cache
k = None , # (B, T_new, H_kv, D) - new keys
v = None , # (B, T_new, H_kv, D) - new values
cache_seqlens = None , # (B,) - current position in cache
causal = False ,
window_size = ( - 1 , - 1 ),
)
In-place Updates : Both FA3 and SDPA versions update k_cache and v_cache in-place.
Cache Management :
k_cache, v_cache: Pre-allocated tensors (usually T_max = 4096 or larger)
cache_seqlens: Tensor tracking current position (shape: (B,))
New keys/values inserted at position cache_seqlens[0]:cache_seqlens[0]+T_new
Returns : Output tensor of shape (B, T_new, H, D)
Reference: flash_attention.py:123-169
SDPA Fallback Implementation
Layout Conversion
FA3 uses (B, T, H, D) layout, but SDPA expects (B, H, T, D):
# Convert to SDPA layout
q = q.transpose( 1 , 2 ) # (B, T, H, D) -> (B, H, T, D)
k = k.transpose( 1 , 2 )
v = v.transpose( 1 , 2 )
# Call SDPA
y = F.scaled_dot_product_attention(q, k, v, ... )
# Convert back
y = y.transpose( 1 , 2 ) # (B, H, T, D) -> (B, T, H, D)
Reference: flash_attention.py:115-120
Sliding Window Support
SDPA doesn’t natively support sliding windows, so we build an explicit mask:
def _sdpa_attention ( q , k , v , window_size , enable_gqa ):
Tq, Tk = q.size( 2 ), k.size( 2 )
window = window_size[ 0 ] # left window size
# Full context, same length → use is_causal=True
if (window < 0 or window >= Tq) and Tq == Tk:
return F.scaled_dot_product_attention(q, k, v, is_causal = True )
# Single token (generation) → no mask needed
if Tq == 1 :
if window >= 0 and window < Tk:
k = k[:, :, - (window + 1 ):, :] # Slice to window
v = v[:, :, - (window + 1 ):, :]
return F.scaled_dot_product_attention(q, k, v, is_causal = False )
# Sliding window or chunk inference → explicit mask
row_idx = (Tk - Tq) + torch.arange(Tq).unsqueeze( 1 )
col_idx = torch.arange(Tk).unsqueeze( 0 )
mask = col_idx <= row_idx # Causal mask
if window >= 0 :
mask = mask & ((row_idx - col_idx) <= window) # Sliding window
return F.scaled_dot_product_attention(q, k, v, attn_mask = mask)
Reference: flash_attention.py:61-94
GQA Support
SDPA has native GQA support (enabled automatically):
enable_gqa = q.size( 1 ) != k.size( 1 ) # Different number of heads
y = F.scaled_dot_product_attention(q, k, v, enable_gqa = enable_gqa)
Reference: flash_attention.py:72, flash_attention.py:118
KV Cache Pattern
Typical usage in GPT model:
class CausalSelfAttention ( nn . Module ):
def forward ( self , x , ..., window_size , kv_cache ):
q = self .c_q(x).view(B, T, self .n_head, self .head_dim)
k = self .c_k(x).view(B, T, self .n_kv_head, self .head_dim)
v = self .c_v(x).view(B, T, self .n_kv_head, self .head_dim)
# ... apply RoPE and QK norm ...
if kv_cache is None :
# Training: no cache
y = flash_attn.flash_attn_func(
q, k, v, causal = True , window_size = window_size
)
else :
# Inference: use cache
k_cache, v_cache = kv_cache.get_layer_cache( self .layer_idx)
y = flash_attn.flash_attn_with_kvcache(
q, k_cache, v_cache,
k = k, v = v,
cache_seqlens = kv_cache.cache_seqlens,
causal = True ,
window_size = window_size,
)
# Advance position after last layer
if self .layer_idx == kv_cache.n_layers - 1 :
kv_cache.advance(T)
return y
Reference: gpt.py:98-113
FA3 (Hopper)
Speed : ~3x faster than SDPA on H100
Memory : More efficient (FlashAttention algorithm)
Precision : BFloat16
Limitations : Hopper-only (sm90)
SDPA Fallback
Speed : Good performance on all hardware
Memory : Standard memory-efficient attention
Precision : Adapts to input dtype
Coverage : Works everywhere (CUDA, MPS, CPU)
Testing and Override
For testing, you can force a specific implementation:
import nanochat.flash_attention as fa_module
# Force SDPA (even on Hopper)
fa_module._override_impl = 'sdpa'
# Force FA3 (will assert if not available)
fa_module._override_impl = 'fa3'
# Auto-detect (default)
fa_module._override_impl = None
Reference: flash_attention.py:45-55
Common Patterns
Training (Full Sequences)
# Shapes: (B=32, T=2048, H=12, D=64)
y = flash_attn.flash_attn_func(
q, k, v,
causal = True ,
window_size = ( - 1 , 0 ), # Full context
)
Training (Sliding Window)
# Attend to last 1024 tokens only
y = flash_attn.flash_attn_func(
q, k, v,
causal = True ,
window_size = ( 1024 , 0 ),
)
Inference (Single Token)
# q: (B=1, T_new=1, H=12, D=64)
# k_cache, v_cache: (B=1, T_max=4096, H_kv=4, D=64)
# cache_seqlens: (B=1,) = [current_position]
y = flash_attn.flash_attn_with_kvcache(
q, k_cache, v_cache,
k = k, v = v, # (B=1, 1, H_kv=4, D=64)
cache_seqlens = cache_seqlens,
causal = True ,
window_size = ( - 1 , 0 ),
)
cache_seqlens += 1 # Advance position
Inference (Chunk)
# Process multiple tokens at once (e.g., prompt encoding)
# q: (B=1, T_new=128, H=12, D=64)
y = flash_attn.flash_attn_with_kvcache(
q, k_cache, v_cache,
k = k, v = v, # (B=1, 128, H_kv=4, D=64)
cache_seqlens = cache_seqlens,
causal = True ,
window_size = ( - 1 , 0 ),
)
cache_seqlens += 128 # Advance by chunk size
Advantages of Unified Interface
Write once, run anywhere : Same code works on H100, A100, RTX 4090, Mac, etc.
No conditional logic in model : Model code doesn’t need to check hardware
Easy testing : Test SDPA path on Hopper by overriding
Future-proof : When FA3 supports Blackwell, no code changes needed
Limitations
FA3 Limitations
Hopper-only (H100, H200)
BFloat16 only
Requires kernels package (varunneal/flash-attention-3)
SDPA Limitations
Slower than FA3 on Hopper
Sliding window requires explicit mask (memory overhead)
Chunk inference needs careful mask construction
GPT Architecture How the model uses Flash Attention
Optimizer MuonAdamW optimizer details