slime provides extensive customization capabilities through function path arguments. These allow you to inject custom logic at various stages of the training and rollout pipeline without modifying the core codebase.
Overview of Customization Points
slime supports customization at multiple stages:
Stage Argument Purpose Generation --custom-generate-function-pathOverride generation logic for RAG, tool calling, etc. Reward --custom-rm-pathImplement custom reward computation Filtering --dynamic-sampling-filter-pathFilter samples during dynamic sampling Data Processing --rollout-data-postprocess-pathPost-process data after rollout Loss Calculation --custom-loss-function-pathImplement custom training objectives
Custom Generate Function
The custom generate function allows you to override the default generation logic, enabling complex behaviors like multi-turn conversations, tool calling, and retrieval-augmented generation.
Function Signature
async def generate ( args , sample : Sample, sampling_params : dict ) -> Sample:
"""Custom generation function.
Args:
args: Global arguments object
sample: Sample object with prompt, metadata, etc.
sampling_params: Dictionary with temperature, top_p, max_tokens, etc.
Returns:
Sample object with response, tokens, loss_mask, and status filled
"""
pass
Required Sample Fields
Your function must set these fields:
sample.response: Generated text response
sample.tokens: Full token sequence (prompt + response)
sample.response_length: Length of response in tokens
sample.loss_mask: Binary mask indicating which tokens to train on
sample.status: Completion status (COMPLETED, TRUNCATED, or ABORTED)
Basic Example
Here’s a simple custom generation function:
from slime.rollout.sglang_rollout import GenerateState
from slime.utils.http_utils import post
from slime.utils.types import Sample
async def generate ( args , sample : Sample, sampling_params ) -> Sample:
state = GenerateState(args)
url = f "http:// { args.sglang_router_ip } : { args.sglang_router_port } /generate"
# Prepare prompt
prompt_text = sample.prompt
prompt_tokens = state.tokenizer(prompt_text, add_special_tokens = False )[ "input_ids" ]
# Call inference engine
payload = {
"text" : prompt_text,
"sampling_params" : sampling_params,
}
output = await post(url, payload)
# Process response
response = output[ "text" ]
response_tokens = state.tokenizer(response, add_special_tokens = False )[ "input_ids" ]
# Fill sample
sample.tokens = prompt_tokens + response_tokens
sample.response_length = len (response_tokens)
sample.response = response
sample.loss_mask = [ 1 ] * len (response_tokens) # Train on all tokens
sample.status = Sample.Status. COMPLETED
return sample
Configuration
Specify your function in the training script:
CUSTOM_ARGS = (
--custom-generate-function-path my_module.generation.generate
)
Custom Reward Function
Reward functions evaluate generated samples and return a scalar score.
Function Signature
async def custom_rm ( args , sample : Sample, ** kwargs ) -> float :
"""Compute reward for a sample.
Args:
args: Global arguments
sample: Sample with prompt, response, and label
**kwargs: Additional arguments
Returns:
Float reward score (typically 0-1)
"""
pass
Example: Rule-Based Reward
import re
from slime.utils.types import Sample
async def math_reward ( args , sample : Sample, ** kwargs ) -> float :
"""Reward function for math problems."""
# Extract answer from response
pattern = r "<answer> ( . *? ) </answer>"
match = re.search(pattern, sample.response, re. DOTALL )
if not match:
return 0.0 # No valid answer format
predicted = match.group( 1 ).strip()
ground_truth = sample.label.get( "answer" , "" )
# Exact match
if predicted.lower() == ground_truth.lower():
return 1.0
return 0.0
Example: Remote Reward Model
from slime.utils.http_utils import post
async def remote_reward ( args , sample : Sample, ** kwargs ) -> float :
"""Call external reward model API."""
payload = {
"prompt" : sample.prompt,
"response" : sample.response,
"label" : sample.label,
}
result = await post(args.rm_url, payload)
return float (result[ "reward" ])
Configuration
ROLLOUT_ARGS = (
--custom-rm-path my_module.rewards.math_reward
)
Loss Masking for Multi-Turn Agents
Loss masking is crucial for agent training. It controls which tokens the model learns from.
The loss_mask must be the same length as the response tokens. Tokens with mask=1 contribute to loss, mask=0 are ignored.
General Principle
Model-generated tokens (thinking, actions, answers) → loss_mask = 1
Environment-returned tokens (tool outputs, observations) → loss_mask = 0
Multi-Turn Example
Here’s a simplified version of multi-turn generation with tool calling:
async def generate_with_tools ( args , sample : Sample, sampling_params ) -> Sample:
state = GenerateState(args)
url = f "http:// { args.sglang_router_ip } : { args.sglang_router_port } /generate"
prompt_text = sample.prompt
prompt_tokens = state.tokenizer(prompt_text, add_special_tokens = False )[ "input_ids" ]
response = ""
response_tokens = []
loss_mask = []
for turn in range (max_turns):
# 1. Model generates action
payload = {
"text" : prompt_text + response,
"sampling_params" : sampling_params,
}
output = await post(url, payload)
model_output = output[ "text" ]
model_tokens = state.tokenizer(model_output, add_special_tokens = False )[ "input_ids" ]
# Model-generated tokens: train on these
response += model_output
response_tokens += model_tokens
loss_mask += [ 1 ] * len (model_tokens) # Train on model output
# 2. Check if done
if "<answer>" in model_output:
break
# 3. Execute tool call
if "<search>" in model_output:
query = extract_query(model_output)
tool_result = await call_search_api(query)
tool_text = f "<result> { tool_result } </result>"
tool_tokens = state.tokenizer(tool_text, add_special_tokens = False )[ "input_ids" ]
# Tool output tokens: don't train on these
response += tool_text
response_tokens += tool_tokens
loss_mask += [ 0 ] * len (tool_tokens) # Don't train on tool output
sample.tokens = prompt_tokens + response_tokens
sample.response_length = len (response_tokens)
sample.response = response
sample.loss_mask = loss_mask
sample.status = Sample.Status. COMPLETED
return sample
Dynamic Sampling Filter
Dynamic filters evaluate groups of samples and decide whether to keep or discard them.
Function Signature
from dataclasses import dataclass
from slime.utils.types import Sample
@dataclass
class DynamicFilterOutput :
keep: bool # Whether to keep this sample group
reason: str | None # Reason for filtering (for logging)
def filter_function ( args , samples : list[Sample], ** kwargs ) -> DynamicFilterOutput:
"""Filter a group of samples."""
pass
Example: Reward Diversity Filter
import torch
from slime.utils.types import Sample
def check_reward_nonzero_std ( args , samples : list[Sample], ** kwargs ):
"""Filter out groups where all rewards are identical."""
rewards = [sample.get_reward_value(args) for sample in samples]
# Calculate standard deviation
reward_std = torch.tensor(rewards, dtype = torch.float).std()
# Keep only if there's diversity in rewards
keep = reward_std > 0.0
return DynamicFilterOutput(
keep = keep,
reason = None if keep else f "zero_std_ { round (rewards[ 0 ], 1 ) } " ,
)
Configuration
ROLLOUT_ARGS = (
--over-sampling-batch-size 64
--rollout-batch-size 32
--n-samples-per-prompt 8
--dynamic-sampling-filter-path my_module.filters.check_reward_nonzero_std
)
This configuration:
Samples 64 prompts × 8 responses = 512 samples
Filters groups where all 8 responses have identical rewards
Keeps only 32 prompts × 8 responses = 256 diverse samples
Automatically resamples if too many groups are filtered
Buffer Filter
Buffer filters select samples from the rollout buffer before training.
Function Signature
def buffer_filter (
args ,
rollout_id : int ,
buffer : list[list[Sample]],
num_samples : int
) -> list[list[Sample]]:
"""Select samples from buffer.
Args:
args: Global arguments
rollout_id: Current rollout ID
buffer: List of sample groups in buffer
num_samples: Number of sample groups to return
Returns:
Selected sample groups
"""
pass
Example: Priority Sampling
def priority_filter ( args , rollout_id , buffer : list[list[Sample]], num_samples : int ):
"""Sample high-reward groups with higher probability."""
# Calculate average reward for each group
group_rewards = []
for group in buffer:
avg_reward = sum (s.reward for s in group) / len (group)
group_rewards.append(avg_reward)
# Convert to probabilities
probs = torch.softmax(torch.tensor(group_rewards), dim = 0 )
# Sample without replacement
indices = torch.multinomial(probs, num_samples, replacement = False )
selected = [buffer[i] for i in indices]
# Remove selected from buffer
for i in sorted (indices, reverse = True ):
del buffer[i]
return selected
Rollout Data Postprocess
Post-process samples after log probabilities are computed.
Function Signature
def postprocess_function ( args , samples : list[list[Sample]]) -> None :
"""Post-process rollout data.
Args:
args: Global arguments
samples: List of sample groups with log_probs computed
Note:
This function modifies samples in-place
"""
pass
Example: Dynamic Loss Masking
def adjust_loss_masks ( args , samples : list[list[Sample]]) -> None :
"""Adjust loss masks based on log probabilities."""
for group in samples:
for sample in group:
# Mask out very low probability tokens
for i, log_prob in enumerate (sample.rollout_log_probs):
if log_prob < - 10.0 : # Very unlikely token
sample.loss_mask[i] = 0
Custom Loss Function
Implement entirely custom training objectives.
Configuration
GRPO_ARGS = (
--loss-type custom_loss
--custom-loss-function-path my_module.losses.custom_loss
)
Refer to slime’s built-in loss functions in slime/trainer/megatron_loss.py for implementation examples.
Real-World Example: Search-R1
Here’s a complete example from the Search-R1 implementation that combines custom generation with tool calling:
Generation Function
Reward Function
Configuration
# From examples/search-r1/generate_with_search.py
async def generate ( args , sample : Sample, sampling_params ) -> Sample:
state = GenerateState(args)
url = f "http:// { args.sglang_router_ip } : { args.sglang_router_port } /generate"
prompt_text = sample.prompt
prompt_tokens = state.tokenizer(prompt_text, add_special_tokens = False )[ "input_ids" ]
response = ""
response_tokens = []
loss_mask = []
rollout_log_probs = []
for turn in range (max_turns):
# Model generates search query or answer
payload = {
"text" : prompt_text + response,
"sampling_params" : sampling_params,
"return_logprob" : True ,
}
output = await post(url, payload)
if output[ "meta_info" ][ "finish_reason" ][ "type" ] == "abort" :
sample.status = Sample.Status. ABORTED
return sample
# Extract model output
model_text = output[ "text" ]
model_tokens = [item[ 1 ] for item in output[ "meta_info" ][ "output_token_logprobs" ]]
model_log_probs = [item[ 0 ] for item in output[ "meta_info" ][ "output_token_logprobs" ]]
# Add to response
response += model_text
response_tokens += model_tokens
loss_mask += [ 1 ] * len (model_tokens) # Train on model output
rollout_log_probs += model_log_probs
# Parse action
action, content = parse_action(model_text)
if action == "search" :
# Execute search
search_result = await search(content)
obs_text = f "<information> { search_result } </information>"
obs_tokens = state.tokenizer(obs_text, add_special_tokens = False )[ "input_ids" ]
# Add observation
response += obs_text
response_tokens += obs_tokens
loss_mask += [ 0 ] * len (obs_tokens) # Don't train on search results
rollout_log_probs += [ 0.0 ] * len (obs_tokens)
elif action == "answer" :
break # Done
sample.tokens = prompt_tokens + response_tokens
sample.response_length = len (response_tokens)
sample.response = response
sample.loss_mask = loss_mask
sample.rollout_log_probs = rollout_log_probs
sample.status = Sample.Status. COMPLETED
return sample
Pass additional data to your custom functions through the metadata field.
Dataset Preparation
Create a metadata field in your JSONL dataset:
{
"question" : "What is the capital of France?" ,
"answer" : "Paris" ,
"metadata" : {
"session_id" : "sess_123" ,
"difficulty" : "easy" ,
"tools" : [ "search" , "calculator" ]
}
}
Configuration
ROLLOUT_ARGS = (
--prompt-data /path/to/data.jsonl
--input-key question
--label-key answer
--metadata-key metadata
)
Accessing in Custom Functions
async def generate ( args , sample : Sample, sampling_params ) -> Sample:
# Access metadata
session_id = sample.metadata[ "session_id" ]
tools = sample.metadata[ "tools" ]
# Use metadata in generation logic
if "search" in tools:
# Enable search capability
pass
return sample
Testing Custom Functions
slime provides CPU-only contract tests for validation:
# Test custom generate function
python tests/plugin_contracts/test_plugin_generate_contracts.py \
--custom-generate-function-path my_module.generation.generate
# Test custom reward function
python tests/plugin_contracts/test_plugin_path_loading_contracts.py \
--custom-rm-path my_module.rewards.custom_rm
These tests verify function signatures, return types, and basic functionality without requiring GPUs.
Best Practices
Start Simple
Begin with a basic custom function and test it thoroughly before adding complexity.
Validate Loss Masks
Always verify that loss_mask length matches response_length. Misalignment causes training failures.
Handle Edge Cases
Check for empty responses, truncation, and API failures. Set appropriate sample.status values.
Use Async Properly
Always use await for HTTP calls and I/O operations. This enables concurrent processing.
Log Metadata
Store debugging information in sample.metadata for later analysis.
Next Steps
Multi-Turn Agents Learn how to build agent systems with tool calling
Configuration Return to configuration reference