Overview
Slime uses several key data structures to represent training samples, rollout batches, and outputs throughout the training pipeline.
Sample
Sample Class
The fundamental unit representing a single prompt-response pair.
from slime.utils.types import Sample
sample = Sample(
prompt="What is 2+2?",
response="4",
reward=1.0,
status=Sample.Status.COMPLETED
)
Source: slime/utils/types.py:9
Fields
Index of the sample group (all responses for same prompt)
Unique sample index across all samples
prompt
str | list[dict]
default:"''"
Input prompt. Can be:
- String for single-turn
- List of message dicts for chat format:
[{"role": "user", "content": "..."}]
Full token sequence (prompt + response tokens)
multimodal_inputs
dict | None
default:"None"
Raw multimodal data:{
"images": [PIL.Image, ...],
"videos": [...],
"audio": [...]
}
multimodal_train_inputs
dict | None
default:"None"
Processed multimodal tensors for training (e.g., pixel_values)
Number of generated tokens
Ground truth label (for reward computation)
reward
float | dict | None
default:"None"
Reward value. Can be:
- Float for single reward
- Dict for multi-objective rewards:
{"task": 0.8, "length": -0.1}
loss_mask
list[int] | None
default:"None"
Binary mask for loss computation (1 = compute loss, 0 = ignore)
- Used for partial rollout masking
- Used for tool calling masking
List of model weight versions used during generation (for multi-turn)
rollout_log_probs
list[float] | None
default:"None"
Log probabilities from rollout engine for each generated token
rollout_routed_experts
list[list[int]] | None
default:"None"
Routed expert indices for MoE models (for routing replay)
Flag to remove this sample from training
teacher_log_probs
list[float] | None
default:"None"
Teacher model log probabilities for on-policy distillation
status
Status
default:"Status.PENDING"
Sample status enum (see below)
Arbitrary metadata dictionary
train_metadata
dict | None
default:"None"
Training-specific metadata (e.g., custom loss configuration)
Session ID for consistent hashing routing
Sample.Status
Enum representing sample generation status.
class Status(Enum):
PENDING = "pending" # Not yet generated
COMPLETED = "completed" # Successfully completed
TRUNCATED = "truncated" # Hit max length
ABORTED = "aborted" # Aborted by system
FAILED = "failed" # Recoverable failure (e.g., API error)
Source: slime/utils/types.py:31-39
Methods
Get reward value, handling both float and dict rewards.Parameters:
args (Namespace): Arguments with reward_key attribute
Returns: float - Reward value# Float reward
sample.reward = 0.8
sample.get_reward_value(args) # 0.8
# Dict reward
sample.reward = {"task": 0.8, "length": -0.1}
args.reward_key = "task"
sample.get_reward_value(args) # 0.8
effective_response_length
Get effective response length considering loss mask.Returns: int - Sum of loss_mask if set, else response_lengthsample.response_length = 10
sample.loss_mask = [1, 1, 0, 0, 1, 1, 1, 1, 1, 1]
sample.effective_response_length # 8
Convert sample to dictionary for serialization.Returns: dict
Create sample from dictionary.Parameters:
data (dict): Serialized sample data
Returns: Sample
SpecInfo
Nested dataclass for speculative decoding statistics.
@dataclass
class SpecInfo:
spec_accept_token_num: int = 0 # Accepted draft tokens
spec_draft_token_num: int = 0 # Total draft tokens
spec_verify_ct: int = 0 # Verification iterations
completion_token_num: int = 0 # Total completion tokens
@property
def spec_accept_rate(self) -> float:
"""Acceptance rate of draft tokens"""
@property
def spec_accept_length(self) -> float:
"""Average accepted tokens per verification"""
Source: slime/utils/types.py:53-81
PrefixCacheInfo
Nested dataclass for prefix caching statistics.
@dataclass
class PrefixCacheInfo:
cached_tokens: int = 0 # Number of cached tokens
total_prompt_tokens: int = 0 # Total prompt tokens
@property
def prefix_cache_hit_rate(self) -> float:
"""Cache hit rate"""
Source: slime/utils/types.py:93-118
RolloutBatch
Type Definition
Dictionary-based batch structure for rollout data.
from slime.utils.types import RolloutBatch
# RolloutBatch is a type alias
RolloutBatch = dict[str, list[torch.Tensor] | list[int] | list[float] | list[str]]
Typical Fields:
rollout_batch = {
"tokens": [torch.Tensor, ...], # Token tensors
"response_length": [128, 64, 256, ...], # Response lengths
"rewards": [0.8, 0.0, 1.0, ...], # Reward values
"loss_mask": [torch.Tensor, ...], # Loss masks
"prompt": ["prompt1", "prompt2", ...], # Original prompts (for logging)
"response": ["resp1", "resp2", ...], # Responses (for logging)
}
Source: slime/utils/types.py:190
Rollout Output Types
RolloutFnTrainOutput
Output from training rollout functions.
from slime.rollout.base_types import RolloutFnTrainOutput
@dataclass
class RolloutFnTrainOutput:
samples: list[list[Sample]] # Shape: [rollout_batch_size, n_samples_per_prompt]
metrics: dict[str, Any] = None # Optional metrics
Example:
output = RolloutFnTrainOutput(
samples=[
[sample1, sample2, sample3, sample4], # Group 0: 4 responses for prompt 0
[sample5, sample6, sample7, sample8], # Group 1: 4 responses for prompt 1
# ... rollout_batch_size groups total
],
metrics={
"filter_drop_rate": 0.15,
"avg_response_length": 234.5
}
)
Source: slime/rollout/base_types.py:8
RolloutFnEvalOutput
Output from evaluation rollout functions.
from slime.rollout.base_types import RolloutFnEvalOutput
@dataclass
class RolloutFnEvalOutput:
data: dict[str, dict[str, Any]] # dataset_name -> results
metrics: dict[str, Any] = None # Optional metrics
Example:
output = RolloutFnEvalOutput(
data={
"gsm8k": {
"rewards": [1.0, 0.0, 1.0, ...],
"truncated": [False, False, True, ...],
"samples": [sample1, sample2, ...]
},
"math": {
"rewards": [0.0, 1.0, 0.0, ...],
"truncated": [False, False, False, ...],
"samples": [...]
}
},
metrics={
"gsm8k_accuracy": 0.87,
"math_accuracy": 0.52
}
)
Source: slime/rollout/base_types.py:14
Multimodal Types
MultimodalType
Type information for multimodal data.
from slime.utils.types import MultimodalType, MultimodalTypes
@dataclass
class MultimodalType:
name: str # Type identifier (e.g., "image")
placeholder: str # Placeholder token (e.g., "<image>")
Predefined Types:
MultimodalTypes.IMAGE # MultimodalType(name="image", placeholder="<image>")
MultimodalTypes.VIDEO # MultimodalType(name="video", placeholder="<video>")
MultimodalTypes.AUDIO # MultimodalType(name="audio", placeholder="<audio>")
# Get all types
all_types = MultimodalTypes.all() # [IMAGE, VIDEO, AUDIO]
# Get by name
image_type = MultimodalTypes.get("image")
Source: slime/utils/types.py:193-210
ParamInfo
Metadata for model parameters during weight updates.
from slime.utils.types import ParamInfo
@dataclass(frozen=True)
class ParamInfo:
name: str # Parameter name
dtype: torch.dtype # Data type
shape: torch.Size # Tensor shape
attrs: dict # Additional attributes
size: int # Number of elements
src_rank: int # Source rank for distributed transfer
Example:
param_info = ParamInfo(
name="model.layers.0.self_attn.q_proj.weight",
dtype=torch.bfloat16,
shape=torch.Size([4096, 4096]),
attrs={},
size=16777216,
src_rank=0
)
Source: slime/utils/types.py:177-184
Usage Examples
Creating Samples
from slime.utils.types import Sample
# Basic sample
sample = Sample(
prompt="Solve: 2x + 3 = 7",
index=0
)
# Multi-turn chat sample
sample = Sample(
prompt=[
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi!"},
{"role": "user", "content": "How are you?"}
],
index=1
)
# Multimodal sample
from PIL import Image
sample = Sample(
prompt="Describe this image",
multimodal_inputs={
"images": [Image.open("cat.jpg")]
},
index=2
)
Working with Sample Groups
# Create a group of samples (n_samples_per_prompt = 4)
group = [
Sample(prompt="What is AI?", index=0, group_index=0),
Sample(prompt="What is AI?", index=1, group_index=0),
Sample(prompt="What is AI?", index=2, group_index=0),
Sample(prompt="What is AI?", index=3, group_index=0),
]
# After generation and reward computation
for i, sample in enumerate(group):
sample.response = f"Response {i}"
sample.reward = i * 0.25 # [0.0, 0.25, 0.5, 0.75]
sample.status = Sample.Status.COMPLETED
Processing Rollout Output
from slime.rollout.base_types import RolloutFnTrainOutput
# After rollout generation
output: RolloutFnTrainOutput = generate_rollout(args, rollout_id, data_source)
# Access samples
for group in output.samples:
for sample in group:
print(f"Prompt: {sample.prompt}")
print(f"Response: {sample.response}")
print(f"Reward: {sample.reward}")
# Access metrics
if output.metrics:
print(f"Metrics: {output.metrics}")
See Also: