What is prefix caching?
When multiple requests share a common prefix — for example, the same system prompt or a few-shot example block — the attention keys and values computed for that prefix are identical across all requests. Prefix caching detects this and reuses the already-computed KV cache blocks instead of running the transformer layers again.
The result is a shorter effective prefill for every request that hits the cache, which lowers latency and improves throughput on workloads with repetitive prefixes.
Prefix caching is always enabled. There is no configuration flag to toggle it.
How it works
The BlockManager divides each sequence’s token stream into fixed-size blocks (kvcache_block_size = 256 tokens by default). Each fully-populated block is identified by a content hash that chains through the sequence:
@classmethod
def compute_hash(cls, token_ids: list[int], prefix: int = -1):
h = xxhash.xxh64()
if prefix != -1:
h.update(prefix.to_bytes(8, "little"))
h.update(np.array(token_ids).tobytes())
return h.intdigest()
The prefix argument is the hash of the previous block, so the hash encodes the entire token history up to that point — not just the current block. Two blocks with the same content at the same position in the sequence are guaranteed to produce the same hash.
Allocation with cache lookup
When BlockManager.allocate assigns blocks to a new sequence it walks the blocks in order and checks hash_to_block_id for each one:
def allocate(self, seq: Sequence):
h = -1
cache_miss = False
for i in range(seq.num_blocks):
token_ids = seq.block(i)
h = self.compute_hash(token_ids, h) if len(token_ids) == self.block_size else -1
block_id = self.hash_to_block_id.get(h, -1)
if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
cache_miss = True
if cache_miss:
block_id = self.free_block_ids[0]
block = self._allocate_block(block_id)
else:
seq.num_cached_tokens += self.block_size
...
if h != -1:
block.update(h, token_ids)
self.hash_to_block_id[h] = block_id
seq.block_table.append(block_id)
Every cache hit increments seq.num_cached_tokens by block_size. The scheduler later uses this value to skip those tokens during prefill:
# scheduler.py
num_batched_tokens += len(seq) - seq.num_cached_tokens
Only the uncached suffix of the prompt is sent through the model.
Cache persistence across requests
Blocks are reference-counted. When a sequence finishes and is deallocated, its blocks are returned to the free list — but their hash and token_ids metadata is preserved. A future request whose prefix hashes match will re-use those physical blocks without eviction.
Example: repeated system prompt
The most common use case is a fixed system prompt prepended to every user message.
from nanovllm import LLM, SamplingParams
from transformers import AutoTokenizer
import os
path = os.path.expanduser("~/huggingface/Qwen3-0.6B/")
tokenizer = AutoTokenizer.from_pretrained(path)
llm = LLM(path, enforce_eager=False)
system_prompt = (
"You are a helpful assistant. Answer concisely and accurately."
)
questions = [
"What is the speed of light?",
"Who wrote Hamlet?",
"What is the boiling point of water at sea level?",
"Name the planets in the solar system.",
]
# Each prompt shares the same system prefix
prompts = [
tokenizer.apply_chat_template(
[
{"role": "system", "content": system_prompt},
{"role": "user", "content": q},
],
tokenize=False,
add_generation_prompt=True,
)
for q in questions
]
sampling_params = SamplingParams(temperature=0.6, max_tokens=128)
outputs = llm.generate(prompts, sampling_params)
for q, out in zip(questions, outputs):
print(f"Q: {q}")
print(f"A: {out['text']}\n")
After the first request populates the KV cache blocks for the system prompt, all subsequent requests with the same system prompt skip those prefill tokens entirely.
To maximise cache hits, place the static part of the prompt (system message, few-shot examples) at the beginning of the sequence and the dynamic part (user query) at the end. The hash chain is broken at the first token that differs, so any change early in the sequence prevents reuse of all subsequent blocks.
Observing cache hits
The use_tqdm=True default in generate() prints live prefill throughput. A significantly higher prefill tok/s on the second and later requests in a batch with a shared prefix indicates that cache hits are reducing the number of tokens actually processed:
Generating: 25%|████ | 1/4 [00:01<00:04, Prefill: 4821tok/s, Decode: 312tok/s]
Generating: 50%|████████ | 2/4 [00:01<00:02, Prefill: 9103tok/s, Decode: 318tok/s]
The jump in prefill throughput after the first request is caused by the scheduler sending fewer tokens through the model.