Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/meta-llama/llama/llms.txt

Use this file to discover all available pages before exploring further.

Overview

The Transformer class implements the core neural network architecture for Llama 2 models. It consists of token embeddings, multiple transformer blocks with attention and feedforward layers, normalization, and an output projection layer.

Class Definition

class Transformer(nn.Module)

Methods

__init__

def __init__(self, params: ModelArgs)
Initialize a Transformer model.
params
ModelArgs
required
Model configuration parameters containing:
  • dim (int): Model dimension (default: 4096)
  • n_layers (int): Number of transformer layers (default: 32)
  • n_heads (int): Number of attention heads (default: 32)
  • n_kv_heads (Optional[int]): Number of key-value heads for grouped-query attention
  • vocab_size (int): Size of the vocabulary
  • multiple_of (int): Make SwiGLU hidden layer size multiple of this value (default: 256)
  • ffn_dim_multiplier (Optional[float]): Multiplier for feedforward dimension
  • norm_eps (float): Epsilon for RMSNorm (default: 1e-5)
  • max_batch_size (int): Maximum batch size (default: 32)
  • max_seq_len (int): Maximum sequence length (default: 2048)
Attributes: After initialization, the Transformer instance has the following attributes:
  • params (ModelArgs): Model configuration parameters.
  • vocab_size (int): Vocabulary size.
  • n_layers (int): Number of layers in the model.
  • tok_embeddings (ParallelEmbedding): Token embeddings.
  • layers (torch.nn.ModuleList): List of Transformer blocks.
  • norm (RMSNorm): Layer normalization for the model output.
  • output (ColumnParallelLinear): Linear layer for final output.
  • freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies for rotary positional embeddings.
Example:
import torch
from llama.model import Transformer, ModelArgs

# Define model configuration
model_args = ModelArgs(
    dim=4096,
    n_layers=32,
    n_heads=32,
    vocab_size=32000,
    max_seq_len=2048,
    max_batch_size=32
)

# Initialize transformer
transformer = Transformer(params=model_args)

print(f"Model layers: {transformer.n_layers}")
print(f"Vocabulary size: {transformer.vocab_size}")

forward

@torch.inference_mode()
def forward(self, tokens: torch.Tensor, start_pos: int) -> torch.Tensor
Perform a forward pass through the Transformer model.
tokens
torch.Tensor
required
Input token indices. Shape: (batch_size, sequence_length)
start_pos
int
required
Starting position for attention caching. Used for efficient generation by caching key-value pairs from previous tokens.
return
torch.Tensor
Output logits after applying the Transformer model. Shape: (batch_size, sequence_length, vocab_size)
Note: This method is decorated with @torch.inference_mode() for optimized inference. It processes input tokens through:
  1. Token embeddings
  2. Multiple transformer blocks (attention + feedforward)
  3. Final layer normalization
  4. Output projection to vocabulary size
The method uses attention masking for causal language modeling and supports key-value caching for efficient autoregressive generation. Example:
import torch
from llama.model import Transformer, ModelArgs

# Initialize model
model_args = ModelArgs(
    dim=4096,
    n_layers=32,
    n_heads=32,
    vocab_size=32000,
    max_seq_len=2048,
    max_batch_size=4
)

transformer = Transformer(params=model_args)
transformer.eval()

# Prepare input tokens
batch_size = 2
seq_len = 10
tokens = torch.randint(0, 32000, (batch_size, seq_len)).cuda()

# First forward pass (no cache)
start_pos = 0
logits = transformer.forward(tokens, start_pos)
print(f"Output shape: {logits.shape}")  # (2, 10, 32000)

# Get next token predictions
next_token_logits = logits[:, -1, :]  # (batch_size, vocab_size)
print(f"Next token logits shape: {next_token_logits.shape}")

# Autoregressive generation with caching
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)  # (batch_size, 1)
start_pos = seq_len
logits = transformer.forward(next_token, start_pos)
print(f"Cached forward output shape: {logits.shape}")  # (2, 1, 32000)

Architecture Details

The Transformer model implements the Llama 2 architecture with the following key components:

Token Embeddings

Converts input token IDs to dense vector representations using parallel embeddings for efficient distributed training.

Transformer Blocks

Each block contains:
  • Multi-head Attention: Implements grouped-query attention with rotary positional embeddings (RoPE)
  • Feedforward Network: Uses SwiGLU activation function
  • RMSNorm: Root Mean Square Layer Normalization applied before attention and feedforward layers

Attention Caching

The model supports key-value caching for efficient autoregressive generation:
  • First pass: Process entire prompt with start_pos=0
  • Subsequent passes: Process one token at a time with start_pos indicating cache position

Output Layer

Final linear projection from model dimension to vocabulary size produces logits for next token prediction.

Usage with Llama Class

While you can use the Transformer class directly, it’s typically used internally by the Llama class:
from llama import Llama

# The Llama.build() method creates and loads a Transformer internally
llama = Llama.build(
    ckpt_dir="/path/to/llama-2-7b",
    tokenizer_path="/path/to/tokenizer.model",
    max_seq_len=2048,
    max_batch_size=32
)

# Access the transformer model
transformer = llama.model
print(f"Model has {transformer.n_layers} layers")

Build docs developers (and LLMs) love