Overview
Multi-turn agent training involves:- Multi-turn interaction loop: Model generates actions, executes tools, observes results
- Loss masking: Train on model-generated tokens, not environment responses
- Custom reward functions: Evaluate complete interaction trajectories
- Metadata management: Pass tool definitions and context through the pipeline
Example: Search-R1 Agent
We’ll use the Search-R1 example, which trains a model to answer questions using web search.Agent Workflow
Generate Action
Model decides to either:
- Search:
<search>query terms</search> - Answer:
<answer>final answer</answer>
Implementation
1. Configuration
First, define configuration for your agent:# Configuration dictionary
SEARCH_R1_CONFIGS = {
"max_turns": 2, # Maximum search iterations
"topk": 3, # Number of search results
"search_concurrency": 256, # Concurrent search requests
"search_backend": "local", # "local" or "google"
"return_logprob": True, # Collect log probs for TIS
"format_score": 0.2, # Partial credit for format
}
2. Tool Functions
Implement your tool APIs:import asyncio
from slime.utils.http_utils import post
SEMAPHORE = asyncio.Semaphore(256) # Limit concurrency
async def search(query: str) -> str:
"""Execute search and return formatted results."""
async with SEMAPHORE:
if SEARCH_R1_CONFIGS["search_backend"] == "google":
result = await google_search(
api_key=SEARCH_R1_CONFIGS["google"]["api_key"],
query=query,
topk=SEARCH_R1_CONFIGS["topk"],
)
else:
result = await local_search(
url=SEARCH_R1_CONFIGS["local"]["search_url"],
query=query,
topk=SEARCH_R1_CONFIGS["topk"],
)
return format_search_results(result)
def format_search_results(results) -> str:
"""Format search results as text."""
formatted = ""
for idx, doc in enumerate(results):
content = doc["document"]["contents"]
title = content.split("\n")[0]
text = "\n".join(content.split("\n")[1:])
formatted += f"Doc {idx+1}(Title: {title}) {text}\n"
return formatted
3. Custom Generate Function
Implement the multi-turn generation loop:import re
from slime.rollout.sglang_rollout import GenerateState
from slime.utils.types import Sample
async def generate(args, sample: Sample, sampling_params) -> Sample:
state = GenerateState(args)
url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate"
# Initialize
prompt_text = sample.prompt
prompt_tokens = state.tokenizer(prompt_text, add_special_tokens=False)["input_ids"]
response = ""
response_tokens = []
loss_mask = []
rollout_log_probs = [] if SEARCH_R1_CONFIGS["return_logprob"] else None
# Multi-turn loop
for turn_idx in range(SEARCH_R1_CONFIGS["max_turns"]):
# Step 1: Model generates action
payload = {
"text": prompt_text + response,
"sampling_params": sampling_params,
}
if SEARCH_R1_CONFIGS["return_logprob"]:
payload["return_logprob"] = True
output = await post(url, payload)
# Handle abort
if output["meta_info"]["finish_reason"]["type"] == "abort":
sample.status = Sample.Status.ABORTED
return sample
model_text = output["text"]
# Extract tokens and log probs
if SEARCH_R1_CONFIGS["return_logprob"]:
# Use tokens directly from output to ensure alignment
model_tokens = [item[1] for item in output["meta_info"]["output_token_logprobs"]]
model_log_probs = [item[0] for item in output["meta_info"]["output_token_logprobs"]]
else:
# Tokenize response
model_tokens = state.tokenizer(model_text, add_special_tokens=False)["input_ids"]
# Add model output to response
response += model_text
response_tokens += model_tokens
loss_mask += [1] * len(model_tokens) # TRAIN on model output
if SEARCH_R1_CONFIGS["return_logprob"]:
rollout_log_probs += model_log_probs
# Step 2: Parse action
action, content = parse_action(model_text)
if action == "search":
# Step 3: Execute search
search_results = await search(content)
obs_text = f"\n\n<information>{search_results}</information>\n\n"
obs_tokens = state.tokenizer(obs_text, add_special_tokens=False)["input_ids"]
# Step 4: Add observation to response
response += obs_text
response_tokens += obs_tokens
loss_mask += [0] * len(obs_tokens) # DON'T TRAIN on tool output
if SEARCH_R1_CONFIGS["return_logprob"]:
rollout_log_probs += [0.0] * len(obs_tokens)
# Verify alignment
assert len(response_tokens) == len(rollout_log_probs), \
f"Misalignment: {len(response_tokens)} tokens vs {len(rollout_log_probs)} log probs"
elif action == "answer":
# Step 5: Done - model provided final answer
break
else:
# Invalid action - provide hint
hint = "\nInvalid action. Use <search>query</search> or <answer>answer</answer>.\n"
hint_tokens = state.tokenizer(hint, add_special_tokens=False)["input_ids"]
response += hint
response_tokens += hint_tokens
loss_mask += [0] * len(hint_tokens) # Don't train on error messages
if SEARCH_R1_CONFIGS["return_logprob"]:
rollout_log_probs += [0.0] * len(hint_tokens)
# Check length limit
if output["meta_info"]["finish_reason"]["type"] == "length":
sample.status = Sample.Status.TRUNCATED
break
# Fill sample
sample.tokens = prompt_tokens + response_tokens
sample.response_length = len(response_tokens)
sample.response = response
sample.loss_mask = loss_mask
sample.prompt = prompt_text
if SEARCH_R1_CONFIGS["return_logprob"]:
sample.rollout_log_probs = rollout_log_probs
# Set status
if sample.status != Sample.Status.TRUNCATED:
finish_type = output["meta_info"]["finish_reason"]["type"]
sample.status = Sample.Status.COMPLETED if finish_type == "stop" else Sample.Status.ABORTED
return sample
def parse_action(text: str) -> tuple[str, str]:
"""Extract action type and content from model output."""
pattern = r"<(search|answer)>(.*?)</\1>"
match = re.search(pattern, text, re.DOTALL)
if match:
action = match.group(1) # "search" or "answer"
content = match.group(2).strip() # Query or answer text
return action, content
return None, ""
4. Custom Reward Function
Evaluate the complete interaction:from qa_em_format import compute_score_em
async def reward_func(args, sample: Sample, **kwargs) -> float:
"""Compute reward based on answer correctness."""
if not isinstance(sample, Sample):
raise TypeError("Sample must be an instance of Sample class.")
# Compute exact match score
score = compute_score_em(
solution_str=sample.prompt + sample.response,
ground_truth=sample.label["ground_truth"],
format_score=SEARCH_R1_CONFIGS["format_score"],
)
return score
compute_score_em function:
- Extracts the final answer from the response
- Compares against ground truth
- Returns 1.0 for exact match, 0.0 for wrong answer
- Can give partial credit for correct format
5. Training Script Configuration
#!/bin/bash
CUSTOM_ARGS=(
--custom-generate-function-path examples.search_r1.generate_with_search.generate
--custom-rm-path examples.search_r1.generate_with_search.reward_func
)
ROLLOUT_ARGS=(
--prompt-data /root/nq_search/train.jsonl
--input-key question
--label-key answer
--metadata-key metadata # Pass search config if needed
--rollout-batch-size 32
--n-samples-per-prompt 4
--rollout-max-response-len 8192
--rollout-temperature 1.0
# Use TIS for off-policy correction
--use-tis
--tis-clip 2.0
)
ray job submit --address="http://127.0.0.1:8265" \
-- python3 train.py \
${MODEL_ARGS[@]} \
${CKPT_ARGS[@]} \
${ROLLOUT_ARGS[@]} \
${CUSTOM_ARGS[@]} \
${GRPO_ARGS[@]} \
${OPTIMIZER_ARGS[@]}
Advanced: Multi-Agent System
For complex workflows, you can orchestrate multiple specialized agents. Here’s an example from the multi-agent example:Agent Roles
class SolverAgent(Agent):
"""Generates initial solutions to problems."""
async def generate_initial_solution(self, args, problem) -> str:
prompt = SOLVER_PROMPT_TEMPLATE.format(problem_statement=problem)
return await self.run(args, prompt, key="solver")
class RewriterAgent(Agent):
"""Improves solutions based on previous attempts."""
async def rewrite(self, args, problem, previous_solutions: list[str]) -> str:
template = generate_rewriter_template(len(previous_solutions))
format_params = {"problem_statement": problem}
for i, sol in enumerate(previous_solutions):
format_params[f"solution{i+1}"] = sol
prompt = template.format(**format_params)
return await self.run(args, prompt, key="rewriter")
class SelectorAgent(Agent):
"""Selects the best solution from candidates."""
async def select(self, args, problem, candidates: list[str]) -> str:
template = generate_select_template(len(candidates))
format_params = {"problem_statement": problem}
for i, sol in enumerate(candidates):
format_params[f"solution{i+1}"] = sol
prompt = template.format(**format_params)
return await self.run(args, prompt, key="selector")
Orchestration
async def run_agent_system(args, sample: Sample) -> list[Sample]:
"""Run parallel agent workflows."""
problem = sample.prompt
results_dict = {"solver": [], "rewriter": [], "selector": []}
# Step 1: Generate multiple initial solutions in parallel
solver_tasks = [
solver_worker(args, problem, worker_id)
for worker_id in range(args.num_parallel)
]
solutions = await asyncio.gather(*solver_tasks)
# Compute rewards for solver outputs
rewards = await batched_async_rm(args, results_dict["solver"])
for sample, reward in zip(results_dict["solver"], rewards):
sample.reward = reward
# Step 2: Rewrite solutions
valid_solutions = [s for s in solutions if s is not None]
if len(valid_solutions) > 0:
rewriter_tasks = [
rewrite_worker(args, valid_solutions, problem, worker_id)
for worker_id in range(args.num_parallel)
]
rewritten = await asyncio.gather(*rewriter_tasks)
# Compute rewards for rewritten solutions
rewards = await batched_async_rm(args, results_dict["rewriter"])
for sample, reward in zip(results_dict["rewriter"], rewards):
sample.reward = reward
# Step 3: Select best solution
if len(rewritten) > 0:
selector = SelectorAgent()
selection = await selector.select(args, problem, rewritten)
# Find selected solution's reward
selected_idx = selector.extract_selected_solution_idx(
selection, rewritten
)
if selected_idx is not None:
results_dict["selector"][0].reward = \
results_dict["rewriter"][selected_idx].reward
# Adjust rewards based on final outcome
if results_dict["selector"][0].reward == 1.0:
# Success - boost all rewards
adjust_rewards(results_dict, weight=1.2)
else:
# Failure - reduce all rewards
adjust_rewards(results_dict, weight=0.8)
# Return all samples for training
return (
results_dict["solver"] +
results_dict["rewriter"] +
results_dict["selector"]
)
Configuration for Multi-Agent
CUSTOM_ARGS=(
--custom-generate-function-path examples.multi_agent.rollout_with_multi_agents.generate_with_multi_agents
)
ROLLOUT_ARGS=(
--rollout-batch-size 16
--n-samples-per-prompt 1 # Each prompt generates multiple samples internally
--rollout-max-response-len 4096
)
Key Considerations
Loss Masking Strategy
Critical: Always set
loss_mask = 0 for environment-generated tokens (tool outputs, observations, error messages). Only train on model-generated tokens.# Model output - TRAIN
model_tokens = tokenizer(model_text)["input_ids"]
loss_mask += [1] * len(model_tokens)
# Tool output - DON'T TRAIN
tool_tokens = tokenizer(tool_result)["input_ids"]
loss_mask += [0] * len(tool_tokens)
Log Probability Collection
When using TIS or other off-policy methods:# Request log probs from inference engine
payload = {
"text": prompt,
"sampling_params": sampling_params,
"return_logprob": True, # Enable log prob collection
}
output = await post(url, payload)
# Extract tokens and log probs together
model_tokens = [item[1] for item in output["meta_info"]["output_token_logprobs"]]
model_log_probs = [item[0] for item in output["meta_info"]["output_token_logprobs"]]
# Never re-tokenize when collecting log probs!
# This ensures perfect alignment between tokens and probabilities
When
return_logprob=True, always use tokens directly from the output. Never re-tokenize the text, as this may produce different tokens.Metadata for Tool Definitions
Pass tool definitions through metadata:{
"question": "What is the population of Tokyo?",
"answer": "14 million",
"metadata": {
"tools": [
{
"name": "search",
"description": "Search the web",
"parameters": {"query": "string"}
},
{
"name": "calculator",
"description": "Perform calculations",
"parameters": {"expression": "string"}
}
]
}
}
async def generate(args, sample: Sample, sampling_params) -> Sample:
available_tools = sample.metadata.get("tools", [])
# Build system prompt with tool descriptions
tool_descriptions = "\n".join(
f"- {tool['name']}: {tool['description']}"
for tool in available_tools
)
system_prompt = f"You have access to:\n{tool_descriptions}\n\n"
prompt = system_prompt + sample.prompt
# Continue with generation...
Error Handling
async def generate(args, sample: Sample, sampling_params) -> Sample:
try:
# Generation logic
for turn in range(max_turns):
output = await post(url, payload)
if output["meta_info"]["finish_reason"]["type"] == "abort":
sample.status = Sample.Status.ABORTED
return sample
# Process action
try:
result = await execute_tool(action, content)
except ToolExecutionError as e:
# Handle tool errors gracefully
error_msg = f"<error>Tool failed: {str(e)}</error>"
error_tokens = tokenizer(error_msg)["input_ids"]
response += error_msg
response_tokens += error_tokens
loss_mask += [0] * len(error_tokens)
continue
except Exception as e:
logger.error(f"Generation failed: {e}")
sample.status = Sample.Status.ABORTED
return sample
return sample
Performance Tips
Concurrency Control
import asyncio
# Limit concurrent tool calls
TOOL_SEMAPHORE = asyncio.Semaphore(256)
async def call_tool(query):
async with TOOL_SEMAPHORE:
result = await api_call(query)
return result
Caching Tool Results
from functools import lru_cache
@lru_cache(maxsize=10000)
def cached_search(query: str) -> str:
"""Cache search results to avoid duplicate API calls."""
return search_api(query)
async def search(query: str) -> str:
# Check cache first
cached = cached_search(query)
if cached:
return cached
# Call API
result = await search_api(query)
cached_search.cache_info() # Monitor cache hit rate
return result
Batched Tool Calls
async def batch_search(queries: list[str]) -> list[str]:
"""Execute multiple searches in parallel."""
tasks = [search(q) for q in queries]
results = await asyncio.gather(*tasks)
return results
Debugging
Log Sample Details
# In custom generate function
logger.info(f"Turn {turn_idx}: action={action}, content={content[:50]}...")
logger.debug(f"Response length: {len(response_tokens)}, mask length: {len(loss_mask)}")
# Verify alignment
assert len(response_tokens) == len(loss_mask), \
f"Token/mask mismatch: {len(response_tokens)} vs {len(loss_mask)}"
if rollout_log_probs:
assert len(response_tokens) == len(rollout_log_probs), \
f"Token/logprob mismatch: {len(response_tokens)} vs {len(rollout_log_probs)}"
Save Debug Data
# In training script
ray job submit ... \
-- python3 train.py \
--save-debug-rollout-data /root/debug/rollout_{rollout_id}.pt \
...
import torch
data = torch.load("/root/debug/rollout_0.pt")
for sample in data["samples"]:
print(f"Prompt: {sample.prompt[:100]}...")
print(f"Response: {sample.response[:200]}...")
print(f"Reward: {sample.reward}")
print(f"Loss mask: {sample.loss_mask}")
print("-" * 80)
Next Steps
Customization Guide
Deep dive into all customization options
Distributed Training
Scale your agent training across multiple nodes