Skip to main content
slime provides extensive customization capabilities through function path arguments. These allow you to inject custom logic at various stages of the training and rollout pipeline without modifying the core codebase.

Overview of Customization Points

slime supports customization at multiple stages:
StageArgumentPurpose
Generation--custom-generate-function-pathOverride generation logic for RAG, tool calling, etc.
Reward--custom-rm-pathImplement custom reward computation
Filtering--dynamic-sampling-filter-pathFilter samples during dynamic sampling
Data Processing--rollout-data-postprocess-pathPost-process data after rollout
Loss Calculation--custom-loss-function-pathImplement custom training objectives

Custom Generate Function

The custom generate function allows you to override the default generation logic, enabling complex behaviors like multi-turn conversations, tool calling, and retrieval-augmented generation.

Function Signature

async def generate(args, sample: Sample, sampling_params: dict) -> Sample:
    """Custom generation function.
    
    Args:
        args: Global arguments object
        sample: Sample object with prompt, metadata, etc.
        sampling_params: Dictionary with temperature, top_p, max_tokens, etc.
    
    Returns:
        Sample object with response, tokens, loss_mask, and status filled
    """
    pass

Required Sample Fields

Your function must set these fields:
  • sample.response: Generated text response
  • sample.tokens: Full token sequence (prompt + response)
  • sample.response_length: Length of response in tokens
  • sample.loss_mask: Binary mask indicating which tokens to train on
  • sample.status: Completion status (COMPLETED, TRUNCATED, or ABORTED)

Basic Example

Here’s a simple custom generation function:
from slime.rollout.sglang_rollout import GenerateState
from slime.utils.http_utils import post
from slime.utils.types import Sample

async def generate(args, sample: Sample, sampling_params) -> Sample:
    state = GenerateState(args)
    url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate"
    
    # Prepare prompt
    prompt_text = sample.prompt
    prompt_tokens = state.tokenizer(prompt_text, add_special_tokens=False)["input_ids"]
    
    # Call inference engine
    payload = {
        "text": prompt_text,
        "sampling_params": sampling_params,
    }
    output = await post(url, payload)
    
    # Process response
    response = output["text"]
    response_tokens = state.tokenizer(response, add_special_tokens=False)["input_ids"]
    
    # Fill sample
    sample.tokens = prompt_tokens + response_tokens
    sample.response_length = len(response_tokens)
    sample.response = response
    sample.loss_mask = [1] * len(response_tokens)  # Train on all tokens
    sample.status = Sample.Status.COMPLETED
    
    return sample

Configuration

Specify your function in the training script:
CUSTOM_ARGS=(
   --custom-generate-function-path my_module.generation.generate
)

Custom Reward Function

Reward functions evaluate generated samples and return a scalar score.

Function Signature

async def custom_rm(args, sample: Sample, **kwargs) -> float:
    """Compute reward for a sample.
    
    Args:
        args: Global arguments
        sample: Sample with prompt, response, and label
        **kwargs: Additional arguments
    
    Returns:
        Float reward score (typically 0-1)
    """
    pass

Example: Rule-Based Reward

import re
from slime.utils.types import Sample

async def math_reward(args, sample: Sample, **kwargs) -> float:
    """Reward function for math problems."""
    # Extract answer from response
    pattern = r"<answer>(.*?)</answer>"
    match = re.search(pattern, sample.response, re.DOTALL)
    
    if not match:
        return 0.0  # No valid answer format
    
    predicted = match.group(1).strip()
    ground_truth = sample.label.get("answer", "")
    
    # Exact match
    if predicted.lower() == ground_truth.lower():
        return 1.0
    
    return 0.0

Example: Remote Reward Model

from slime.utils.http_utils import post

async def remote_reward(args, sample: Sample, **kwargs) -> float:
    """Call external reward model API."""
    payload = {
        "prompt": sample.prompt,
        "response": sample.response,
        "label": sample.label,
    }
    
    result = await post(args.rm_url, payload)
    return float(result["reward"])

Configuration

ROLLOUT_ARGS=(
   --custom-rm-path my_module.rewards.math_reward
)

Loss Masking for Multi-Turn Agents

Loss masking is crucial for agent training. It controls which tokens the model learns from.
The loss_mask must be the same length as the response tokens. Tokens with mask=1 contribute to loss, mask=0 are ignored.

General Principle

  • Model-generated tokens (thinking, actions, answers) → loss_mask = 1
  • Environment-returned tokens (tool outputs, observations) → loss_mask = 0

Multi-Turn Example

Here’s a simplified version of multi-turn generation with tool calling:
async def generate_with_tools(args, sample: Sample, sampling_params) -> Sample:
    state = GenerateState(args)
    url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate"
    
    prompt_text = sample.prompt
    prompt_tokens = state.tokenizer(prompt_text, add_special_tokens=False)["input_ids"]
    
    response = ""
    response_tokens = []
    loss_mask = []
    
    for turn in range(max_turns):
        # 1. Model generates action
        payload = {
            "text": prompt_text + response,
            "sampling_params": sampling_params,
        }
        output = await post(url, payload)
        model_output = output["text"]
        model_tokens = state.tokenizer(model_output, add_special_tokens=False)["input_ids"]
        
        # Model-generated tokens: train on these
        response += model_output
        response_tokens += model_tokens
        loss_mask += [1] * len(model_tokens)  # Train on model output
        
        # 2. Check if done
        if "<answer>" in model_output:
            break
        
        # 3. Execute tool call
        if "<search>" in model_output:
            query = extract_query(model_output)
            tool_result = await call_search_api(query)
            tool_text = f"<result>{tool_result}</result>"
            tool_tokens = state.tokenizer(tool_text, add_special_tokens=False)["input_ids"]
            
            # Tool output tokens: don't train on these
            response += tool_text
            response_tokens += tool_tokens
            loss_mask += [0] * len(tool_tokens)  # Don't train on tool output
    
    sample.tokens = prompt_tokens + response_tokens
    sample.response_length = len(response_tokens)
    sample.response = response
    sample.loss_mask = loss_mask
    sample.status = Sample.Status.COMPLETED
    
    return sample

Dynamic Sampling Filter

Dynamic filters evaluate groups of samples and decide whether to keep or discard them.

Function Signature

from dataclasses import dataclass
from slime.utils.types import Sample

@dataclass
class DynamicFilterOutput:
    keep: bool  # Whether to keep this sample group
    reason: str | None  # Reason for filtering (for logging)

def filter_function(args, samples: list[Sample], **kwargs) -> DynamicFilterOutput:
    """Filter a group of samples."""
    pass

Example: Reward Diversity Filter

import torch
from slime.utils.types import Sample

def check_reward_nonzero_std(args, samples: list[Sample], **kwargs):
    """Filter out groups where all rewards are identical."""
    rewards = [sample.get_reward_value(args) for sample in samples]
    
    # Calculate standard deviation
    reward_std = torch.tensor(rewards, dtype=torch.float).std()
    
    # Keep only if there's diversity in rewards
    keep = reward_std > 0.0
    
    return DynamicFilterOutput(
        keep=keep,
        reason=None if keep else f"zero_std_{round(rewards[0], 1)}",
    )

Configuration

ROLLOUT_ARGS=(
   --over-sampling-batch-size 64
   --rollout-batch-size 32
   --n-samples-per-prompt 8
   --dynamic-sampling-filter-path my_module.filters.check_reward_nonzero_std
)
This configuration:
  1. Samples 64 prompts × 8 responses = 512 samples
  2. Filters groups where all 8 responses have identical rewards
  3. Keeps only 32 prompts × 8 responses = 256 diverse samples
  4. Automatically resamples if too many groups are filtered

Buffer Filter

Buffer filters select samples from the rollout buffer before training.

Function Signature

def buffer_filter(
    args, 
    rollout_id: int, 
    buffer: list[list[Sample]], 
    num_samples: int
) -> list[list[Sample]]:
    """Select samples from buffer.
    
    Args:
        args: Global arguments
        rollout_id: Current rollout ID
        buffer: List of sample groups in buffer
        num_samples: Number of sample groups to return
    
    Returns:
        Selected sample groups
    """
    pass

Example: Priority Sampling

def priority_filter(args, rollout_id, buffer: list[list[Sample]], num_samples: int):
    """Sample high-reward groups with higher probability."""
    # Calculate average reward for each group
    group_rewards = []
    for group in buffer:
        avg_reward = sum(s.reward for s in group) / len(group)
        group_rewards.append(avg_reward)
    
    # Convert to probabilities
    probs = torch.softmax(torch.tensor(group_rewards), dim=0)
    
    # Sample without replacement
    indices = torch.multinomial(probs, num_samples, replacement=False)
    
    selected = [buffer[i] for i in indices]
    
    # Remove selected from buffer
    for i in sorted(indices, reverse=True):
        del buffer[i]
    
    return selected

Rollout Data Postprocess

Post-process samples after log probabilities are computed.

Function Signature

def postprocess_function(args, samples: list[list[Sample]]) -> None:
    """Post-process rollout data.
    
    Args:
        args: Global arguments
        samples: List of sample groups with log_probs computed
    
    Note:
        This function modifies samples in-place
    """
    pass

Example: Dynamic Loss Masking

def adjust_loss_masks(args, samples: list[list[Sample]]) -> None:
    """Adjust loss masks based on log probabilities."""
    for group in samples:
        for sample in group:
            # Mask out very low probability tokens
            for i, log_prob in enumerate(sample.rollout_log_probs):
                if log_prob < -10.0:  # Very unlikely token
                    sample.loss_mask[i] = 0

Custom Loss Function

Implement entirely custom training objectives.

Configuration

GRPO_ARGS=(
   --loss-type custom_loss
   --custom-loss-function-path my_module.losses.custom_loss
)
Refer to slime’s built-in loss functions in slime/trainer/megatron_loss.py for implementation examples.

Real-World Example: Search-R1

Here’s a complete example from the Search-R1 implementation that combines custom generation with tool calling:
# From examples/search-r1/generate_with_search.py

async def generate(args, sample: Sample, sampling_params) -> Sample:
    state = GenerateState(args)
    url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate"
    
    prompt_text = sample.prompt
    prompt_tokens = state.tokenizer(prompt_text, add_special_tokens=False)["input_ids"]
    response = ""
    response_tokens = []
    loss_mask = []
    rollout_log_probs = []
    
    for turn in range(max_turns):
        # Model generates search query or answer
        payload = {
            "text": prompt_text + response,
            "sampling_params": sampling_params,
            "return_logprob": True,
        }
        output = await post(url, payload)
        
        if output["meta_info"]["finish_reason"]["type"] == "abort":
            sample.status = Sample.Status.ABORTED
            return sample
        
        # Extract model output
        model_text = output["text"]
        model_tokens = [item[1] for item in output["meta_info"]["output_token_logprobs"]]
        model_log_probs = [item[0] for item in output["meta_info"]["output_token_logprobs"]]
        
        # Add to response
        response += model_text
        response_tokens += model_tokens
        loss_mask += [1] * len(model_tokens)  # Train on model output
        rollout_log_probs += model_log_probs
        
        # Parse action
        action, content = parse_action(model_text)
        
        if action == "search":
            # Execute search
            search_result = await search(content)
            obs_text = f"<information>{search_result}</information>"
            obs_tokens = state.tokenizer(obs_text, add_special_tokens=False)["input_ids"]
            
            # Add observation
            response += obs_text
            response_tokens += obs_tokens
            loss_mask += [0] * len(obs_tokens)  # Don't train on search results
            rollout_log_probs += [0.0] * len(obs_tokens)
            
        elif action == "answer":
            break  # Done
    
    sample.tokens = prompt_tokens + response_tokens
    sample.response_length = len(response_tokens)
    sample.response = response
    sample.loss_mask = loss_mask
    sample.rollout_log_probs = rollout_log_probs
    sample.status = Sample.Status.COMPLETED
    
    return sample

Metadata and Custom Fields

Pass additional data to your custom functions through the metadata field.

Dataset Preparation

Create a metadata field in your JSONL dataset:
{
  "question": "What is the capital of France?",
  "answer": "Paris",
  "metadata": {
    "session_id": "sess_123",
    "difficulty": "easy",
    "tools": ["search", "calculator"]
  }
}

Configuration

ROLLOUT_ARGS=(
   --prompt-data /path/to/data.jsonl
   --input-key question
   --label-key answer
   --metadata-key metadata
)

Accessing in Custom Functions

async def generate(args, sample: Sample, sampling_params) -> Sample:
    # Access metadata
    session_id = sample.metadata["session_id"]
    tools = sample.metadata["tools"]
    
    # Use metadata in generation logic
    if "search" in tools:
        # Enable search capability
        pass
    
    return sample

Testing Custom Functions

slime provides CPU-only contract tests for validation:
# Test custom generate function
python tests/plugin_contracts/test_plugin_generate_contracts.py \
  --custom-generate-function-path my_module.generation.generate

# Test custom reward function
python tests/plugin_contracts/test_plugin_path_loading_contracts.py \
  --custom-rm-path my_module.rewards.custom_rm
These tests verify function signatures, return types, and basic functionality without requiring GPUs.

Best Practices

1

Start Simple

Begin with a basic custom function and test it thoroughly before adding complexity.
2

Validate Loss Masks

Always verify that loss_mask length matches response_length. Misalignment causes training failures.
3

Handle Edge Cases

Check for empty responses, truncation, and API failures. Set appropriate sample.status values.
4

Use Async Properly

Always use await for HTTP calls and I/O operations. This enables concurrent processing.
5

Log Metadata

Store debugging information in sample.metadata for later analysis.

Next Steps

Multi-Turn Agents

Learn how to build agent systems with tool calling

Configuration

Return to configuration reference

Build docs developers (and LLMs) love