Skip to main content
slime’s extensibility makes it ideal for training sophisticated agents that can use tools, perform multi-turn reasoning, and interact with external APIs. This guide shows you how to build agent systems for RL training.

Overview

Multi-turn agent training involves:
  1. Multi-turn interaction loop: Model generates actions, executes tools, observes results
  2. Loss masking: Train on model-generated tokens, not environment responses
  3. Custom reward functions: Evaluate complete interaction trajectories
  4. Metadata management: Pass tool definitions and context through the pipeline

Example: Search-R1 Agent

We’ll use the Search-R1 example, which trains a model to answer questions using web search.

Agent Workflow

1

Receive Question

Model receives a question as the initial prompt
2

Generate Action

Model decides to either:
  • Search: <search>query terms</search>
  • Answer: <answer>final answer</answer>
3

Execute Tool

If search action, call search API and format results
4

Observe Results

Append search results to conversation history
5

Repeat or Finish

Continue for max_turns or until answer is generated
6

Compute Reward

Compare final answer against ground truth

Implementation

1. Configuration

First, define configuration for your agent:
# Configuration dictionary
SEARCH_R1_CONFIGS = {
    "max_turns": 2,  # Maximum search iterations
    "topk": 3,  # Number of search results
    "search_concurrency": 256,  # Concurrent search requests
    "search_backend": "local",  # "local" or "google"
    "return_logprob": True,  # Collect log probs for TIS
    "format_score": 0.2,  # Partial credit for format
}

2. Tool Functions

Implement your tool APIs:
import asyncio
from slime.utils.http_utils import post

SEMAPHORE = asyncio.Semaphore(256)  # Limit concurrency

async def search(query: str) -> str:
    """Execute search and return formatted results."""
    async with SEMAPHORE:
        if SEARCH_R1_CONFIGS["search_backend"] == "google":
            result = await google_search(
                api_key=SEARCH_R1_CONFIGS["google"]["api_key"],
                query=query,
                topk=SEARCH_R1_CONFIGS["topk"],
            )
        else:
            result = await local_search(
                url=SEARCH_R1_CONFIGS["local"]["search_url"],
                query=query,
                topk=SEARCH_R1_CONFIGS["topk"],
            )
    
    return format_search_results(result)

def format_search_results(results) -> str:
    """Format search results as text."""
    formatted = ""
    for idx, doc in enumerate(results):
        content = doc["document"]["contents"]
        title = content.split("\n")[0]
        text = "\n".join(content.split("\n")[1:])
        formatted += f"Doc {idx+1}(Title: {title}) {text}\n"
    return formatted

3. Custom Generate Function

Implement the multi-turn generation loop:
import re
from slime.rollout.sglang_rollout import GenerateState
from slime.utils.types import Sample

async def generate(args, sample: Sample, sampling_params) -> Sample:
    state = GenerateState(args)
    url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate"
    
    # Initialize
    prompt_text = sample.prompt
    prompt_tokens = state.tokenizer(prompt_text, add_special_tokens=False)["input_ids"]
    response = ""
    response_tokens = []
    loss_mask = []
    rollout_log_probs = [] if SEARCH_R1_CONFIGS["return_logprob"] else None
    
    # Multi-turn loop
    for turn_idx in range(SEARCH_R1_CONFIGS["max_turns"]):
        # Step 1: Model generates action
        payload = {
            "text": prompt_text + response,
            "sampling_params": sampling_params,
        }
        
        if SEARCH_R1_CONFIGS["return_logprob"]:
            payload["return_logprob"] = True
        
        output = await post(url, payload)
        
        # Handle abort
        if output["meta_info"]["finish_reason"]["type"] == "abort":
            sample.status = Sample.Status.ABORTED
            return sample
        
        model_text = output["text"]
        
        # Extract tokens and log probs
        if SEARCH_R1_CONFIGS["return_logprob"]:
            # Use tokens directly from output to ensure alignment
            model_tokens = [item[1] for item in output["meta_info"]["output_token_logprobs"]]
            model_log_probs = [item[0] for item in output["meta_info"]["output_token_logprobs"]]
        else:
            # Tokenize response
            model_tokens = state.tokenizer(model_text, add_special_tokens=False)["input_ids"]
        
        # Add model output to response
        response += model_text
        response_tokens += model_tokens
        loss_mask += [1] * len(model_tokens)  # TRAIN on model output
        
        if SEARCH_R1_CONFIGS["return_logprob"]:
            rollout_log_probs += model_log_probs
        
        # Step 2: Parse action
        action, content = parse_action(model_text)
        
        if action == "search":
            # Step 3: Execute search
            search_results = await search(content)
            obs_text = f"\n\n<information>{search_results}</information>\n\n"
            obs_tokens = state.tokenizer(obs_text, add_special_tokens=False)["input_ids"]
            
            # Step 4: Add observation to response
            response += obs_text
            response_tokens += obs_tokens
            loss_mask += [0] * len(obs_tokens)  # DON'T TRAIN on tool output
            
            if SEARCH_R1_CONFIGS["return_logprob"]:
                rollout_log_probs += [0.0] * len(obs_tokens)
                # Verify alignment
                assert len(response_tokens) == len(rollout_log_probs), \
                    f"Misalignment: {len(response_tokens)} tokens vs {len(rollout_log_probs)} log probs"
        
        elif action == "answer":
            # Step 5: Done - model provided final answer
            break
        
        else:
            # Invalid action - provide hint
            hint = "\nInvalid action. Use <search>query</search> or <answer>answer</answer>.\n"
            hint_tokens = state.tokenizer(hint, add_special_tokens=False)["input_ids"]
            
            response += hint
            response_tokens += hint_tokens
            loss_mask += [0] * len(hint_tokens)  # Don't train on error messages
            
            if SEARCH_R1_CONFIGS["return_logprob"]:
                rollout_log_probs += [0.0] * len(hint_tokens)
        
        # Check length limit
        if output["meta_info"]["finish_reason"]["type"] == "length":
            sample.status = Sample.Status.TRUNCATED
            break
    
    # Fill sample
    sample.tokens = prompt_tokens + response_tokens
    sample.response_length = len(response_tokens)
    sample.response = response
    sample.loss_mask = loss_mask
    sample.prompt = prompt_text
    
    if SEARCH_R1_CONFIGS["return_logprob"]:
        sample.rollout_log_probs = rollout_log_probs
    
    # Set status
    if sample.status != Sample.Status.TRUNCATED:
        finish_type = output["meta_info"]["finish_reason"]["type"]
        sample.status = Sample.Status.COMPLETED if finish_type == "stop" else Sample.Status.ABORTED
    
    return sample

def parse_action(text: str) -> tuple[str, str]:
    """Extract action type and content from model output."""
    pattern = r"<(search|answer)>(.*?)</\1>"
    match = re.search(pattern, text, re.DOTALL)
    
    if match:
        action = match.group(1)  # "search" or "answer"
        content = match.group(2).strip()  # Query or answer text
        return action, content
    
    return None, ""

4. Custom Reward Function

Evaluate the complete interaction:
from qa_em_format import compute_score_em

async def reward_func(args, sample: Sample, **kwargs) -> float:
    """Compute reward based on answer correctness."""
    if not isinstance(sample, Sample):
        raise TypeError("Sample must be an instance of Sample class.")
    
    # Compute exact match score
    score = compute_score_em(
        solution_str=sample.prompt + sample.response,
        ground_truth=sample.label["ground_truth"],
        format_score=SEARCH_R1_CONFIGS["format_score"],
    )
    
    return score
The compute_score_em function:
  • Extracts the final answer from the response
  • Compares against ground truth
  • Returns 1.0 for exact match, 0.0 for wrong answer
  • Can give partial credit for correct format

5. Training Script Configuration

#!/bin/bash

CUSTOM_ARGS=(
   --custom-generate-function-path examples.search_r1.generate_with_search.generate
   --custom-rm-path examples.search_r1.generate_with_search.reward_func
)

ROLLOUT_ARGS=(
   --prompt-data /root/nq_search/train.jsonl
   --input-key question
   --label-key answer
   --metadata-key metadata  # Pass search config if needed
   
   --rollout-batch-size 32
   --n-samples-per-prompt 4
   --rollout-max-response-len 8192
   --rollout-temperature 1.0
   
   # Use TIS for off-policy correction
   --use-tis
   --tis-clip 2.0
)

ray job submit --address="http://127.0.0.1:8265" \
   -- python3 train.py \
   ${MODEL_ARGS[@]} \
   ${CKPT_ARGS[@]} \
   ${ROLLOUT_ARGS[@]} \
   ${CUSTOM_ARGS[@]} \
   ${GRPO_ARGS[@]} \
   ${OPTIMIZER_ARGS[@]}

Advanced: Multi-Agent System

For complex workflows, you can orchestrate multiple specialized agents. Here’s an example from the multi-agent example:

Agent Roles

class SolverAgent(Agent):
    """Generates initial solutions to problems."""
    
    async def generate_initial_solution(self, args, problem) -> str:
        prompt = SOLVER_PROMPT_TEMPLATE.format(problem_statement=problem)
        return await self.run(args, prompt, key="solver")

class RewriterAgent(Agent):
    """Improves solutions based on previous attempts."""
    
    async def rewrite(self, args, problem, previous_solutions: list[str]) -> str:
        template = generate_rewriter_template(len(previous_solutions))
        format_params = {"problem_statement": problem}
        for i, sol in enumerate(previous_solutions):
            format_params[f"solution{i+1}"] = sol
        
        prompt = template.format(**format_params)
        return await self.run(args, prompt, key="rewriter")

class SelectorAgent(Agent):
    """Selects the best solution from candidates."""
    
    async def select(self, args, problem, candidates: list[str]) -> str:
        template = generate_select_template(len(candidates))
        format_params = {"problem_statement": problem}
        for i, sol in enumerate(candidates):
            format_params[f"solution{i+1}"] = sol
        
        prompt = template.format(**format_params)
        return await self.run(args, prompt, key="selector")

Orchestration

async def run_agent_system(args, sample: Sample) -> list[Sample]:
    """Run parallel agent workflows."""
    problem = sample.prompt
    results_dict = {"solver": [], "rewriter": [], "selector": []}
    
    # Step 1: Generate multiple initial solutions in parallel
    solver_tasks = [
        solver_worker(args, problem, worker_id)
        for worker_id in range(args.num_parallel)
    ]
    solutions = await asyncio.gather(*solver_tasks)
    
    # Compute rewards for solver outputs
    rewards = await batched_async_rm(args, results_dict["solver"])
    for sample, reward in zip(results_dict["solver"], rewards):
        sample.reward = reward
    
    # Step 2: Rewrite solutions
    valid_solutions = [s for s in solutions if s is not None]
    if len(valid_solutions) > 0:
        rewriter_tasks = [
            rewrite_worker(args, valid_solutions, problem, worker_id)
            for worker_id in range(args.num_parallel)
        ]
        rewritten = await asyncio.gather(*rewriter_tasks)
        
        # Compute rewards for rewritten solutions
        rewards = await batched_async_rm(args, results_dict["rewriter"])
        for sample, reward in zip(results_dict["rewriter"], rewards):
            sample.reward = reward
    
    # Step 3: Select best solution
    if len(rewritten) > 0:
        selector = SelectorAgent()
        selection = await selector.select(args, problem, rewritten)
        
        # Find selected solution's reward
        selected_idx = selector.extract_selected_solution_idx(
            selection, rewritten
        )
        if selected_idx is not None:
            results_dict["selector"][0].reward = \
                results_dict["rewriter"][selected_idx].reward
    
    # Adjust rewards based on final outcome
    if results_dict["selector"][0].reward == 1.0:
        # Success - boost all rewards
        adjust_rewards(results_dict, weight=1.2)
    else:
        # Failure - reduce all rewards
        adjust_rewards(results_dict, weight=0.8)
    
    # Return all samples for training
    return (
        results_dict["solver"] + 
        results_dict["rewriter"] + 
        results_dict["selector"]
    )

Configuration for Multi-Agent

CUSTOM_ARGS=(
   --custom-generate-function-path examples.multi_agent.rollout_with_multi_agents.generate_with_multi_agents
)

ROLLOUT_ARGS=(
   --rollout-batch-size 16
   --n-samples-per-prompt 1  # Each prompt generates multiple samples internally
   --rollout-max-response-len 4096
)

Key Considerations

Loss Masking Strategy

Critical: Always set loss_mask = 0 for environment-generated tokens (tool outputs, observations, error messages). Only train on model-generated tokens.
# Model output - TRAIN
model_tokens = tokenizer(model_text)["input_ids"]
loss_mask += [1] * len(model_tokens)

# Tool output - DON'T TRAIN
tool_tokens = tokenizer(tool_result)["input_ids"]
loss_mask += [0] * len(tool_tokens)

Log Probability Collection

When using TIS or other off-policy methods:
# Request log probs from inference engine
payload = {
    "text": prompt,
    "sampling_params": sampling_params,
    "return_logprob": True,  # Enable log prob collection
}

output = await post(url, payload)

# Extract tokens and log probs together
model_tokens = [item[1] for item in output["meta_info"]["output_token_logprobs"]]
model_log_probs = [item[0] for item in output["meta_info"]["output_token_logprobs"]]

# Never re-tokenize when collecting log probs!
# This ensures perfect alignment between tokens and probabilities
When return_logprob=True, always use tokens directly from the output. Never re-tokenize the text, as this may produce different tokens.

Metadata for Tool Definitions

Pass tool definitions through metadata:
{
  "question": "What is the population of Tokyo?",
  "answer": "14 million",
  "metadata": {
    "tools": [
      {
        "name": "search",
        "description": "Search the web",
        "parameters": {"query": "string"}
      },
      {
        "name": "calculator",
        "description": "Perform calculations",
        "parameters": {"expression": "string"}
      }
    ]
  }
}
Access in custom generate function:
async def generate(args, sample: Sample, sampling_params) -> Sample:
    available_tools = sample.metadata.get("tools", [])
    
    # Build system prompt with tool descriptions
    tool_descriptions = "\n".join(
        f"- {tool['name']}: {tool['description']}" 
        for tool in available_tools
    )
    
    system_prompt = f"You have access to:\n{tool_descriptions}\n\n"
    prompt = system_prompt + sample.prompt
    
    # Continue with generation...

Error Handling

async def generate(args, sample: Sample, sampling_params) -> Sample:
    try:
        # Generation logic
        for turn in range(max_turns):
            output = await post(url, payload)
            
            if output["meta_info"]["finish_reason"]["type"] == "abort":
                sample.status = Sample.Status.ABORTED
                return sample
            
            # Process action
            try:
                result = await execute_tool(action, content)
            except ToolExecutionError as e:
                # Handle tool errors gracefully
                error_msg = f"<error>Tool failed: {str(e)}</error>"
                error_tokens = tokenizer(error_msg)["input_ids"]
                
                response += error_msg
                response_tokens += error_tokens
                loss_mask += [0] * len(error_tokens)
                continue
    
    except Exception as e:
        logger.error(f"Generation failed: {e}")
        sample.status = Sample.Status.ABORTED
        return sample
    
    return sample

Performance Tips

Concurrency Control

import asyncio

# Limit concurrent tool calls
TOOL_SEMAPHORE = asyncio.Semaphore(256)

async def call_tool(query):
    async with TOOL_SEMAPHORE:
        result = await api_call(query)
    return result

Caching Tool Results

from functools import lru_cache

@lru_cache(maxsize=10000)
def cached_search(query: str) -> str:
    """Cache search results to avoid duplicate API calls."""
    return search_api(query)

async def search(query: str) -> str:
    # Check cache first
    cached = cached_search(query)
    if cached:
        return cached
    
    # Call API
    result = await search_api(query)
    cached_search.cache_info()  # Monitor cache hit rate
    return result

Batched Tool Calls

async def batch_search(queries: list[str]) -> list[str]:
    """Execute multiple searches in parallel."""
    tasks = [search(q) for q in queries]
    results = await asyncio.gather(*tasks)
    return results

Debugging

Log Sample Details

# In custom generate function
logger.info(f"Turn {turn_idx}: action={action}, content={content[:50]}...")
logger.debug(f"Response length: {len(response_tokens)}, mask length: {len(loss_mask)}")

# Verify alignment
assert len(response_tokens) == len(loss_mask), \
    f"Token/mask mismatch: {len(response_tokens)} vs {len(loss_mask)}"

if rollout_log_probs:
    assert len(response_tokens) == len(rollout_log_probs), \
        f"Token/logprob mismatch: {len(response_tokens)} vs {len(rollout_log_probs)}"

Save Debug Data

# In training script
ray job submit ... \
   -- python3 train.py \
   --save-debug-rollout-data /root/debug/rollout_{rollout_id}.pt \
   ...
Inspect saved data:
import torch

data = torch.load("/root/debug/rollout_0.pt")
for sample in data["samples"]:
    print(f"Prompt: {sample.prompt[:100]}...")
    print(f"Response: {sample.response[:200]}...")
    print(f"Reward: {sample.reward}")
    print(f"Loss mask: {sample.loss_mask}")
    print("-" * 80)

Next Steps

Customization Guide

Deep dive into all customization options

Distributed Training

Scale your agent training across multiple nodes

Build docs developers (and LLMs) love